From c48c0de042c44ba2f6ac872be9f990caad6fe2e8 Mon Sep 17 00:00:00 2001 From: Vantz Stockwell Date: Tue, 17 Mar 2026 06:51:30 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20SSH=20service=20=E2=80=94=20connect,=20?= =?UTF-8?q?PTY,=20shell=20I/O=20with=20goroutine=20pipes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/ssh/service.go | 259 +++++++++++++++++++++++++++++++++++ internal/ssh/service_test.go | 148 ++++++++++++++++++++ 2 files changed, 407 insertions(+) create mode 100644 internal/ssh/service.go create mode 100644 internal/ssh/service_test.go diff --git a/internal/ssh/service.go b/internal/ssh/service.go new file mode 100644 index 0000000..be88557 --- /dev/null +++ b/internal/ssh/service.go @@ -0,0 +1,259 @@ +package ssh + +import ( + "database/sql" + "fmt" + "io" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/crypto/ssh" +) + +// OutputHandler is called when data is read from an SSH session's stdout. +// In production this will emit Wails events; for testing, a simple callback. +type OutputHandler func(sessionID string, data []byte) + +// SSHSession represents an active SSH connection with its PTY shell session. +type SSHSession struct { + ID string + Client *ssh.Client + Session *ssh.Session + Stdin io.WriteCloser + ConnID int64 + Hostname string + Port int + Username string + Connected time.Time + mu sync.Mutex +} + +// SSHService manages SSH connections and their associated sessions. +type SSHService struct { + sessions map[string]*SSHSession + mu sync.RWMutex + db *sql.DB + outputHandler OutputHandler +} + +// NewSSHService creates a new SSHService. The outputHandler is called when data +// arrives from a session's stdout. Pass nil if output handling is not needed. +func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService { + return &SSHService{ + sessions: make(map[string]*SSHSession), + db: db, + outputHandler: outputHandler, + } +} + +// Connect dials an SSH server, opens a session with a PTY and shell, and +// launches a goroutine to read stdout. Returns the session ID. +func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) { + addr := fmt.Sprintf("%s:%d", hostname, port) + + config := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 15 * time.Second, + } + + client, err := ssh.Dial("tcp", addr, config) + if err != nil { + return "", fmt.Errorf("ssh dial %s: %w", addr, err) + } + + session, err := client.NewSession() + if err != nil { + client.Close() + return "", fmt.Errorf("new session: %w", err) + } + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + + if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("request pty: %w", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("stdin pipe: %w", err) + } + + stdout, err := session.StdoutPipe() + if err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("stdout pipe: %w", err) + } + + if err := session.Shell(); err != nil { + session.Close() + client.Close() + return "", fmt.Errorf("start shell: %w", err) + } + + sessionID := uuid.NewString() + sshSession := &SSHSession{ + ID: sessionID, + Client: client, + Session: session, + Stdin: stdin, + Hostname: hostname, + Port: port, + Username: username, + Connected: time.Now(), + } + + s.mu.Lock() + s.sessions[sessionID] = sshSession + s.mu.Unlock() + + // Launch goroutine to read stdout and forward data via the output handler + go s.readLoop(sessionID, stdout) + + return sessionID, nil +} + +// readLoop continuously reads from the session stdout and calls the output +// handler with data. It stops when the reader returns an error (typically EOF +// when the session closes). +func (s *SSHService) readLoop(sessionID string, reader io.Reader) { + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if n > 0 && s.outputHandler != nil { + data := make([]byte, n) + copy(data, buf[:n]) + s.outputHandler(sessionID, data) + } + if err != nil { + break + } + } +} + +// Write sends data to the session's stdin. +func (s *SSHService) Write(sessionID string, data string) error { + s.mu.RLock() + sess, ok := s.sessions[sessionID] + s.mu.RUnlock() + + if !ok { + return fmt.Errorf("session %s not found", sessionID) + } + + sess.mu.Lock() + defer sess.mu.Unlock() + + if sess.Stdin == nil { + return fmt.Errorf("session %s stdin is closed", sessionID) + } + + _, err := sess.Stdin.Write([]byte(data)) + if err != nil { + return fmt.Errorf("write to session %s: %w", sessionID, err) + } + return nil +} + +// Resize sends a window-change request to the remote PTY. +func (s *SSHService) Resize(sessionID string, cols, rows int) error { + s.mu.RLock() + sess, ok := s.sessions[sessionID] + s.mu.RUnlock() + + if !ok { + return fmt.Errorf("session %s not found", sessionID) + } + + sess.mu.Lock() + defer sess.mu.Unlock() + + if sess.Session == nil { + return fmt.Errorf("session %s is closed", sessionID) + } + + if err := sess.Session.WindowChange(rows, cols); err != nil { + return fmt.Errorf("resize session %s: %w", sessionID, err) + } + return nil +} + +// Disconnect closes the SSH session and client, and removes it from tracking. +func (s *SSHService) Disconnect(sessionID string) error { + s.mu.Lock() + sess, ok := s.sessions[sessionID] + if !ok { + s.mu.Unlock() + return fmt.Errorf("session %s not found", sessionID) + } + delete(s.sessions, sessionID) + s.mu.Unlock() + + sess.mu.Lock() + defer sess.mu.Unlock() + + if sess.Stdin != nil { + sess.Stdin.Close() + } + if sess.Session != nil { + sess.Session.Close() + } + if sess.Client != nil { + sess.Client.Close() + } + + return nil +} + +// GetSession returns the SSHSession for the given ID, or false if not found. +func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[sessionID] + return sess, ok +} + +// ListSessions returns all active SSH sessions. +func (s *SSHService) ListSessions() []*SSHSession { + s.mu.RLock() + defer s.mu.RUnlock() + list := make([]*SSHSession, 0, len(s.sessions)) + for _, sess := range s.sessions { + list = append(list, sess) + } + return list +} + +// BuildPasswordAuth creates an ssh.AuthMethod for password authentication. +func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod { + return ssh.Password(password) +} + +// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key. +// If the key is encrypted, pass the passphrase; otherwise pass an empty string. +func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) { + var signer ssh.Signer + var err error + + if passphrase != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase)) + } else { + signer, err = ssh.ParsePrivateKey(privateKey) + } + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + return ssh.PublicKeys(signer), nil +} diff --git a/internal/ssh/service_test.go b/internal/ssh/service_test.go new file mode 100644 index 0000000..0a880b4 --- /dev/null +++ b/internal/ssh/service_test.go @@ -0,0 +1,148 @@ +package ssh + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func TestNewSSHService(t *testing.T) { + svc := NewSSHService(nil, nil) + if svc == nil { + t.Fatal("NewSSHService returned nil") + } + if len(svc.ListSessions()) != 0 { + t.Error("new service should have no sessions") + } +} + +func TestBuildPasswordAuth(t *testing.T) { + svc := NewSSHService(nil, nil) + auth := svc.BuildPasswordAuth("mypassword") + if auth == nil { + t.Error("BuildPasswordAuth returned nil") + } +} + +func TestBuildKeyAuth(t *testing.T) { + svc := NewSSHService(nil, nil) + + // Generate a test Ed25519 key + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("GenerateKey error: %v", err) + } + pemBlock, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + t.Fatalf("MarshalPrivateKey error: %v", err) + } + keyBytes := pem.EncodeToMemory(pemBlock) + + auth, err := svc.BuildKeyAuth(keyBytes, "") + if err != nil { + t.Fatalf("BuildKeyAuth error: %v", err) + } + if auth == nil { + t.Error("BuildKeyAuth returned nil") + } +} + +func TestBuildKeyAuthInvalidKey(t *testing.T) { + svc := NewSSHService(nil, nil) + _, err := svc.BuildKeyAuth([]byte("not a key"), "") + if err == nil { + t.Error("BuildKeyAuth should fail with invalid key") + } +} + +func TestSessionTracking(t *testing.T) { + svc := NewSSHService(nil, nil) + + // Manually add a session to test tracking + svc.mu.Lock() + svc.sessions["test-123"] = &SSHSession{ + ID: "test-123", + Hostname: "192.168.1.4", + Port: 22, + Username: "vstockwell", + Connected: time.Now(), + } + svc.mu.Unlock() + + s, ok := svc.GetSession("test-123") + if !ok { + t.Fatal("session not found") + } + if s.Hostname != "192.168.1.4" { + t.Errorf("Hostname = %q, want %q", s.Hostname, "192.168.1.4") + } + + sessions := svc.ListSessions() + if len(sessions) != 1 { + t.Errorf("ListSessions() = %d, want 1", len(sessions)) + } +} + +func TestGetSessionNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + _, ok := svc.GetSession("nonexistent") + if ok { + t.Error("GetSession should return false for nonexistent session") + } +} + +func TestWriteNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + err := svc.Write("nonexistent", "data") + if err == nil { + t.Error("Write should fail for nonexistent session") + } +} + +func TestResizeNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + err := svc.Resize("nonexistent", 80, 24) + if err == nil { + t.Error("Resize should fail for nonexistent session") + } +} + +func TestDisconnectNotFound(t *testing.T) { + svc := NewSSHService(nil, nil) + err := svc.Disconnect("nonexistent") + if err == nil { + t.Error("Disconnect should fail for nonexistent session") + } +} + +func TestDisconnectRemovesSession(t *testing.T) { + svc := NewSSHService(nil, nil) + + // Manually add a session with nil Client/Session/Stdin (no real connection) + svc.mu.Lock() + svc.sessions["test-dc"] = &SSHSession{ + ID: "test-dc", + Hostname: "10.0.0.1", + Port: 22, + Username: "admin", + Connected: time.Now(), + } + svc.mu.Unlock() + + if err := svc.Disconnect("test-dc"); err != nil { + t.Fatalf("Disconnect error: %v", err) + } + + _, ok := svc.GetSession("test-dc") + if ok { + t.Error("session should be removed after Disconnect") + } + + if len(svc.ListSessions()) != 0 { + t.Error("ListSessions should be empty after Disconnect") + } +}