Compare commits

..

2 Commits

Author SHA1 Message Date
Vantz Stockwell
c31563c8c6 fix: remove shell integration injection — echoes visibly in terminal
All checks were successful
Build & Sign Wraith / Build Windows + Sign (push) Successful in 1m1s
The CWD tracking PROMPT_COMMAND/precmd injection wrote raw escape
sequences to stdin that echoed back to the user. Removed until we
have a non-echoing mechanism (e.g., second SSH channel or .bashrc
modification). CWD tracking still works passively for shells that
already emit OSC 7 sequences.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 13:50:10 -04:00
Vantz Stockwell
6729eb5b80 feat: wire 4 backend services — host key verification, CWD tracking, session manager, workspace restore
U-1: Replace ssh.InsecureIgnoreHostKey() with TOFU (Trust On First Use) host
key verification via HostKeyStore. New keys auto-store, matching keys accept
silently, CHANGED keys reject with MITM warning. Added DeleteHostKey() for
legitimate re-key scenarios.

U-2: Wire CWDTracker per SSH session. readLoop() now processes OSC 7 escape
sequences, strips them from terminal output, and emits ssh:cwd:{sessionID}
Wails events on directory changes. Shell integration commands (bash/zsh
PROMPT_COMMAND) injected after connection.

U-3: Session manager now tracks all SSH and RDP sessions via CreateWithID()
which accepts the service-level UUID instead of generating a new one.
ConnectSSH, ConnectSSHWithPassword, ConnectRDP register sessions;
DisconnectSession and RDPDisconnect remove them. ConnectedAt timestamp set.

U-4: WorkspaceService instantiated in New(), clean shutdown flag managed on
startup/exit, workspace state auto-saved on every session open/close.
Frontend-facing proxy methods exposed: SaveWorkspace, LoadWorkspace,
MarkCleanShutdown, WasCleanShutdown, GetSessionCWD.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 13:41:58 -04:00
5 changed files with 246 additions and 43 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

View File

