diff --git a/src/proxy.go b/src/proxy.go index d53805b..caddee4 100644 --- a/src/proxy.go +++ b/src/proxy.go @@ -9,6 +9,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "regexp" "strings" "time" @@ -32,7 +33,7 @@ func fifo(name string) (string, error) { return output, nil } -func runProxy(commandPrefix string, cmdBuilder func(temp string) *exec.Cmd, opts *Options, withExports bool) (int, error) { +func runProxy(commandPrefix string, cmdBuilder func(temp string, needBash bool) (*exec.Cmd, error), opts *Options, withExports bool) (int, error) { output, err := fifo("proxy-output") if err != nil { return ExitError, err @@ -92,17 +93,28 @@ func runProxy(commandPrefix string, cmdBuilder func(temp string) *exec.Cmd, opts // To ensure that the options are processed by a POSIX-compliant shell, // we need to write the command to a temporary file and execute it with sh. var exports []string + needBash := false if withExports { - exports = os.Environ() - for idx, pairStr := range exports { + validIdentifier := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + for _, pairStr := range os.Environ() { pair := strings.SplitN(pairStr, "=", 2) - exports[idx] = fmt.Sprintf("export %s=%s", pair[0], escapeSingleQuote(pair[1])) + if validIdentifier.MatchString(pair[0]) { + exports = append(exports, fmt.Sprintf("export %s=%s", pair[0], escapeSingleQuote(pair[1]))) + } else if strings.HasPrefix(pair[0], "BASH_FUNC_") && strings.HasSuffix(pair[0], "%%") { + name := pair[0][10 : len(pair[0])-2] + exports = append(exports, name+pair[1]) + exports = append(exports, "export -f "+name) + needBash = true + } } } temp := WriteTemporaryFile(append(exports, command), "\n") defer os.Remove(temp) - cmd := cmdBuilder(temp) + cmd, err := cmdBuilder(temp, needBash) + if err != nil { + return ExitError, err + } cmd.Stderr = os.Stderr intChan := make(chan os.Signal, 1) defer close(intChan) diff --git a/src/proxy_unix.go b/src/proxy_unix.go index 189d0e5..b670053 100644 --- a/src/proxy_unix.go +++ b/src/proxy_unix.go @@ -9,7 +9,10 @@ import ( "golang.org/x/sys/unix" ) -func sh() (string, error) { +func sh(bash bool) (string, error) { + if bash { + return "bash", nil + } return "sh", nil } diff --git a/src/proxy_windows.go b/src/proxy_windows.go index a957da8..2aa61ab 100644 --- a/src/proxy_windows.go +++ b/src/proxy_windows.go @@ -13,12 +13,16 @@ import ( var shPath atomic.Value -func sh() (string, error) { +func sh(bash bool) (string, error) { if cached := shPath.Load(); cached != nil { return cached.(string), nil } - cmd := exec.Command("cygpath", "-w", "/usr/bin/sh") + name := "sh" + if bash { + name = "bash" + } + cmd := exec.Command("cygpath", "-w", "/usr/bin/"+name) bytes, err := cmd.Output() if err != nil { return "", err @@ -31,7 +35,7 @@ func sh() (string, error) { func mkfifo(path string, mode uint32) (string, error) { m := strconv.FormatUint(uint64(mode), 8) - sh, err := sh() + sh, err := sh(false) if err != nil { return path, err } @@ -43,7 +47,7 @@ func mkfifo(path string, mode uint32) (string, error) { } func withOutputPipe(output string, task func(io.ReadCloser)) error { - sh, err := sh() + sh, err := sh(false) if err != nil { return err } @@ -62,7 +66,7 @@ func withOutputPipe(output string, task func(io.ReadCloser)) error { } func withInputPipe(input string, task func(io.WriteCloser)) error { - sh, err := sh() + sh, err := sh(false) if err != nil { return err } diff --git a/src/tmux.go b/src/tmux.go index 246222f..b2315dc 100644 --- a/src/tmux.go +++ b/src/tmux.go @@ -49,9 +49,12 @@ func runTmux(args []string, opts *Options) (int, error) { tmuxArgs = append(tmuxArgs, "-w"+opts.Tmux.width.String()) tmuxArgs = append(tmuxArgs, "-h"+opts.Tmux.height.String()) - return runProxy(argStr, func(temp string) *exec.Cmd { - sh, _ := sh() + return runProxy(argStr, func(temp string, needBash bool) (*exec.Cmd, error) { + sh, err := sh(needBash) + if err != nil { + return nil, err + } tmuxArgs = append(tmuxArgs, sh, temp) - return exec.Command("tmux", tmuxArgs...) + return exec.Command("tmux", tmuxArgs...), nil }, opts, true) } diff --git a/src/winpty_windows.go b/src/winpty_windows.go index 78020d7..aba02ce 100644 --- a/src/winpty_windows.go +++ b/src/winpty_windows.go @@ -44,11 +44,6 @@ func needWinpty(opts *Options) bool { } func runWinpty(args []string, opts *Options) (int, error) { - sh, err := sh() - if err != nil { - return ExitError, err - } - argStr := escapeSingleQuote(args[0]) for _, arg := range args[1:] { argStr += " " + escapeSingleQuote(arg) @@ -56,20 +51,30 @@ func runWinpty(args []string, opts *Options) (int, error) { argStr += ` --no-winpty` if isMintty345() { - return runProxy(argStr, func(temp string) *exec.Cmd { + return runProxy(argStr, func(temp string, needBash bool) (*exec.Cmd, error) { + sh, err := sh(needBash) + if err != nil { + return nil, err + } + cmd := exec.Command(sh, temp) cmd.Env = append(os.Environ(), "MSYS=enable_pcon") cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - return cmd + return cmd, nil }, opts, false) } - return runProxy(argStr, func(temp string) *exec.Cmd { + return runProxy(argStr, func(temp string, needBash bool) (*exec.Cmd, error) { + sh, err := sh(needBash) + if err != nil { + return nil, err + } + cmd := exec.Command(sh, "-c", fmt.Sprintf(`winpty < /dev/tty > /dev/tty -- sh %q`, temp)) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - return cmd + return cmd, nil }, opts, false) }