feat: SSH host key verification + OSC 7 CWD tracker

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Vantz Stockwell 2026-03-17 06:52:59 -04:00
parent a75e21138e
commit cab286b4a6
4 changed files with 366 additions and 0 deletions

128
internal/ssh/cwd.go Normal file
View File

@ -0,0 +1,128 @@
package ssh
import (
"bytes"
"fmt"
"net/url"
"sync"
)
// CWDTracker parses OSC 7 escape sequences from terminal output to track the
// remote working directory.
type CWDTracker struct {
currentPath string
mu sync.RWMutex
}
// NewCWDTracker creates a new CWDTracker.
func NewCWDTracker() *CWDTracker {
return &CWDTracker{}
}
// osc7Prefix is the escape sequence that starts an OSC 7 directive.
var osc7Prefix = []byte("\033]7;")
// stTerminator is the ST (String Terminator) escape: ESC + backslash.
var stTerminator = []byte("\033\\")
// belTerminator is the BEL character, an alternative OSC terminator.
var belTerminator = []byte{0x07}
// ProcessOutput scans data for OSC 7 sequences of the form:
//
// \033]7;file://hostname/path\033\\ (ST terminator)
// \033]7;file://hostname/path\007 (BEL terminator)
//
// It returns cleaned output with all OSC 7 sequences stripped and the new CWD
// path (or "" if no OSC 7 was found).
func (t *CWDTracker) ProcessOutput(data []byte) (cleaned []byte, newCWD string) {
var result []byte
remaining := data
var lastCWD string
for {
idx := bytes.Index(remaining, osc7Prefix)
if idx == -1 {
result = append(result, remaining...)
break
}
// Append everything before the OSC 7 sequence.
result = append(result, remaining[:idx]...)
// Find the end of the OSC 7 payload (after the prefix).
afterPrefix := remaining[idx+len(osc7Prefix):]
// Try ST terminator first (\033\\), then BEL (\007).
endIdx := -1
terminatorLen := 0
if stIdx := bytes.Index(afterPrefix, stTerminator); stIdx != -1 {
endIdx = stIdx
terminatorLen = len(stTerminator)
}
if belIdx := bytes.Index(afterPrefix, belTerminator); belIdx != -1 {
if endIdx == -1 || belIdx < endIdx {
endIdx = belIdx
terminatorLen = len(belTerminator)
}
}
if endIdx == -1 {
// No terminator found; treat the rest as literal output.
result = append(result, remaining[idx:]...)
break
}
// Extract the URI payload between prefix and terminator.
payload := string(afterPrefix[:endIdx])
if path := extractPathFromOSC7(payload); path != "" {
lastCWD = path
}
remaining = afterPrefix[endIdx+terminatorLen:]
}
if lastCWD != "" {
t.mu.Lock()
t.currentPath = lastCWD
t.mu.Unlock()
}
return result, lastCWD
}
// GetCWD returns the current tracked working directory.
func (t *CWDTracker) GetCWD() string {
t.mu.RLock()
defer t.mu.RUnlock()
return t.currentPath
}
// extractPathFromOSC7 parses a file:// URI and returns the path component.
func extractPathFromOSC7(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return ""
}
if u.Scheme != "file" {
return ""
}
return u.Path
}
// ShellIntegrationCommand returns the shell command to inject for CWD tracking
// via OSC 7. The returned command sets up a prompt hook that emits the OSC 7
// escape sequence after every command.
func ShellIntegrationCommand(shellType string) string {
switch shellType {
case "bash":
return fmt.Sprintf(`PROMPT_COMMAND='printf "\033]7;file://%%s%%s\033\\" "$HOSTNAME" "$PWD"'`)
case "zsh":
return fmt.Sprintf(`precmd() { printf '\033]7;file://%%s%%s\033\\' "$HOST" "$PWD" }`)
case "fish":
return `function __wraith_osc7 --on-event fish_prompt; printf '\033]7;file://%s%s\033\\' (hostname) (pwd); end`
default:
return ""
}
}

72
internal/ssh/cwd_test.go Normal file
View File

@ -0,0 +1,72 @@
package ssh
import (
"testing"
)
func TestProcessOutputBasicOSC7(t *testing.T) {
tracker := NewCWDTracker()
input := []byte("hello\033]7;file://myhost/home/user\033\\world")
cleaned, cwd := tracker.ProcessOutput(input)
if string(cleaned) != "helloworld" {
t.Errorf("cleaned = %q", string(cleaned))
}
if cwd != "/home/user" {
t.Errorf("cwd = %q", cwd)
}
}
func TestProcessOutputBELTerminator(t *testing.T) {
tracker := NewCWDTracker()
input := []byte("output\033]7;file://host/tmp\007more")
cleaned, cwd := tracker.ProcessOutput(input)
if string(cleaned) != "outputmore" {
t.Errorf("cleaned = %q", string(cleaned))
}
if cwd != "/tmp" {
t.Errorf("cwd = %q", cwd)
}
}
func TestProcessOutputNoOSC7(t *testing.T) {
tracker := NewCWDTracker()
input := []byte("just normal output")
cleaned, cwd := tracker.ProcessOutput(input)
if string(cleaned) != "just normal output" {
t.Errorf("cleaned = %q", string(cleaned))
}
if cwd != "" {
t.Errorf("cwd should be empty, got %q", cwd)
}
}
func TestProcessOutputMultipleOSC7(t *testing.T) {
tracker := NewCWDTracker()
input := []byte("\033]7;file://h/dir1\033\\text\033]7;file://h/dir2\033\\end")
cleaned, cwd := tracker.ProcessOutput(input)
if string(cleaned) != "textend" {
t.Errorf("cleaned = %q", string(cleaned))
}
if cwd != "/dir2" {
t.Errorf("cwd = %q, want /dir2", cwd)
}
}
func TestGetCWDPersists(t *testing.T) {
tracker := NewCWDTracker()
tracker.ProcessOutput([]byte("\033]7;file://h/home/user\033\\"))
if tracker.GetCWD() != "/home/user" {
t.Errorf("GetCWD = %q", tracker.GetCWD())
}
}
func TestShellIntegrationCommand(t *testing.T) {
cmd := ShellIntegrationCommand("bash")
if cmd == "" {
t.Error("bash command should not be empty")
}
cmd = ShellIntegrationCommand("zsh")
if cmd == "" {
t.Error("zsh command should not be empty")
}
}

