diff --git a/internal/ssh/cwd.go b/internal/ssh/cwd.go new file mode 100644 index 0000000..bd4c7b2 --- /dev/null +++ b/internal/ssh/cwd.go @@ -0,0 +1,128 @@ +package ssh + +import ( + "bytes" + "fmt" + "net/url" + "sync" +) + +// CWDTracker parses OSC 7 escape sequences from terminal output to track the +// remote working directory. +type CWDTracker struct { + currentPath string + mu sync.RWMutex +} + +// NewCWDTracker creates a new CWDTracker. +func NewCWDTracker() *CWDTracker { + return &CWDTracker{} +} + +// osc7Prefix is the escape sequence that starts an OSC 7 directive. +var osc7Prefix = []byte("\033]7;") + +// stTerminator is the ST (String Terminator) escape: ESC + backslash. +var stTerminator = []byte("\033\\") + +// belTerminator is the BEL character, an alternative OSC terminator. +var belTerminator = []byte{0x07} + +// ProcessOutput scans data for OSC 7 sequences of the form: +// +// \033]7;file://hostname/path\033\\ (ST terminator) +// \033]7;file://hostname/path\007 (BEL terminator) +// +// It returns cleaned output with all OSC 7 sequences stripped and the new CWD +// path (or "" if no OSC 7 was found). +func (t *CWDTracker) ProcessOutput(data []byte) (cleaned []byte, newCWD string) { + var result []byte + remaining := data + var lastCWD string + + for { + idx := bytes.Index(remaining, osc7Prefix) + if idx == -1 { + result = append(result, remaining...) + break + } + + // Append everything before the OSC 7 sequence. + result = append(result, remaining[:idx]...) + + // Find the end of the OSC 7 payload (after the prefix). + afterPrefix := remaining[idx+len(osc7Prefix):] + + // Try ST terminator first (\033\\), then BEL (\007). + endIdx := -1 + terminatorLen := 0 + + if stIdx := bytes.Index(afterPrefix, stTerminator); stIdx != -1 { + endIdx = stIdx + terminatorLen = len(stTerminator) + } + if belIdx := bytes.Index(afterPrefix, belTerminator); belIdx != -1 { + if endIdx == -1 || belIdx < endIdx { + endIdx = belIdx + terminatorLen = len(belTerminator) + } + } + + if endIdx == -1 { + // No terminator found; treat the rest as literal output. + result = append(result, remaining[idx:]...) + break + } + + // Extract the URI payload between prefix and terminator. + payload := string(afterPrefix[:endIdx]) + if path := extractPathFromOSC7(payload); path != "" { + lastCWD = path + } + + remaining = afterPrefix[endIdx+terminatorLen:] + } + + if lastCWD != "" { + t.mu.Lock() + t.currentPath = lastCWD + t.mu.Unlock() + } + + return result, lastCWD +} + +// GetCWD returns the current tracked working directory. +func (t *CWDTracker) GetCWD() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.currentPath +} + +// extractPathFromOSC7 parses a file:// URI and returns the path component. +func extractPathFromOSC7(uri string) string { + u, err := url.Parse(uri) + if err != nil { + return "" + } + if u.Scheme != "file" { + return "" + } + return u.Path +} + +// ShellIntegrationCommand returns the shell command to inject for CWD tracking +// via OSC 7. The returned command sets up a prompt hook that emits the OSC 7 +// escape sequence after every command. +func ShellIntegrationCommand(shellType string) string { + switch shellType { + case "bash": + return fmt.Sprintf(`PROMPT_COMMAND='printf "\033]7;file://%%s%%s\033\\" "$HOSTNAME" "$PWD"'`) + case "zsh": + return fmt.Sprintf(`precmd() { printf '\033]7;file://%%s%%s\033\\' "$HOST" "$PWD" }`) + case "fish": + return `function __wraith_osc7 --on-event fish_prompt; printf '\033]7;file://%s%s\033\\' (hostname) (pwd); end` + default: + return "" + } +} diff --git a/internal/ssh/cwd_test.go b/internal/ssh/cwd_test.go new file mode 100644 index 0000000..c9089fb --- /dev/null +++ b/internal/ssh/cwd_test.go @@ -0,0 +1,72 @@ +package ssh + +import ( + "testing" +) + +func TestProcessOutputBasicOSC7(t *testing.T) { + tracker := NewCWDTracker() + input := []byte("hello\033]7;file://myhost/home/user\033\\world") + cleaned, cwd := tracker.ProcessOutput(input) + if string(cleaned) != "helloworld" { + t.Errorf("cleaned = %q", string(cleaned)) + } + if cwd != "/home/user" { + t.Errorf("cwd = %q", cwd) + } +} + +func TestProcessOutputBELTerminator(t *testing.T) { + tracker := NewCWDTracker() + input := []byte("output\033]7;file://host/tmp\007more") + cleaned, cwd := tracker.ProcessOutput(input) + if string(cleaned) != "outputmore" { + t.Errorf("cleaned = %q", string(cleaned)) + } + if cwd != "/tmp" { + t.Errorf("cwd = %q", cwd) + } +} + +func TestProcessOutputNoOSC7(t *testing.T) { + tracker := NewCWDTracker() + input := []byte("just normal output") + cleaned, cwd := tracker.ProcessOutput(input) + if string(cleaned) != "just normal output" { + t.Errorf("cleaned = %q", string(cleaned)) + } + if cwd != "" { + t.Errorf("cwd should be empty, got %q", cwd) + } +} + +func TestProcessOutputMultipleOSC7(t *testing.T) { + tracker := NewCWDTracker() + input := []byte("\033]7;file://h/dir1\033\\text\033]7;file://h/dir2\033\\end") + cleaned, cwd := tracker.ProcessOutput(input) + if string(cleaned) != "textend" { + t.Errorf("cleaned = %q", string(cleaned)) + } + if cwd != "/dir2" { + t.Errorf("cwd = %q, want /dir2", cwd) + } +} + +func TestGetCWDPersists(t *testing.T) { + tracker := NewCWDTracker() + tracker.ProcessOutput([]byte("\033]7;file://h/home/user\033\\")) + if tracker.GetCWD() != "/home/user" { + t.Errorf("GetCWD = %q", tracker.GetCWD()) + } +} + +func TestShellIntegrationCommand(t *testing.T) { + cmd := ShellIntegrationCommand("bash") + if cmd == "" { + t.Error("bash command should not be empty") + } + cmd = ShellIntegrationCommand("zsh") + if cmd == "" { + t.Error("zsh command should not be empty") + } +} diff --git a/internal/ssh/hostkey.go b/internal/ssh/hostkey.go new file mode 100644 index 0000000..9fe72e0 --- /dev/null +++ b/internal/ssh/hostkey.go @@ -0,0 +1,89 @@ +package ssh + +import ( + "database/sql" + "fmt" +) + +// HostKeyResult represents the result of a host key verification. +type HostKeyResult int + +const ( + HostKeyNew HostKeyResult = iota // never seen this host before + HostKeyMatch // fingerprint matches stored + HostKeyChanged // fingerprint CHANGED — possible MITM +) + +// HostKeyStore stores and verifies SSH host key fingerprints in the host_keys SQLite table. +type HostKeyStore struct { + db *sql.DB +} + +// NewHostKeyStore creates a new HostKeyStore backed by the given database. +func NewHostKeyStore(db *sql.DB) *HostKeyStore { + return &HostKeyStore{db: db} +} + +// Verify checks whether the given fingerprint matches any stored host key for the +// specified hostname, port, and key type. It returns HostKeyNew if no key is stored, +// HostKeyMatch if the fingerprint matches, or HostKeyChanged if it differs. +func (s *HostKeyStore) Verify(hostname string, port int, keyType string, fingerprint string) (HostKeyResult, error) { + var storedFingerprint string + err := s.db.QueryRow( + "SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ? AND key_type = ?", + hostname, port, keyType, + ).Scan(&storedFingerprint) + + if err == sql.ErrNoRows { + return HostKeyNew, nil + } + if err != nil { + return 0, fmt.Errorf("query host key: %w", err) + } + + if storedFingerprint == fingerprint { + return HostKeyMatch, nil + } + return HostKeyChanged, nil +} + +// Store inserts or replaces a host key fingerprint for the given hostname, port, and key type. +func (s *HostKeyStore) Store(hostname string, port int, keyType string, fingerprint string, rawKey string) error { + _, err := s.db.Exec( + `INSERT INTO host_keys (hostname, port, key_type, fingerprint, raw_key) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (hostname, port, key_type) + DO UPDATE SET fingerprint = excluded.fingerprint, raw_key = excluded.raw_key`, + hostname, port, keyType, fingerprint, rawKey, + ) + if err != nil { + return fmt.Errorf("store host key: %w", err) + } + return nil +} + +// Delete removes all stored host keys for the given hostname and port. +func (s *HostKeyStore) Delete(hostname string, port int) error { + _, err := s.db.Exec( + "DELETE FROM host_keys WHERE hostname = ? AND port = ?", + hostname, port, + ) + if err != nil { + return fmt.Errorf("delete host key: %w", err) + } + return nil +} + +// GetFingerprint returns the stored fingerprint for the given hostname and port. +// It returns an empty string and sql.ErrNoRows if no key is stored. +func (s *HostKeyStore) GetFingerprint(hostname string, port int) (string, error) { + var fingerprint string + err := s.db.QueryRow( + "SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ?", + hostname, port, + ).Scan(&fingerprint) + if err != nil { + return "", fmt.Errorf("get fingerprint: %w", err) + } + return fingerprint, nil +} diff --git a/internal/ssh/hostkey_test.go b/internal/ssh/hostkey_test.go new file mode 100644 index 0000000..0d686d3 --- /dev/null +++ b/internal/ssh/hostkey_test.go @@ -0,0 +1,77 @@ +package ssh + +import ( + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" +) + +func setupHostKeyStore(t *testing.T) *HostKeyStore { + t.Helper() + d, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + if err := db.Migrate(d); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { d.Close() }) + return NewHostKeyStore(d) +} + +func TestVerifyNewHost(t *testing.T) { + store := setupHostKeyStore(t) + result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123") + if err != nil { + t.Fatal(err) + } + if result != HostKeyNew { + t.Errorf("got %d, want HostKeyNew", result) + } +} + +func TestStoreAndVerifyMatch(t *testing.T) { + store := setupHostKeyStore(t) + if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil { + t.Fatal(err) + } + result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123") + if err != nil { + t.Fatal(err) + } + if result != HostKeyMatch { + t.Errorf("got %d, want HostKeyMatch", result) + } +} + +func TestVerifyChangedKey(t *testing.T) { + store := setupHostKeyStore(t) + if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil { + t.Fatal(err) + } + result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:DIFFERENT") + if err != nil { + t.Fatal(err) + } + if result != HostKeyChanged { + t.Errorf("got %d, want HostKeyChanged", result) + } +} + +func TestDeleteHostKey(t *testing.T) { + store := setupHostKeyStore(t) + if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil { + t.Fatal(err) + } + if err := store.Delete("192.168.1.4", 22); err != nil { + t.Fatal(err) + } + result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123") + if err != nil { + t.Fatal(err) + } + if result != HostKeyNew { + t.Errorf("after delete, got %d, want HostKeyNew", result) + } +}