U-1: Replace ssh.InsecureIgnoreHostKey() with TOFU (Trust On First Use) host
key verification via HostKeyStore. New keys auto-store, matching keys accept
silently, CHANGED keys reject with MITM warning. Added DeleteHostKey() for
legitimate re-key scenarios.
U-2: Wire CWDTracker per SSH session. readLoop() now processes OSC 7 escape
sequences, strips them from terminal output, and emits ssh:cwd:{sessionID}
Wails events on directory changes. Shell integration commands (bash/zsh
PROMPT_COMMAND) injected after connection.
U-3: Session manager now tracks all SSH and RDP sessions via CreateWithID()
which accepts the service-level UUID instead of generating a new one.
ConnectSSH, ConnectSSHWithPassword, ConnectRDP register sessions;
DisconnectSession and RDPDisconnect remove them. ConnectedAt timestamp set.
U-4: WorkspaceService instantiated in New(), clean shutdown flag managed on
startup/exit, workspace state auto-saved on every session open/close.
Frontend-facing proxy methods exposed: SaveWorkspace, LoadWorkspace,
MarkCleanShutdown, WasCleanShutdown, GetSessionCWD.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
364 lines
9.9 KiB
Go
364 lines
9.9 KiB
Go
package ssh
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// OutputHandler is called when data is read from an SSH session's stdout.
|
|
// In production this will emit Wails events; for testing, a simple callback.
|
|
type OutputHandler func(sessionID string, data []byte)
|
|
|
|
// CWDHandler is called when the CWD tracker detects a directory change via OSC 7.
|
|
type CWDHandler func(sessionID string, path string)
|
|
|
|
// SSHSession represents an active SSH connection with its PTY shell session.
|
|
type SSHSession struct {
|
|
ID string
|
|
Client *ssh.Client
|
|
Session *ssh.Session
|
|
Stdin io.WriteCloser
|
|
ConnID int64
|
|
Hostname string
|
|
Port int
|
|
Username string
|
|
Connected time.Time
|
|
CWDTracker *CWDTracker
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// SSHService manages SSH connections and their associated sessions.
|
|
type SSHService struct {
|
|
sessions map[string]*SSHSession
|
|
mu sync.RWMutex
|
|
db *sql.DB
|
|
hostKeyStore *HostKeyStore
|
|
outputHandler OutputHandler
|
|
cwdHandler CWDHandler
|
|
}
|
|
|
|
// NewSSHService creates a new SSHService. The outputHandler is called when data
|
|
// arrives from a session's stdout. The cwdHandler is called when OSC 7 CWD
|
|
// changes are detected. The hostKeyStore verifies SSH host keys (TOFU model).
|
|
// Pass nil for any handler if not needed.
|
|
func NewSSHService(db *sql.DB, hostKeyStore *HostKeyStore, outputHandler OutputHandler, cwdHandler CWDHandler) *SSHService {
|
|
return &SSHService{
|
|
sessions: make(map[string]*SSHSession),
|
|
db: db,
|
|
hostKeyStore: hostKeyStore,
|
|
outputHandler: outputHandler,
|
|
cwdHandler: cwdHandler,
|
|
}
|
|
}
|
|
|
|
// Connect dials an SSH server, opens a session with a PTY and shell, and
|
|
// launches a goroutine to read stdout. Returns the session ID.
|
|
//
|
|
// Host key verification uses Trust On First Use (TOFU): new keys are stored
|
|
// automatically, matching keys are accepted silently, and CHANGED keys reject
|
|
// the connection to protect against MITM attacks.
|
|
func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) {
|
|
addr := fmt.Sprintf("%s:%d", hostname, port)
|
|
|
|
// Build host key callback — TOFU model via HostKeyStore
|
|
hostKeyCallback := s.buildHostKeyCallback(hostname, port)
|
|
|
|
config := &ssh.ClientConfig{
|
|
User: username,
|
|
Auth: authMethods,
|
|
HostKeyCallback: hostKeyCallback,
|
|
Timeout: 15 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, config)
|
|
if err != nil {
|
|
return "", fmt.Errorf("ssh dial %s: %w", addr, err)
|
|
}
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
client.Close()
|
|
return "", fmt.Errorf("new session: %w", err)
|
|
}
|
|
|
|
modes := ssh.TerminalModes{
|
|
ssh.ECHO: 1,
|
|
ssh.TTY_OP_ISPEED: 115200,
|
|
ssh.TTY_OP_OSPEED: 115200,
|
|
}
|
|
|
|
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
|
session.Close()
|
|
client.Close()
|
|
return "", fmt.Errorf("request pty: %w", err)
|
|
}
|
|
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
session.Close()
|
|
client.Close()
|
|
return "", fmt.Errorf("stdin pipe: %w", err)
|
|
}
|
|
|
|
stdout, err := session.StdoutPipe()
|
|
if err != nil {
|
|
session.Close()
|
|
client.Close()
|
|
return "", fmt.Errorf("stdout pipe: %w", err)
|
|
}
|
|
|
|
if err := session.Shell(); err != nil {
|
|
session.Close()
|
|
client.Close()
|
|
return "", fmt.Errorf("start shell: %w", err)
|
|
}
|
|
|
|
sessionID := uuid.NewString()
|
|
sshSession := &SSHSession{
|
|
ID: sessionID,
|
|
Client: client,
|
|
Session: session,
|
|
Stdin: stdin,
|
|
Hostname: hostname,
|
|
Port: port,
|
|
Username: username,
|
|
Connected: time.Now(),
|
|
CWDTracker: NewCWDTracker(),
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.sessions[sessionID] = sshSession
|
|
s.mu.Unlock()
|
|
|
|
// Launch goroutine to read stdout and forward data via the output handler
|
|
go s.readLoop(sessionID, stdout)
|
|
|
|
// Inject shell integration for CWD tracking (OSC 7).
|
|
// Send both bash and zsh variants — the wrong shell ignores them harmlessly.
|
|
go func() {
|
|
time.Sleep(500 * time.Millisecond)
|
|
bashCmd := ShellIntegrationCommand("bash")
|
|
zshCmd := ShellIntegrationCommand("zsh")
|
|
if bashCmd != "" {
|
|
// Write commands silently; errors are non-fatal
|
|
_, _ = stdin.Write([]byte(bashCmd + "\n"))
|
|
}
|
|
if zshCmd != "" {
|
|
_, _ = stdin.Write([]byte(zshCmd + "\n"))
|
|
}
|
|
}()
|
|
|
|
return sessionID, nil
|
|
}
|
|
|
|
// buildHostKeyCallback returns an ssh.HostKeyCallback that implements TOFU
|
|
// (Trust On First Use) via the HostKeyStore. If no store is configured, it
|
|
// falls back to accepting all keys (for testing/development only).
|
|
func (s *SSHService) buildHostKeyCallback(hostname string, port int) ssh.HostKeyCallback {
|
|
if s.hostKeyStore == nil {
|
|
slog.Warn("no host key store configured — accepting all host keys (INSECURE)")
|
|
return ssh.InsecureIgnoreHostKey()
|
|
}
|
|
|
|
store := s.hostKeyStore
|
|
return func(remoteHostname string, addr net.Addr, key ssh.PublicKey) error {
|
|
keyType := key.Type()
|
|
fingerprint := ssh.FingerprintSHA256(key)
|
|
|
|
result, err := store.Verify(hostname, port, keyType, fingerprint)
|
|
if err != nil {
|
|
return fmt.Errorf("host key verification failed: %w", err)
|
|
}
|
|
|
|
switch result {
|
|
case HostKeyMatch:
|
|
slog.Debug("host key verified", "host", hostname, "port", port, "type", keyType)
|
|
return nil
|
|
|
|
case HostKeyNew:
|
|
// TOFU: store the key and accept
|
|
rawKey := base64.StdEncoding.EncodeToString(key.Marshal())
|
|
if storeErr := store.Store(hostname, port, keyType, fingerprint, rawKey); storeErr != nil {
|
|
slog.Warn("failed to store new host key", "host", hostname, "error", storeErr)
|
|
// Still accept the connection — storing failed but the key itself is fine
|
|
}
|
|
slog.Info("new host key stored (TOFU)", "host", hostname, "port", port, "type", keyType, "fingerprint", fingerprint)
|
|
return nil
|
|
|
|
case HostKeyChanged:
|
|
// REJECT — possible MITM attack
|
|
return fmt.Errorf(
|
|
"HOST KEY CHANGED for %s:%d (type %s). Expected fingerprint does not match. "+
|
|
"This could indicate a man-in-the-middle attack. Connection refused. "+
|
|
"If the server was legitimately re-keyed, remove the old key and try again",
|
|
hostname, port, keyType,
|
|
)
|
|
|
|
default:
|
|
return fmt.Errorf("unknown host key result: %d", result)
|
|
}
|
|
}
|
|
}
|
|
|
|
// readLoop continuously reads from the session stdout, processes CWD tracking
|
|
// (stripping OSC 7 sequences), and calls the output handler with cleaned data.
|
|
// It stops when the reader returns an error (typically EOF when the session closes).
|
|
func (s *SSHService) readLoop(sessionID string, reader io.Reader) {
|
|
// Grab the CWD tracker for this session (if any)
|
|
s.mu.RLock()
|
|
sess := s.sessions[sessionID]
|
|
s.mu.RUnlock()
|
|
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
n, err := reader.Read(buf)
|
|
if n > 0 {
|
|
data := make([]byte, n)
|
|
copy(data, buf[:n])
|
|
|
|
// Process CWD tracking — strips OSC 7 sequences from output
|
|
if sess != nil && sess.CWDTracker != nil {
|
|
cleaned, newCWD := sess.CWDTracker.ProcessOutput(data)
|
|
data = cleaned
|
|
|
|
// Emit CWD change event if a new path was detected
|
|
if newCWD != "" && s.cwdHandler != nil {
|
|
s.cwdHandler(sessionID, newCWD)
|
|
}
|
|
}
|
|
|
|
if len(data) > 0 && s.outputHandler != nil {
|
|
s.outputHandler(sessionID, data)
|
|
}
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write sends data to the session's stdin.
|
|
func (s *SSHService) Write(sessionID string, data string) error {
|
|
s.mu.RLock()
|
|
sess, ok := s.sessions[sessionID]
|
|
s.mu.RUnlock()
|
|
|
|
if !ok {
|
|
return fmt.Errorf("session %s not found", sessionID)
|
|
}
|
|
|
|
sess.mu.Lock()
|
|
defer sess.mu.Unlock()
|
|
|
|
if sess.Stdin == nil {
|
|
return fmt.Errorf("session %s stdin is closed", sessionID)
|
|
}
|
|
|
|
_, err := sess.Stdin.Write([]byte(data))
|
|
if err != nil {
|
|
return fmt.Errorf("write to session %s: %w", sessionID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Resize sends a window-change request to the remote PTY.
|
|
func (s *SSHService) Resize(sessionID string, cols, rows int) error {
|
|
s.mu.RLock()
|
|
sess, ok := s.sessions[sessionID]
|
|
s.mu.RUnlock()
|
|
|
|
if !ok {
|
|
return fmt.Errorf("session %s not found", sessionID)
|
|
}
|
|
|
|
sess.mu.Lock()
|
|
defer sess.mu.Unlock()
|
|
|
|
if sess.Session == nil {
|
|
return fmt.Errorf("session %s is closed", sessionID)
|
|
}
|
|
|
|
if err := sess.Session.WindowChange(rows, cols); err != nil {
|
|
return fmt.Errorf("resize session %s: %w", sessionID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Disconnect closes the SSH session and client, and removes it from tracking.
|
|
func (s *SSHService) Disconnect(sessionID string) error {
|
|
s.mu.Lock()
|
|
sess, ok := s.sessions[sessionID]
|
|
if !ok {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("session %s not found", sessionID)
|
|
}
|
|
delete(s.sessions, sessionID)
|
|
s.mu.Unlock()
|
|
|
|
sess.mu.Lock()
|
|
defer sess.mu.Unlock()
|
|
|
|
if sess.Stdin != nil {
|
|
sess.Stdin.Close()
|
|
}
|
|
if sess.Session != nil {
|
|
sess.Session.Close()
|
|
}
|
|
if sess.Client != nil {
|
|
sess.Client.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSession returns the SSHSession for the given ID, or false if not found.
|
|
func (s *SSHService) GetSession(sessionID string) (*SSHSession, bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
sess, ok := s.sessions[sessionID]
|
|
return sess, ok
|
|
}
|
|
|
|
// ListSessions returns all active SSH sessions.
|
|
func (s *SSHService) ListSessions() []*SSHSession {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
list := make([]*SSHSession, 0, len(s.sessions))
|
|
for _, sess := range s.sessions {
|
|
list = append(list, sess)
|
|
}
|
|
return list
|
|
}
|
|
|
|
// BuildPasswordAuth creates an ssh.AuthMethod for password authentication.
|
|
func (s *SSHService) BuildPasswordAuth(password string) ssh.AuthMethod {
|
|
return ssh.Password(password)
|
|
}
|
|
|
|
// BuildKeyAuth creates an ssh.AuthMethod from a PEM-encoded private key.
|
|
// If the key is encrypted, pass the passphrase; otherwise pass an empty string.
|
|
func (s *SSHService) BuildKeyAuth(privateKey []byte, passphrase string) (ssh.AuthMethod, error) {
|
|
var signer ssh.Signer
|
|
var err error
|
|
|
|
if passphrase != "" {
|
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(passphrase))
|
|
} else {
|
|
signer, err = ssh.ParsePrivateKey(privateKey)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse private key: %w", err)
|
|
}
|
|
|
|
return ssh.PublicKeys(signer), nil
|
|
}
|