package control import ( "bytes" "context" "fmt" "os" "os/exec" "path/filepath" "strconv" "strings" ) type SSHRunner struct{} type SSHExecutor interface { Run(ctx context.Context, node Node, script string) (*CommandResult, error) Check(ctx context.Context, node Node) (*CommandResult, error) CopyFile(ctx context.Context, node Node, localPath, remotePath string) error } type CommandResult struct { Stdout string Stderr string } func (r SSHRunner) Run(ctx context.Context, node Node, script string) (*CommandResult, error) { target := sshTarget(node) cmd, err := sshCommand(ctx, node, target, "sh -s") if err != nil { return &CommandResult{}, err } cmd.Stdin = strings.NewReader(script) var stdout bytes.Buffer var stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, fmt.Errorf("ssh %s: %w", target, err) } return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, nil } func (r SSHRunner) Check(ctx context.Context, node Node) (*CommandResult, error) { target := sshTarget(node) cmd, err := sshCommand(ctx, node, target, "printf ok") if err != nil { return &CommandResult{}, err } var stdout bytes.Buffer var stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, fmt.Errorf("ssh %s: %w", target, err) } return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, nil } func CopyFileOverSCP(ctx context.Context, node Node, localPath, remotePath string) error { target := fmt.Sprintf("%s:%s", sshTarget(node), remotePath) cmd, err := scpCommand(ctx, node, localPath, target) if err != nil { return err } output, err := cmd.CombinedOutput() if err != nil { return fmt.Errorf("scp %s -> %s: %w: %s", localPath, target, err, string(output)) } return nil } func (r SSHRunner) CopyFile(ctx context.Context, node Node, localPath, remotePath string) error { return CopyFileOverSCP(ctx, node, localPath, remotePath) } func CopyDirContentsOverSCP(ctx context.Context, node Node, localDir, remoteDir string) error { target := fmt.Sprintf("%s:%s", sshTarget(node), remoteDir) cmd, err := scpCommand(ctx, node, "-r", filepath.Clean(localDir)+"/.", target) if err != nil { return err } output, err := cmd.CombinedOutput() if err != nil { return fmt.Errorf("scp %s -> %s: %w: %s", localDir, target, err, string(output)) } return nil } func sshBaseArgs(node Node) []string { args := []string{ "-o", "StrictHostKeyChecking=accept-new", "-p", strconv.Itoa(defaultSSHPort(node.SSH.Port)), } if strings.TrimSpace(node.SSH.Auth) == "password" { args = append(args, "-o", "BatchMode=no") } else { args = append(args, "-o", "BatchMode=yes") } if strings.TrimSpace(node.SSH.IdentityFile) != "" { args = append(args, "-i", expandHome(node.SSH.IdentityFile)) } return args } func scpBaseArgs(node Node) []string { args := []string{ "-o", "StrictHostKeyChecking=accept-new", "-P", strconv.Itoa(defaultSSHPort(node.SSH.Port)), } if strings.TrimSpace(node.SSH.Auth) == "password" { args = append(args, "-o", "BatchMode=no") } else { args = append(args, "-o", "BatchMode=yes") } if strings.TrimSpace(node.SSH.IdentityFile) != "" { args = append(args, "-i", expandHome(node.SSH.IdentityFile)) } return args } func sshTarget(node Node) string { return fmt.Sprintf("%s@%s", node.SSH.User, node.Host) } func defaultSSHPort(port int) int { if port == 0 { return 22 } return port } func expandHome(path string) string { if path == "" || path[0] != '~' { return path } home, err := exec.Command("sh", "-lc", "printf %s \"$HOME\"").Output() if err != nil { return path } return filepath.Join(strings.TrimSpace(string(home)), strings.TrimPrefix(path, "~/")) } func sshCommand(ctx context.Context, node Node, extraArgs ...string) (*exec.Cmd, error) { args := sshBaseArgs(node) args = append(args, extraArgs...) return wrapWithPassword(ctx, node, "ssh", args...) } func scpCommand(ctx context.Context, node Node, extraArgs ...string) (*exec.Cmd, error) { args := scpBaseArgs(node) args = append(args, extraArgs...) return wrapWithPassword(ctx, node, "scp", args...) } func wrapWithPassword(ctx context.Context, node Node, command string, args ...string) (*exec.Cmd, error) { if strings.TrimSpace(node.SSH.Auth) != "password" { return exec.CommandContext(ctx, command, args...), nil } password := node.SSH.Password if password == "" { envName := strings.TrimSpace(node.SSH.PasswordEnv) if envName == "" { return nil, fmt.Errorf("ssh password auth for %s requires ssh.password_env", sshTarget(node)) } password = os.Getenv(envName) if password == "" { return nil, fmt.Errorf("ssh password env %s is empty", envName) } } wrappedArgs := append([]string{"-p", password, command}, args...) cmd := exec.CommandContext(ctx, "sshpass", wrappedArgs...) return cmd, nil }