89
internal/ssh/hostkey.go Normal file
View File

@ -0,0 +1,89 @@
package ssh
import (
"database/sql"
"fmt"
)
// HostKeyResult represents the result of a host key verification.
type HostKeyResult int
const (
HostKeyNew HostKeyResult = iota // never seen this host before
HostKeyMatch // fingerprint matches stored
HostKeyChanged // fingerprint CHANGED — possible MITM
)
// HostKeyStore stores and verifies SSH host key fingerprints in the host_keys SQLite table.
type HostKeyStore struct {
db *sql.DB
}
// NewHostKeyStore creates a new HostKeyStore backed by the given database.
func NewHostKeyStore(db *sql.DB) *HostKeyStore {
return &HostKeyStore{db: db}
}
// Verify checks whether the given fingerprint matches any stored host key for the
// specified hostname, port, and key type. It returns HostKeyNew if no key is stored,
// HostKeyMatch if the fingerprint matches, or HostKeyChanged if it differs.
func (s *HostKeyStore) Verify(hostname string, port int, keyType string, fingerprint string) (HostKeyResult, error) {
var storedFingerprint string
err := s.db.QueryRow(
"SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ? AND key_type = ?",
hostname, port, keyType,
).Scan(&storedFingerprint)
if err == sql.ErrNoRows {
return HostKeyNew, nil
}
if err != nil {
return 0, fmt.Errorf("query host key: %w", err)
}
if storedFingerprint == fingerprint {
return HostKeyMatch, nil
}
return HostKeyChanged, nil
}
// Store inserts or replaces a host key fingerprint for the given hostname, port, and key type.
func (s *HostKeyStore) Store(hostname string, port int, keyType string, fingerprint string, rawKey string) error {
_, err := s.db.Exec(
`INSERT INTO host_keys (hostname, port, key_type, fingerprint, raw_key)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (hostname, port, key_type)
DO UPDATE SET fingerprint = excluded.fingerprint, raw_key = excluded.raw_key`,
hostname, port, keyType, fingerprint, rawKey,
)
if err != nil {
return fmt.Errorf("store host key: %w", err)
}
return nil
}
// Delete removes all stored host keys for the given hostname and port.
func (s *HostKeyStore) Delete(hostname string, port int) error {
_, err := s.db.Exec(
"DELETE FROM host_keys WHERE hostname = ? AND port = ?",
hostname, port,
)
if err != nil {
return fmt.Errorf("delete host key: %w", err)
}
return nil
}
// GetFingerprint returns the stored fingerprint for the given hostname and port.
// It returns an empty string and sql.ErrNoRows if no key is stored.
func (s *HostKeyStore) GetFingerprint(hostname string, port int) (string, error) {
var fingerprint string
err := s.db.QueryRow(
"SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ?",
hostname, port,
).Scan(&fingerprint)
if err != nil {
return "", fmt.Errorf("get fingerprint: %w", err)
}
return fingerprint, nil
}

View File

@ -0,0 +1,77 @@
package ssh
import (
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
)
func setupHostKeyStore(t *testing.T) *HostKeyStore {
t.Helper()
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
if err := db.Migrate(d); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { d.Close() })
return NewHostKeyStore(d)
}
func TestVerifyNewHost(t *testing.T) {
store := setupHostKeyStore(t)
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
if err != nil {
t.Fatal(err)
}
if result != HostKeyNew {
t.Errorf("got %d, want HostKeyNew", result)
}
}
func TestStoreAndVerifyMatch(t *testing.T) {
store := setupHostKeyStore(t)
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
t.Fatal(err)
}
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
if err != nil {
t.Fatal(err)
}
if result != HostKeyMatch {
t.Errorf("got %d, want HostKeyMatch", result)
}
}
func TestVerifyChangedKey(t *testing.T) {
store := setupHostKeyStore(t)
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
t.Fatal(err)
}
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:DIFFERENT")
if err != nil {
t.Fatal(err)
}
if result != HostKeyChanged {
t.Errorf("got %d, want HostKeyChanged", result)
}
}
func TestDeleteHostKey(t *testing.T) {
store := setupHostKeyStore(t)
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
t.Fatal(err)
}
if err := store.Delete("192.168.1.4", 22); err != nil {
t.Fatal(err)
}
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
if err != nil {
t.Fatal(err)
}
if result != HostKeyNew {
t.Errorf("after delete, got %d, want HostKeyNew", result)
}
}