@ -363,12 +363,11 @@ func (a *WraithApp) ConnectSSH(connectionID int64, cols, rows int) (string, erro
slog.Warn("failed to update last_connected", "error", err) slog.Warn("failed to update last_connected", "error", err)
} }
// Register with session manager // Register with session manager using the SSH session's own UUID
if _, err := a.Sessions.Create(connectionID, "ssh"); err != nil { if _, err := a.Sessions.CreateWithID(sessionID, connectionID, "ssh"); err != nil {
slog.Warn("failed to register SSH session in manager", "error", err) slog.Warn("failed to register SSH session in manager", "error", err)
} else { } else {
// Store the SSH session ID as the manager session ID for lookup _ = a.Sessions.SetState(sessionID, session.StateConnected)
a.Sessions.SetState(sessionID, session.StateConnected)
} }
// Save workspace state after session change // Save workspace state after session change
@ -416,6 +415,16 @@ func (a *WraithApp) ConnectSSHWithPassword(connectionID int64, username, passwor
slog.Warn("failed to update last_connected", "error", err) slog.Warn("failed to update last_connected", "error", err)
} }
// Register with session manager
if _, err := a.Sessions.CreateWithID(sessionID, connectionID, "ssh"); err != nil {
slog.Warn("failed to register SSH session in manager", "error", err)
} else {
_ = a.Sessions.SetState(sessionID, session.StateConnected)
}
// Save workspace state after session change
a.saveWorkspaceState()
slog.Info("SSH session started (ad-hoc password)", "sessionID", sessionID, "host", conn.Hostname, "user", username) slog.Info("SSH session started (ad-hoc password)", "sessionID", sessionID, "host", conn.Hostname, "user", username)
return sessionID, nil return sessionID, nil
} }
@ -425,9 +434,12 @@ func (a *WraithApp) GetVersion() string {
return a.Updater.CurrentVersion() return a.Updater.CurrentVersion()
} }
// DisconnectSession closes an active SSH session and its SFTP client. // DisconnectSession closes an active SSH session and its SFTP client,
// and removes it from the session manager.
func (a *WraithApp) DisconnectSession(sessionID string) error { func (a *WraithApp) DisconnectSession(sessionID string) error {
a.SFTP.RemoveClient(sessionID) a.SFTP.RemoveClient(sessionID)
a.Sessions.Remove(sessionID)
a.saveWorkspaceState()
return a.SSH.Disconnect(sessionID) return a.SSH.Disconnect(sessionID)
} }
@ -480,6 +492,16 @@ func (a *WraithApp) ConnectRDP(connectionID int64, width, height int) (string, e
slog.Warn("failed to update last_connected", "error", err) slog.Warn("failed to update last_connected", "error", err)
} }
// Register with session manager
if _, err := a.Sessions.CreateWithID(sessionID, connectionID, "rdp"); err != nil {
slog.Warn("failed to register RDP session in manager", "error", err)
} else {
_ = a.Sessions.SetState(sessionID, session.StateConnected)
}
// Save workspace state after session change
a.saveWorkspaceState()
slog.Info("RDP session started", "sessionID", sessionID, "host", conn.Hostname) slog.Info("RDP session started", "sessionID", sessionID, "host", conn.Hostname)
return sessionID, nil return sessionID, nil
} }
@ -511,8 +533,10 @@ func (a *WraithApp) RDPSendClipboard(sessionID string, text string) error {
return a.RDP.SendClipboard(sessionID, text) return a.RDP.SendClipboard(sessionID, text)
} }
// RDPDisconnect tears down an RDP session. // RDPDisconnect tears down an RDP session and removes it from the session manager.
func (a *WraithApp) RDPDisconnect(sessionID string) error { func (a *WraithApp) RDPDisconnect(sessionID string) error {
a.Sessions.Remove(sessionID)
a.saveWorkspaceState()
return a.RDP.Disconnect(sessionID) return a.RDP.Disconnect(sessionID)
} }
@ -619,3 +643,80 @@ func (a *WraithApp) ImportMobaConf(fileContent string) (*plugin.ImportResult, er
) )
return result, nil return result, nil
} }
// ---------- Workspace proxy methods ----------
// saveWorkspaceState builds a workspace snapshot from the current session manager
// state and persists it. Called automatically on session open/close.
func (a *WraithApp) saveWorkspaceState() {
if a.Workspace == nil {
return
}
sessions := a.Sessions.List()
tabs := make([]WorkspaceTab, 0, len(sessions))
for _, s := range sessions {
tabs = append(tabs, WorkspaceTab{
ConnectionID: s.ConnectionID,
Protocol: s.Protocol,
Position: s.TabPosition,
})
}
snapshot := &WorkspaceSnapshot{
Tabs: tabs,
}
if err := a.Workspace.Save(snapshot); err != nil {
slog.Warn("failed to save workspace state", "error", err)
}
}
// SaveWorkspace explicitly saves the current workspace snapshot.
// Exposed to the frontend for manual save triggers.
func (a *WraithApp) SaveWorkspace(snapshot *WorkspaceSnapshot) error {
if a.Workspace == nil {
return fmt.Errorf("workspace service not initialized")
}
return a.Workspace.Save(snapshot)
}
// LoadWorkspace returns the last saved workspace snapshot, or nil if none exists.
func (a *WraithApp) LoadWorkspace() (*WorkspaceSnapshot, error) {
if a.Workspace == nil {
return nil, fmt.Errorf("workspace service not initialized")
}
return a.Workspace.Load()
}
// MarkCleanShutdown records that the app is shutting down cleanly.
// Called before app exit so the next startup knows whether to restore workspace.
func (a *WraithApp) MarkCleanShutdown() error {
if a.Workspace == nil {
return nil
}
return a.Workspace.MarkCleanShutdown()
}
// WasCleanShutdown returns whether the previous app exit was clean.
func (a *WraithApp) WasCleanShutdown() bool {
if a.Workspace == nil {
return true
}
return a.Workspace.WasCleanShutdown()
}
// DeleteHostKey removes stored host keys for a hostname:port, allowing
// reconnection after a legitimate server re-key.
func (a *WraithApp) DeleteHostKey(hostname string, port int) error {
store := ssh.NewHostKeyStore(a.db)
return store.Delete(hostname, port)
}
// GetSessionCWD returns the current tracked working directory for an SSH session.
func (a *WraithApp) GetSessionCWD(sessionID string) string {
sess, ok := a.SSH.GetSession(sessionID)
if !ok || sess.CWDTracker == nil {
return ""
}
return sess.CWDTracker.GetCWD()
}

View File

