package ssh import ( "database/sql" "encoding/base64" "fmt" "io" "log/slog" "net" "sync" "time" "github.com/google/uuid" "golang.org/x/crypto/ssh" ) // OutputHandler is called when data is read from an SSH session's stdout. // In production this will emit Wails events; for testing, a simple callback. type OutputHandler func(sessionID string, data []byte) // CWDHandler is called when the CWD tracker detects a directory change via OSC 7. type CWDHandler func(sessionID string, path string) // SSHSession represents an active SSH connection with its PTY shell session. type SSHSession struct { ID string Client *ssh.Client Session *ssh.Session Stdin io.WriteCloser ConnID int64 Hostname string Port int Username string Connected time.Time CWDTracker *CWDTracker mu sync.Mutex } // SSHService manages SSH connections and their associated sessions. type SSHService struct { sessions map[string]*SSHSession mu sync.RWMutex db *sql.DB hostKeyStore *HostKeyStore outputHandler OutputHandler cwdHandler CWDHandler } // NewSSHService creates a new SSHService. The outputHandler is called when data // arrives from a session's stdout. The cwdHandler is called when OSC 7 CWD // changes are detected. The hostKeyStore verifies SSH host keys (TOFU model). // Pass nil for any handler if not needed. func NewSSHService(db *sql.DB, hostKeyStore *HostKeyStore, outputHandler OutputHandler, cwdHandler CWDHandler) *SSHService { return &SSHService{ sessions: make(map[string]*SSHSession), db: db, hostKeyStore: hostKeyStore, outputHandler: outputHandler, cwdHandler: cwdHandler, } } // Connect dials an SSH server, opens a session with a PTY and shell, and // launches a goroutine to read stdout. Returns the session ID. // // Host key verification uses Trust On First Use (TOFU): new keys are stored // automatically, matching keys are accepted silently, and CHANGED keys reject // the connection to protect against MITM attacks. func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) { addr := fmt.Sprintf("%s:%d", hostname, port) // Build host key callback — TOFU model via HostKeyStore hostKeyCallback := s.buildHostKeyCallback(hostname, port) config := &ssh.ClientConfig{ User: username, Auth: authMethods, HostKeyCallback: hostKeyCallback, Timeout: 15 * time.Second, } client, err := ssh.Dial("tcp", addr, config) if err != nil { return "", fmt.Errorf("ssh dial %s: %w", addr, err) } session, err := client.NewSession() if err != nil { client.Close() return "", fmt.Errorf("new session: %w", err) } modes := ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 115200, ssh.TTY_OP_OSPEED: 115200, } if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { session.Close() client.Close() return "", fmt.Errorf("request pty: %w", err) } stdin, err := session.StdinPipe() if err != nil { session.Close() client.Close() return "", fmt.Errorf("stdin pipe: %w", err) } stdout, err := session.StdoutPipe() if err != nil { session.Close() client.Close() return "", fmt.Errorf("stdout pipe: %w", err) } if err := session.Shell(); err != nil { session.Close() client.Close() return "", fmt.Errorf("start shell: %w", err) } sessionID := uuid.NewString() sshSession := &SSHSession{ ID: sessionID, Client: client, Session: session, Stdin: stdin, Hostname: hostname, Port: port, Username: username, Connected: time.Now(), CWDTracker: NewCWDTracker(), } s.mu.Lock() s.sessions[sessionID] = sshSession s.mu.Unlock() // Launch goroutine to read stdout and forward data via the output handler go s.readLoop(sessionID, stdout) // Inject shell integration for CWD tracking (OSC 7). // Uses stty -echo to suppress the command from appearing in the terminal, // then restores echo. A leading space keeps it out of shell history. go func() { time.Sleep(500 * time.Millisecond) // Suppress echo, set PROMPT_COMMAND for bash (zsh uses precmd), // restore echo, then clear the current line so no visual artifact remains. injection := " stty -echo 2>/dev/null; " + ShellIntegrationCommand("bash") + "; " + "if [ -n \"$ZSH_VERSION\" ]; then " + ShellIntegrationCommand("zsh") + "; fi; " + "stty echo 2>/dev/null\n" sshSession.mu.Lock() if sshSession.Stdin != nil { _, _ = sshSession.Stdin.Write([]byte(injection)) } sshSession.mu.Unlock() }() return sessionID, nil } // buildHostKeyCallback returns an ssh.HostKeyCallback that implements TOFU // (Trust On First Use) via the HostKeyStore. If no store is configured, it // falls back to accepting all keys (for testing/development only). func (s *SSHService) buildHostKeyCallback(hostname string, port int) ssh.HostKeyCallback { if s.hostKeyStore == nil { slog.Warn("no host key store configured — accepting all host keys (INSECURE)") return ssh.InsecureIgnoreHostKey() } store := s.hostKeyStore return func(remoteHostname string, addr net.Addr, key ssh.PublicKey) error { keyType := key.Type() fingerprint := ssh.FingerprintSHA256(key) result, err := store.Verify(hostname, port, keyType, fingerprint) if err != nil { return fmt.Errorf("host key verification failed: %w", err) } switch result { case HostKeyMatch: slog.Debug("host key verified", "host", hostname, "port", port, "type", keyType) return nil case HostKeyNew: // TOFU: store the key and accept rawKey := base64.StdEncoding.EncodeToString(key.Marshal()) if storeErr := store.Store(hostname, port, keyType, fingerprint, rawKey); storeErr != nil { slog.Warn("failed to store new host key", "host", hostname, "error", storeErr) // Still accept the connection — storing failed but the key itself is fine } slog.Info("new host key stored (TOFU)", "host", hostname, "port", port, "type", keyType, "fingerprint", fingerprint) return nil case HostKeyChanged: // REJECT — possible MITM attack return fmt.Errorf( "HOST KEY CHANGED for %s:%d (type %s). Expected fingerprint does not match. "+ "This could indicate a man-in-the-middle attack. Connection refused. "+ "If the server was legitimately re-keyed, remove the old key and try again", hostname, port, keyType, ) default: return fmt.Errorf("unknown host key result: %d", result) } } } // readLoop continuously reads from the session stdout, processes CWD tracking // (stripping OSC 7 sequences), and calls the output handler with cleaned data. // It stops when the reader returns an error (typically EOF when the session closes). func (s *SSHService) readLoop(sessionID string, reader io.Reader) { // Grab the CWD tracker for this session (if any) s.mu.RLock() sess := s.sessions[sessionID] s.mu.RUnlock() buf := make([]byte, 32*1024) for { n, err := reader.Read(buf) if n > 0 { data := make([]byte, n) copy(data, buf[:n]) // Process CWD tracking — strips OSC 7 sequences from output if sess != nil && sess.CWDTracker != nil { cleaned, newCWD := sess.CWDTracker.ProcessOutput(data) data = cleaned // Emit CWD change event if a new path was detected if newCWD != "" && s.cwdHandler != nil { s.cwdHandler(sessionID, newCWD) } } if len(data) > 0 && s.outputHandler != nil { s.outputHandler(sessionID, data) } } if err != nil { break } } } // Write sends data to the session's stdin. func (s *SSHService) Write(sessionID string, data string) error { s.mu.RLock() sess, ok := s.sessions[sessionID] s.mu.RUnlock() if !ok { return fmt.Errorf("session %s not found", sessionID) } sess.mu.Lock() defer sess.mu.Unlock() if sess.Stdin == nil { return fmt.Errorf("session %s stdin is closed", sessionID) } _, err := sess.Stdin.Write([]byte(data)) if err != nil { return fmt.Errorf("write to session %s: %w", sessionID, err) } return nil } // Resize sends a window-change request to the remote PTY. func (s *SSHService) Resize(sessionID string, cols, rows int) error { s.mu.RLock() sess, ok := s.sessions[sessionID] s.mu.RUnlock() if !ok { return fmt.Errorf("session %s not found", sessionID) } sess.mu.Lock() defer sess.mu.Unlock() if sess.Session == nil { return fmt.Errorf("session %s is closed", sessionID) } if err := sess.Session.WindowChange(rows, cols); err != nil { return fmt.Errorf("resize session %s: %w", sessionID, err) } return nil } // Disconnect closes the SSH session and client, and removes it from tracking. func (s *SSHService) Disconnect(sessionID string) error { s.mu.Lock() sess, ok := s.sessions[sessionID] if !ok { s.mu.Unlock() return fmt.Errorf("session %s not found", sessionID) } delete(s.sessions, sessionID) s.mu.Unlock() sess.mu.Lock() defer sess.mu.Unlock() if sess.Stdin != nil { sess.Stdin.Close() } if sess.Session != nil { sess.Session.Close() } if sess.Client != nil { sess.Client.Close() } return nil } // GetSession returns the SSHSession for the given ID, or false if not found. func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) { s.mu.RLock() defer s.mu.RUnlock() sess, ok := s.sessions[sessionID] return sess, ok } // ListSessions returns all active SSH sessions. func (s *SSHService) ListSessions() []*SSHSession { s.mu.RLock() defer s.mu.RUnlock() list := make([]*SSHSession, 0, len(s.sessions)) for _, sess := range s.sessions { list = append(list, sess) } return list } // BuildPasswordAuth creates an ssh.AuthMethod for password authentication. func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod { return ssh.Password(password) } // BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key. // If the key is encrypted, pass the passphrase; otherwise pass an empty string. func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) { var signer ssh.Signer var err error if passphrase != "" { signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase)) } else { signer, err = ssh.ParsePrivateKey(privateKey) } if err != nil { return nil, fmt.Errorf("parse private key: %w", err) } return ssh.PublicKeys(signer), nil }