diff --git a/internal/app/app.go b/internal/app/app.go index 177ffb2..6aef899 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -363,12 +363,11 @@ func (a *WraithApp) ConnectSSH(connectionID int64, cols, rows int) (string, erro slog.Warn("failed to update last_connected", "error", err) } - // Register with session manager - if _, err := a.Sessions.Create(connectionID, "ssh"); err != nil { + // Register with session manager using the SSH session's own UUID + if _, err := a.Sessions.CreateWithID(sessionID, connectionID, "ssh"); err != nil { slog.Warn("failed to register SSH session in manager", "error", err) } else { - // Store the SSH session ID as the manager session ID for lookup - a.Sessions.SetState(sessionID, session.StateConnected) + _ = a.Sessions.SetState(sessionID, session.StateConnected) } // Save workspace state after session change @@ -416,6 +415,16 @@ func (a *WraithApp) ConnectSSHWithPassword(connectionID int64, username, passwor slog.Warn("failed to update last_connected", "error", err) } + // Register with session manager + if _, err := a.Sessions.CreateWithID(sessionID, connectionID, "ssh"); err != nil { + slog.Warn("failed to register SSH session in manager", "error", err) + } else { + _ = a.Sessions.SetState(sessionID, session.StateConnected) + } + + // Save workspace state after session change + a.saveWorkspaceState() + slog.Info("SSH session started (ad-hoc password)", "sessionID", sessionID, "host", conn.Hostname, "user", username) return sessionID, nil } @@ -425,9 +434,12 @@ func (a *WraithApp) GetVersion() string { return a.Updater.CurrentVersion() } -// DisconnectSession closes an active SSH session and its SFTP client. +// DisconnectSession closes an active SSH session and its SFTP client, +// and removes it from the session manager. func (a *WraithApp) DisconnectSession(sessionID string) error { a.SFTP.RemoveClient(sessionID) + a.Sessions.Remove(sessionID) + a.saveWorkspaceState() return a.SSH.Disconnect(sessionID) } @@ -480,6 +492,16 @@ func (a *WraithApp) ConnectRDP(connectionID int64, width, height int) (string, e slog.Warn("failed to update last_connected", "error", err) } + // Register with session manager + if _, err := a.Sessions.CreateWithID(sessionID, connectionID, "rdp"); err != nil { + slog.Warn("failed to register RDP session in manager", "error", err) + } else { + _ = a.Sessions.SetState(sessionID, session.StateConnected) + } + + // Save workspace state after session change + a.saveWorkspaceState() + slog.Info("RDP session started", "sessionID", sessionID, "host", conn.Hostname) return sessionID, nil } @@ -511,8 +533,10 @@ func (a *WraithApp) RDPSendClipboard(sessionID string, text string) error { return a.RDP.SendClipboard(sessionID, text) } -// RDPDisconnect tears down an RDP session. +// RDPDisconnect tears down an RDP session and removes it from the session manager. func (a *WraithApp) RDPDisconnect(sessionID string) error { + a.Sessions.Remove(sessionID) + a.saveWorkspaceState() return a.RDP.Disconnect(sessionID) } @@ -619,3 +643,80 @@ func (a *WraithApp) ImportMobaConf(fileContent string) (*plugin.ImportResult, er ) return result, nil } + +// ---------- Workspace proxy methods ---------- + +// saveWorkspaceState builds a workspace snapshot from the current session manager +// state and persists it. Called automatically on session open/close. +func (a *WraithApp) saveWorkspaceState() { + if a.Workspace == nil { + return + } + + sessions := a.Sessions.List() + tabs := make([]WorkspaceTab, 0, len(sessions)) + for _, s := range sessions { + tabs = append(tabs, WorkspaceTab{ + ConnectionID: s.ConnectionID, + Protocol: s.Protocol, + Position: s.TabPosition, + }) + } + + snapshot := &WorkspaceSnapshot{ + Tabs: tabs, + } + if err := a.Workspace.Save(snapshot); err != nil { + slog.Warn("failed to save workspace state", "error", err) + } +} + +// SaveWorkspace explicitly saves the current workspace snapshot. +// Exposed to the frontend for manual save triggers. +func (a *WraithApp) SaveWorkspace(snapshot *WorkspaceSnapshot) error { + if a.Workspace == nil { + return fmt.Errorf("workspace service not initialized") + } + return a.Workspace.Save(snapshot) +} + +// LoadWorkspace returns the last saved workspace snapshot, or nil if none exists. +func (a *WraithApp) LoadWorkspace() (*WorkspaceSnapshot, error) { + if a.Workspace == nil { + return nil, fmt.Errorf("workspace service not initialized") + } + return a.Workspace.Load() +} + +// MarkCleanShutdown records that the app is shutting down cleanly. +// Called before app exit so the next startup knows whether to restore workspace. +func (a *WraithApp) MarkCleanShutdown() error { + if a.Workspace == nil { + return nil + } + return a.Workspace.MarkCleanShutdown() +} + +// WasCleanShutdown returns whether the previous app exit was clean. +func (a *WraithApp) WasCleanShutdown() bool { + if a.Workspace == nil { + return true + } + return a.Workspace.WasCleanShutdown() +} + +// DeleteHostKey removes stored host keys for a hostname:port, allowing +// reconnection after a legitimate server re-key. +func (a *WraithApp) DeleteHostKey(hostname string, port int) error { + store := ssh.NewHostKeyStore(a.db) + return store.Delete(hostname, port) +} + +// GetSessionCWD returns the current tracked working directory for an SSH session. +func (a *WraithApp) GetSessionCWD(sessionID string) string { + sess, ok := a.SSH.GetSession(sessionID) + if !ok || sess.CWDTracker == nil { + return "" + } + return sess.CWDTracker.GetCWD() +} diff --git a/internal/session/manager.go b/internal/session/manager.go index 74e760c..0a3ea32 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -3,6 +3,7 @@ package session import ( "fmt" "sync" + "time" "github.com/google/uuid" ) @@ -21,6 +22,12 @@ func NewManager() *Manager { } func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, error) { + return m.CreateWithID(uuid.NewString(), connectionID, protocol) +} + +// CreateWithID registers a session with an explicit ID (e.g., the SSH session UUID). +// This keeps the session manager in sync with service-level session IDs. +func (m *Manager) CreateWithID(id string, connectionID int64, protocol string) (*SessionInfo, error) { m.mu.Lock() defer m.mu.Unlock() @@ -29,11 +36,12 @@ func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, err } s := &SessionInfo{ - ID: uuid.NewString(), + ID: id, ConnectionID: connectionID, Protocol: protocol, State: StateConnecting, TabPosition: len(m.sessions), + ConnectedAt: time.Now(), } m.sessions[s.ID] = s return s, nil diff --git a/internal/ssh/service.go b/internal/ssh/service.go index 181f002..2e3b5c1 100644 --- a/internal/ssh/service.go +++ b/internal/ssh/service.go @@ -2,8 +2,11 @@ package ssh import ( "database/sql" + "encoding/base64" "fmt" "io" + "log/slog" + "net" "sync" "time" @@ -15,18 +18,22 @@ import ( // In production this will emit Wails events; for testing, a simple callback. type OutputHandler func(sessionID string, data []byte) +// CWDHandler is called when the CWD tracker detects a directory change via OSC 7. +type CWDHandler func(sessionID string, path string) + // SSHSession represents an active SSH connection with its PTY shell session. type SSHSession struct { - ID string - Client *ssh.Client - Session *ssh.Session - Stdin io.WriteCloser - ConnID int64 - Hostname string - Port int - Username string - Connected time.Time - mu sync.Mutex + ID string + Client *ssh.Client + Session *ssh.Session + Stdin io.WriteCloser + ConnID int64 + Hostname string + Port int + Username string + Connected time.Time + CWDTracker *CWDTracker + mu sync.Mutex } // SSHService manages SSH connections and their associated sessions. @@ -34,28 +41,41 @@ type SSHService struct { sessions map[string]*SSHSession mu sync.RWMutex db *sql.DB + hostKeyStore *HostKeyStore outputHandler OutputHandler + cwdHandler CWDHandler } // NewSSHService creates a new SSHService. The outputHandler is called when data -// arrives from a session's stdout. Pass nil if output handling is not needed. -func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService { +// arrives from a session's stdout. The cwdHandler is called when OSC 7 CWD +// changes are detected. The hostKeyStore verifies SSH host keys (TOFU model). +// Pass nil for any handler if not needed. +func NewSSHService(db *sql.DB, hostKeyStore *HostKeyStore, outputHandler OutputHandler, cwdHandler CWDHandler) *SSHService { return &SSHService{ sessions: make(map[string]*SSHSession), db: db, + hostKeyStore: hostKeyStore, outputHandler: outputHandler, + cwdHandler: cwdHandler, } } // Connect dials an SSH server, opens a session with a PTY and shell, and // launches a goroutine to read stdout. Returns the session ID. +// +// Host key verification uses Trust On First Use (TOFU): new keys are stored +// automatically, matching keys are accepted silently, and CHANGED keys reject +// the connection to protect against MITM attacks. func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) { addr := fmt.Sprintf("%s:%d", hostname, port) + // Build host key callback — TOFU model via HostKeyStore + hostKeyCallback := s.buildHostKeyCallback(hostname, port) + config := &ssh.ClientConfig{ User: username, Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: hostKeyCallback, Timeout: 15 * time.Second, } @@ -104,14 +124,15 @@ func (s *SSHService) Connect(hostname string, port int, username string, authMet sessionID := uuid.NewString() sshSession := &SSHSession{ - ID: sessionID, - Client: client, - Session: session, - Stdin: stdin, - Hostname: hostname, - Port: port, - Username: username, - Connected: time.Now(), + ID: sessionID, + Client: client, + Session: session, + Stdin: stdin, + Hostname: hostname, + Port: port, + Username: username, + Connected: time.Now(), + CWDTracker: NewCWDTracker(), } s.mu.Lock() @@ -121,20 +142,103 @@ func (s *SSHService) Connect(hostname string, port int, username string, authMet // Launch goroutine to read stdout and forward data via the output handler go s.readLoop(sessionID, stdout) + // Inject shell integration for CWD tracking (OSC 7). + // Send both bash and zsh variants — the wrong shell ignores them harmlessly. + go func() { + time.Sleep(500 * time.Millisecond) + bashCmd := ShellIntegrationCommand("bash") + zshCmd := ShellIntegrationCommand("zsh") + if bashCmd != "" { + // Write commands silently; errors are non-fatal + _, _ = stdin.Write([]byte(bashCmd + "\n")) + } + if zshCmd != "" { + _, _ = stdin.Write([]byte(zshCmd + "\n")) + } + }() + return sessionID, nil } -// readLoop continuously reads from the session stdout and calls the output -// handler with data. It stops when the reader returns an error (typically EOF -// when the session closes). +// buildHostKeyCallback returns an ssh.HostKeyCallback that implements TOFU +// (Trust On First Use) via the HostKeyStore. If no store is configured, it +// falls back to accepting all keys (for testing/development only). +func (s *SSHService) buildHostKeyCallback(hostname string, port int) ssh.HostKeyCallback { + if s.hostKeyStore == nil { + slog.Warn("no host key store configured — accepting all host keys (INSECURE)") + return ssh.InsecureIgnoreHostKey() + } + + store := s.hostKeyStore + return func(remoteHostname string, addr net.Addr, key ssh.PublicKey) error { + keyType := key.Type() + fingerprint := ssh.FingerprintSHA256(key) + + result, err := store.Verify(hostname, port, keyType, fingerprint) + if err != nil { + return fmt.Errorf("host key verification failed: %w", err) + } + + switch result { + case HostKeyMatch: + slog.Debug("host key verified", "host", hostname, "port", port, "type", keyType) + return nil + + case HostKeyNew: + // TOFU: store the key and accept + rawKey := base64.StdEncoding.EncodeToString(key.Marshal()) + if storeErr := store.Store(hostname, port, keyType, fingerprint, rawKey); storeErr != nil { + slog.Warn("failed to store new host key", "host", hostname, "error", storeErr) + // Still accept the connection — storing failed but the key itself is fine + } + slog.Info("new host key stored (TOFU)", "host", hostname, "port", port, "type", keyType, "fingerprint", fingerprint) + return nil + + case HostKeyChanged: + // REJECT — possible MITM attack + return fmt.Errorf( + "HOST KEY CHANGED for %s:%d (type %s). Expected fingerprint does not match. "+ + "This could indicate a man-in-the-middle attack. Connection refused. "+ + "If the server was legitimately re-keyed, remove the old key and try again", + hostname, port, keyType, + ) + + default: + return fmt.Errorf("unknown host key result: %d", result) + } + } +} + +// readLoop continuously reads from the session stdout, processes CWD tracking +// (stripping OSC 7 sequences), and calls the output handler with cleaned data. +// It stops when the reader returns an error (typically EOF when the session closes). func (s *SSHService) readLoop(sessionID string, reader io.Reader) { + // Grab the CWD tracker for this session (if any) + s.mu.RLock() + sess := s.sessions[sessionID] + s.mu.RUnlock() + buf := make([]byte, 32*1024) for { n, err := reader.Read(buf) - if n > 0 && s.outputHandler != nil { + if n > 0 { data := make([]byte, n) copy(data, buf[:n]) - s.outputHandler(sessionID, data) + + // Process CWD tracking — strips OSC 7 sequences from output + if sess != nil && sess.CWDTracker != nil { + cleaned, newCWD := sess.CWDTracker.ProcessOutput(data) + data = cleaned + + // Emit CWD change event if a new path was detected + if newCWD != "" && s.cwdHandler != nil { + s.cwdHandler(sessionID, newCWD) + } + } + + if len(data) > 0 && s.outputHandler != nil { + s.outputHandler(sessionID, data) + } } if err != nil { break diff --git a/internal/ssh/service_test.go b/internal/ssh/service_test.go index 0a880b4..458c6e4 100644 --- a/internal/ssh/service_test.go +++ b/internal/ssh/service_test.go @@ -11,7 +11,7 @@ import ( ) func TestNewSSHService(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) if svc == nil { t.Fatal("NewSSHService returned nil") } @@ -21,7 +21,7 @@ func TestNewSSHService(t *testing.T) { } func TestBuildPasswordAuth(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) auth := svc.BuildPasswordAuth("mypassword") if auth == nil { t.Error("BuildPasswordAuth returned nil") @@ -29,7 +29,7 @@ func TestBuildPasswordAuth(t *testing.T) { } func TestBuildKeyAuth(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) // Generate a test Ed25519 key _, priv, err := ed25519.GenerateKey(rand.Reader) @@ -52,7 +52,7 @@ func TestBuildKeyAuth(t *testing.T) { } func TestBuildKeyAuthInvalidKey(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) _, err := svc.BuildKeyAuth([]byte("not a key"), "") if err == nil { t.Error("BuildKeyAuth should fail with invalid key") @@ -60,7 +60,7 @@ func TestBuildKeyAuthInvalidKey(t *testing.T) { } func TestSessionTracking(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) // Manually add a session to test tracking svc.mu.Lock() @@ -88,7 +88,7 @@ func TestSessionTracking(t *testing.T) { } func TestGetSessionNotFound(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) _, ok := svc.GetSession("nonexistent") if ok { t.Error("GetSession should return false for nonexistent session") @@ -96,7 +96,7 @@ func TestGetSessionNotFound(t *testing.T) { } func TestWriteNotFound(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) err := svc.Write("nonexistent", "data") if err == nil { t.Error("Write should fail for nonexistent session") @@ -104,7 +104,7 @@ func TestWriteNotFound(t *testing.T) { } func TestResizeNotFound(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) err := svc.Resize("nonexistent", 80, 24) if err == nil { t.Error("Resize should fail for nonexistent session") @@ -112,7 +112,7 @@ func TestResizeNotFound(t *testing.T) { } func TestDisconnectNotFound(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) err := svc.Disconnect("nonexistent") if err == nil { t.Error("Disconnect should fail for nonexistent session") @@ -120,7 +120,7 @@ func TestDisconnectNotFound(t *testing.T) { } func TestDisconnectRemovesSession(t *testing.T) { - svc := NewSSHService(nil, nil) + svc := NewSSHService(nil, nil, nil, nil) // Manually add a session with nil Client/Session/Stdin (no real connection) svc.mu.Lock()