@ -3,6 +3,7 @@ package session
import ( import (
"fmt" "fmt"
"sync" "sync"
"time"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -21,6 +22,12 @@ func NewManager() *Manager {
} }
func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, error) { func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, error) {
return m.CreateWithID(uuid.NewString(), connectionID, protocol)
}
// CreateWithID registers a session with an explicit ID (e.g., the SSH session UUID).
// This keeps the session manager in sync with service-level session IDs.
func (m *Manager) CreateWithID(id string, connectionID int64, protocol string) (*SessionInfo, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@ -29,11 +36,12 @@ func (m *Manager) Create(connectionID int64, protocol string) (*SessionInfo, err
} }
s := &SessionInfo{ s := &SessionInfo{
ID: uuid.NewString(), ID: id,
ConnectionID: connectionID, ConnectionID: connectionID,
Protocol: protocol, Protocol: protocol,
State: StateConnecting, State: StateConnecting,
TabPosition: len(m.sessions), TabPosition: len(m.sessions),
ConnectedAt: time.Now(),
} }
m.sessions[s.ID] = s m.sessions[s.ID] = s
return s, nil return s, nil

View File

@ -2,8 +2,11 @@ package ssh
import ( import (
"database/sql" "database/sql"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"log/slog"
"net"
"sync" "sync"
"time" "time"
@ -15,18 +18,22 @@ import (
// In production this will emit Wails events; for testing, a simple callback. // In production this will emit Wails events; for testing, a simple callback.
type OutputHandler func(sessionID string, data []byte) type OutputHandler func(sessionID string, data []byte)
// CWDHandler is called when the CWD tracker detects a directory change via OSC 7.
type CWDHandler func(sessionID string, path string)
// SSHSession represents an active SSH connection with its PTY shell session. // SSHSession represents an active SSH connection with its PTY shell session.
type SSHSession struct { type SSHSession struct {
ID string ID string
Client *ssh.Client Client *ssh.Client
Session *ssh.Session Session *ssh.Session
Stdin io.WriteCloser Stdin io.WriteCloser
ConnID int64 ConnID int64
Hostname string Hostname string
Port int Port int
Username string Username string
Connected time.Time Connected time.Time
mu sync.Mutex CWDTracker *CWDTracker
mu sync.Mutex
} }
// SSHService manages SSH connections and their associated sessions. // SSHService manages SSH connections and their associated sessions.
@ -34,28 +41,41 @@ type SSHService struct {
sessions map[string]*SSHSession sessions map[string]*SSHSession
mu sync.RWMutex mu sync.RWMutex
db *sql.DB db *sql.DB
hostKeyStore *HostKeyStore
outputHandler OutputHandler outputHandler OutputHandler
cwdHandler CWDHandler
} }
// NewSSHService creates a new SSHService. The outputHandler is called when data // NewSSHService creates a new SSHService. The outputHandler is called when data
// arrives from a session's stdout. Pass nil if output handling is not needed. // arrives from a session's stdout. The cwdHandler is called when OSC 7 CWD
func NewSSHService(db *sql.DB, outputHandler OutputHandler) *SSHService { // changes are detected. The hostKeyStore verifies SSH host keys (TOFU model).
// Pass nil for any handler if not needed.
func NewSSHService(db *sql.DB, hostKeyStore *HostKeyStore, outputHandler OutputHandler, cwdHandler CWDHandler) *SSHService {
return &SSHService{ return &SSHService{
sessions: make(map[string]*SSHSession), sessions: make(map[string]*SSHSession),
db: db, db: db,
hostKeyStore: hostKeyStore,
outputHandler: outputHandler, outputHandler: outputHandler,
cwdHandler: cwdHandler,
} }
} }
// Connect dials an SSH server, opens a session with a PTY and shell, and // Connect dials an SSH server, opens a session with a PTY and shell, and
// launches a goroutine to read stdout. Returns the session ID. // launches a goroutine to read stdout. Returns the session ID.
//
// Host key verification uses Trust On First Use (TOFU): new keys are stored
// automatically, matching keys are accepted silently, and CHANGED keys reject
// the connection to protect against MITM attacks.
func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) { func (s *SSHService) Connect(hostname string, port int, username string, authMethods []ssh.AuthMethod, cols, rows int) (string, error) {
addr := fmt.Sprintf("%s:%d", hostname, port) addr := fmt.Sprintf("%s:%d", hostname, port)
// Build host key callback — TOFU model via HostKeyStore
hostKeyCallback := s.buildHostKeyCallback(hostname, port)
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: username, User: username,
Auth: authMethods, Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: hostKeyCallback,
Timeout: 15 * time.Second, Timeout: 15 * time.Second,
} }
@ -104,14 +124,15 @@ func (s *SSHService) Connect(hostname string, port int, username string, authMet
sessionID := uuid.NewString() sessionID := uuid.NewString()
sshSession := &SSHSession{ sshSession := &SSHSession{
ID: sessionID, ID: sessionID,
Client: client, Client: client,
Session: session, Session: session,
Stdin: stdin, Stdin: stdin,
Hostname: hostname, Hostname: hostname,
Port: port, Port: port,
Username: username, Username: username,
Connected: time.Now(), Connected: time.Now(),
CWDTracker: NewCWDTracker(),
} }
s.mu.Lock() s.mu.Lock()
@ -121,20 +142,93 @@ func (s *SSHService) Connect(hostname string, port int, username string, authMet
// Launch goroutine to read stdout and forward data via the output handler // Launch goroutine to read stdout and forward data via the output handler
go s.readLoop(sessionID, stdout) go s.readLoop(sessionID, stdout)
// CWD tracking via OSC 7 is handled passively — the CWDTracker in readLoop
// parses OSC 7 sequences if the remote shell already emits them. Automatic
// PROMPT_COMMAND injection is deferred until we have a non-echoing mechanism
// (e.g., writing to a second SSH channel or modifying .bashrc).
return sessionID, nil return sessionID, nil
} }
// readLoop continuously reads from the session stdout and calls the output // buildHostKeyCallback returns an ssh.HostKeyCallback that implements TOFU
// handler with data. It stops when the reader returns an error (typically EOF // (Trust On First Use) via the HostKeyStore. If no store is configured, it
// when the session closes). // falls back to accepting all keys (for testing/development only).
func (s *SSHService) buildHostKeyCallback(hostname string, port int) ssh.HostKeyCallback {
if s.hostKeyStore == nil {
slog.Warn("no host key store configured — accepting all host keys (INSECURE)")
return ssh.InsecureIgnoreHostKey()
}
store := s.hostKeyStore
return func(remoteHostname string, addr net.Addr, key ssh.PublicKey) error {
keyType := key.Type()
fingerprint := ssh.FingerprintSHA256(key)
result, err := store.Verify(hostname, port, keyType, fingerprint)
if err != nil {
return fmt.Errorf("host key verification failed: %w", err)
}
switch result {
case HostKeyMatch:
slog.Debug("host key verified", "host", hostname, "port", port, "type", keyType)
return nil
case HostKeyNew:
// TOFU: store the key and accept
rawKey := base64.StdEncoding.EncodeToString(key.Marshal())
if storeErr := store.Store(hostname, port, keyType, fingerprint, rawKey); storeErr != nil {
slog.Warn("failed to store new host key", "host", hostname, "error", storeErr)
// Still accept the connection — storing failed but the key itself is fine
}
slog.Info("new host key stored (TOFU)", "host", hostname, "port", port, "type", keyType, "fingerprint", fingerprint)
return nil
case HostKeyChanged:
// REJECT — possible MITM attack
return fmt.Errorf(
"HOST KEY CHANGED for %s:%d (type %s). Expected fingerprint does not match. "+
"This could indicate a man-in-the-middle attack. Connection refused. "+
"If the server was legitimately re-keyed, remove the old key and try again",
hostname, port, keyType,
)
default:
return fmt.Errorf("unknown host key result: %d", result)
}
}
}
// readLoop continuously reads from the session stdout, processes CWD tracking
// (stripping OSC 7 sequences), and calls the output handler with cleaned data.
// It stops when the reader returns an error (typically EOF when the session closes).
func (s *SSHService) readLoop(sessionID string, reader io.Reader) { func (s *SSHService) readLoop(sessionID string, reader io.Reader) {
// Grab the CWD tracker for this session (if any)
s.mu.RLock()
sess := s.sessions[sessionID]
s.mu.RUnlock()
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
for { for {
n, err := reader.Read(buf) n, err := reader.Read(buf)
if n > 0 && s.outputHandler != nil { if n > 0 {
data := make([]byte, n) data := make([]byte, n)
copy(data, buf[:n]) copy(data, buf[:n])
s.outputHandler(sessionID, data)
// Process CWD tracking — strips OSC 7 sequences from output
if sess != nil && sess.CWDTracker != nil {
cleaned, newCWD := sess.CWDTracker.ProcessOutput(data)
data = cleaned
// Emit CWD change event if a new path was detected
if newCWD != "" && s.cwdHandler != nil {
s.cwdHandler(sessionID, newCWD)
}
}
if len(data) > 0 && s.outputHandler != nil {
s.outputHandler(sessionID, data)
}
} }
if err != nil { if err != nil {
break break

View File

@ -11,7 +11,7 @@ import (
) )
func TestNewSSHService(t *testing.T) { func TestNewSSHService(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
if svc == nil { if svc == nil {
t.Fatal("NewSSHService returned nil") t.Fatal("NewSSHService returned nil")
} }
@ -21,7 +21,7 @@ func TestNewSSHService(t *testing.T) {
} }
func TestBuildPasswordAuth(t *testing.T) { func TestBuildPasswordAuth(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
auth := svc.BuildPasswordAuth("mypassword") auth := svc.BuildPasswordAuth("mypassword")
if auth == nil { if auth == nil {
t.Error("BuildPasswordAuth returned nil") t.Error("BuildPasswordAuth returned nil")
@ -29,7 +29,7 @@ func TestBuildPasswordAuth(t *testing.T) {
} }
func TestBuildKeyAuth(t *testing.T) { func TestBuildKeyAuth(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
// Generate a test Ed25519 key // Generate a test Ed25519 key
_, priv, err := ed25519.GenerateKey(rand.Reader) _, priv, err := ed25519.GenerateKey(rand.Reader)
@ -52,7 +52,7 @@ func TestBuildKeyAuth(t *testing.T) {
} }
func TestBuildKeyAuthInvalidKey(t *testing.T) { func TestBuildKeyAuthInvalidKey(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
_, err := svc.BuildKeyAuth([]byte("not a key"), "") _, err := svc.BuildKeyAuth([]byte("not a key"), "")
if err == nil { if err == nil {
t.Error("BuildKeyAuth should fail with invalid key") t.Error("BuildKeyAuth should fail with invalid key")
@ -60,7 +60,7 @@ func TestBuildKeyAuthInvalidKey(t *testing.T) {
} }
func TestSessionTracking(t *testing.T) { func TestSessionTracking(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
// Manually add a session to test tracking // Manually add a session to test tracking
svc.mu.Lock() svc.mu.Lock()
@ -88,7 +88,7 @@ func TestSessionTracking(t *testing.T) {
} }
func TestGetSessionNotFound(t *testing.T) { func TestGetSessionNotFound(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
_, ok := svc.GetSession("nonexistent") _, ok := svc.GetSession("nonexistent")
if ok { if ok {
t.Error("GetSession should return false for nonexistent session") t.Error("GetSession should return false for nonexistent session")
@ -96,7 +96,7 @@ func TestGetSessionNotFound(t *testing.T) {
} }
func TestWriteNotFound(t *testing.T) { func TestWriteNotFound(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
err := svc.Write("nonexistent", "data") err := svc.Write("nonexistent", "data")
if err == nil { if err == nil {
t.Error("Write should fail for nonexistent session") t.Error("Write should fail for nonexistent session")
@ -104,7 +104,7 @@ func TestWriteNotFound(t *testing.T) {
} }
func TestResizeNotFound(t *testing.T) { func TestResizeNotFound(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
err := svc.Resize("nonexistent", 80, 24) err := svc.Resize("nonexistent", 80, 24)
if err == nil { if err == nil {
t.Error("Resize should fail for nonexistent session") t.Error("Resize should fail for nonexistent session")
@ -112,7 +112,7 @@ func TestResizeNotFound(t *testing.T) {
} }
func TestDisconnectNotFound(t *testing.T) { func TestDisconnectNotFound(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
err := svc.Disconnect("nonexistent") err := svc.Disconnect("nonexistent")
if err == nil { if err == nil {
t.Error("Disconnect should fail for nonexistent session") t.Error("Disconnect should fail for nonexistent session")
@ -120,7 +120,7 @@ func TestDisconnectNotFound(t *testing.T) {
} }
func TestDisconnectRemovesSession(t *testing.T) { func TestDisconnectRemovesSession(t *testing.T) {
svc := NewSSHService(nil, nil) svc := NewSSHService(nil, nil, nil, nil)
// Manually add a session with nil Client/Session/Stdin (no real connection) // Manually add a session with nil Client/Session/Stdin (no real connection)
svc.mu.Lock() svc.mu.Lock()