feat: SSH host key verification + OSC 7 CWD tracker
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a75e21138e
commit
cab286b4a6
128
internal/ssh/cwd.go
Normal file
128
internal/ssh/cwd.go
Normal file
@ -0,0 +1,128 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CWDTracker parses OSC 7 escape sequences from terminal output to track the
|
||||
// remote working directory.
|
||||
type CWDTracker struct {
|
||||
currentPath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCWDTracker creates a new CWDTracker.
|
||||
func NewCWDTracker() *CWDTracker {
|
||||
return &CWDTracker{}
|
||||
}
|
||||
|
||||
// osc7Prefix is the escape sequence that starts an OSC 7 directive.
|
||||
var osc7Prefix = []byte("\033]7;")
|
||||
|
||||
// stTerminator is the ST (String Terminator) escape: ESC + backslash.
|
||||
var stTerminator = []byte("\033\\")
|
||||
|
||||
// belTerminator is the BEL character, an alternative OSC terminator.
|
||||
var belTerminator = []byte{0x07}
|
||||
|
||||
// ProcessOutput scans data for OSC 7 sequences of the form:
|
||||
//
|
||||
// \033]7;file://hostname/path\033\\ (ST terminator)
|
||||
// \033]7;file://hostname/path\007 (BEL terminator)
|
||||
//
|
||||
// It returns cleaned output with all OSC 7 sequences stripped and the new CWD
|
||||
// path (or "" if no OSC 7 was found).
|
||||
func (t *CWDTracker) ProcessOutput(data []byte) (cleaned []byte, newCWD string) {
|
||||
var result []byte
|
||||
remaining := data
|
||||
var lastCWD string
|
||||
|
||||
for {
|
||||
idx := bytes.Index(remaining, osc7Prefix)
|
||||
if idx == -1 {
|
||||
result = append(result, remaining...)
|
||||
break
|
||||
}
|
||||
|
||||
// Append everything before the OSC 7 sequence.
|
||||
result = append(result, remaining[:idx]...)
|
||||
|
||||
// Find the end of the OSC 7 payload (after the prefix).
|
||||
afterPrefix := remaining[idx+len(osc7Prefix):]
|
||||
|
||||
// Try ST terminator first (\033\\), then BEL (\007).
|
||||
endIdx := -1
|
||||
terminatorLen := 0
|
||||
|
||||
if stIdx := bytes.Index(afterPrefix, stTerminator); stIdx != -1 {
|
||||
endIdx = stIdx
|
||||
terminatorLen = len(stTerminator)
|
||||
}
|
||||
if belIdx := bytes.Index(afterPrefix, belTerminator); belIdx != -1 {
|
||||
if endIdx == -1 || belIdx < endIdx {
|
||||
endIdx = belIdx
|
||||
terminatorLen = len(belTerminator)
|
||||
}
|
||||
}
|
||||
|
||||
if endIdx == -1 {
|
||||
// No terminator found; treat the rest as literal output.
|
||||
result = append(result, remaining[idx:]...)
|
||||
break
|
||||
}
|
||||
|
||||
// Extract the URI payload between prefix and terminator.
|
||||
payload := string(afterPrefix[:endIdx])
|
||||
if path := extractPathFromOSC7(payload); path != "" {
|
||||
lastCWD = path
|
||||
}
|
||||
|
||||
remaining = afterPrefix[endIdx+terminatorLen:]
|
||||
}
|
||||
|
||||
if lastCWD != "" {
|
||||
t.mu.Lock()
|
||||
t.currentPath = lastCWD
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
return result, lastCWD
|
||||
}
|
||||
|
||||
// GetCWD returns the current tracked working directory.
|
||||
func (t *CWDTracker) GetCWD() string {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.currentPath
|
||||
}
|
||||
|
||||
// extractPathFromOSC7 parses a file:// URI and returns the path component.
|
||||
func extractPathFromOSC7(uri string) string {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if u.Scheme != "file" {
|
||||
return ""
|
||||
}
|
||||
return u.Path
|
||||
}
|
||||
|
||||
// ShellIntegrationCommand returns the shell command to inject for CWD tracking
|
||||
// via OSC 7. The returned command sets up a prompt hook that emits the OSC 7
|
||||
// escape sequence after every command.
|
||||
func ShellIntegrationCommand(shellType string) string {
|
||||
switch shellType {
|
||||
case "bash":
|
||||
return fmt.Sprintf(`PROMPT_COMMAND='printf "\033]7;file://%%s%%s\033\\" "$HOSTNAME" "$PWD"'`)
|
||||
case "zsh":
|
||||
return fmt.Sprintf(`precmd() { printf '\033]7;file://%%s%%s\033\\' "$HOST" "$PWD" }`)
|
||||
case "fish":
|
||||
return `function __wraith_osc7 --on-event fish_prompt; printf '\033]7;file://%s%s\033\\' (hostname) (pwd); end`
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
72
internal/ssh/cwd_test.go
Normal file
72
internal/ssh/cwd_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessOutputBasicOSC7(t *testing.T) {
|
||||
tracker := NewCWDTracker()
|
||||
input := []byte("hello\033]7;file://myhost/home/user\033\\world")
|
||||
cleaned, cwd := tracker.ProcessOutput(input)
|
||||
if string(cleaned) != "helloworld" {
|
||||
t.Errorf("cleaned = %q", string(cleaned))
|
||||
}
|
||||
if cwd != "/home/user" {
|
||||
t.Errorf("cwd = %q", cwd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessOutputBELTerminator(t *testing.T) {
|
||||
tracker := NewCWDTracker()
|
||||
input := []byte("output\033]7;file://host/tmp\007more")
|
||||
cleaned, cwd := tracker.ProcessOutput(input)
|
||||
if string(cleaned) != "outputmore" {
|
||||
t.Errorf("cleaned = %q", string(cleaned))
|
||||
}
|
||||
if cwd != "/tmp" {
|
||||
t.Errorf("cwd = %q", cwd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessOutputNoOSC7(t *testing.T) {
|
||||
tracker := NewCWDTracker()
|
||||
input := []byte("just normal output")
|
||||
cleaned, cwd := tracker.ProcessOutput(input)
|
||||
if string(cleaned) != "just normal output" {
|
||||
t.Errorf("cleaned = %q", string(cleaned))
|
||||
}
|
||||
if cwd != "" {
|
||||
t.Errorf("cwd should be empty, got %q", cwd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessOutputMultipleOSC7(t *testing.T) {
|
||||
tracker := NewCWDTracker()
|
||||
input := []byte("\033]7;file://h/dir1\033\\text\033]7;file://h/dir2\033\\end")
|
||||
cleaned, cwd := tracker.ProcessOutput(input)
|
||||
if string(cleaned) != "textend" {
|
||||
t.Errorf("cleaned = %q", string(cleaned))
|
||||
}
|
||||
if cwd != "/dir2" {
|
||||
t.Errorf("cwd = %q, want /dir2", cwd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCWDPersists(t *testing.T) {
|
||||
tracker := NewCWDTracker()
|
||||
tracker.ProcessOutput([]byte("\033]7;file://h/home/user\033\\"))
|
||||
if tracker.GetCWD() != "/home/user" {
|
||||
t.Errorf("GetCWD = %q", tracker.GetCWD())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellIntegrationCommand(t *testing.T) {
|
||||
cmd := ShellIntegrationCommand("bash")
|
||||
if cmd == "" {
|
||||
t.Error("bash command should not be empty")
|
||||
}
|
||||
cmd = ShellIntegrationCommand("zsh")
|
||||
if cmd == "" {
|
||||
t.Error("zsh command should not be empty")
|
||||
}
|
||||
}
|
||||
89
internal/ssh/hostkey.go
Normal file
89
internal/ssh/hostkey.go
Normal file
@ -0,0 +1,89 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// HostKeyResult represents the result of a host key verification.
|
||||
type HostKeyResult int
|
||||
|
||||
const (
|
||||
HostKeyNew HostKeyResult = iota // never seen this host before
|
||||
HostKeyMatch // fingerprint matches stored
|
||||
HostKeyChanged // fingerprint CHANGED — possible MITM
|
||||
)
|
||||
|
||||
// HostKeyStore stores and verifies SSH host key fingerprints in the host_keys SQLite table.
|
||||
type HostKeyStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewHostKeyStore creates a new HostKeyStore backed by the given database.
|
||||
func NewHostKeyStore(db *sql.DB) *HostKeyStore {
|
||||
return &HostKeyStore{db: db}
|
||||
}
|
||||
|
||||
// Verify checks whether the given fingerprint matches any stored host key for the
|
||||
// specified hostname, port, and key type. It returns HostKeyNew if no key is stored,
|
||||
// HostKeyMatch if the fingerprint matches, or HostKeyChanged if it differs.
|
||||
func (s *HostKeyStore) Verify(hostname string, port int, keyType string, fingerprint string) (HostKeyResult, error) {
|
||||
var storedFingerprint string
|
||||
err := s.db.QueryRow(
|
||||
"SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ? AND key_type = ?",
|
||||
hostname, port, keyType,
|
||||
).Scan(&storedFingerprint)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return HostKeyNew, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query host key: %w", err)
|
||||
}
|
||||
|
||||
if storedFingerprint == fingerprint {
|
||||
return HostKeyMatch, nil
|
||||
}
|
||||
return HostKeyChanged, nil
|
||||
}
|
||||
|
||||
// Store inserts or replaces a host key fingerprint for the given hostname, port, and key type.
|
||||
func (s *HostKeyStore) Store(hostname string, port int, keyType string, fingerprint string, rawKey string) error {
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO host_keys (hostname, port, key_type, fingerprint, raw_key)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT (hostname, port, key_type)
|
||||
DO UPDATE SET fingerprint = excluded.fingerprint, raw_key = excluded.raw_key`,
|
||||
hostname, port, keyType, fingerprint, rawKey,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store host key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes all stored host keys for the given hostname and port.
|
||||
func (s *HostKeyStore) Delete(hostname string, port int) error {
|
||||
_, err := s.db.Exec(
|
||||
"DELETE FROM host_keys WHERE hostname = ? AND port = ?",
|
||||
hostname, port,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete host key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFingerprint returns the stored fingerprint for the given hostname and port.
|
||||
// It returns an empty string and sql.ErrNoRows if no key is stored.
|
||||
func (s *HostKeyStore) GetFingerprint(hostname string, port int) (string, error) {
|
||||
var fingerprint string
|
||||
err := s.db.QueryRow(
|
||||
"SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ?",
|
||||
hostname, port,
|
||||
).Scan(&fingerprint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get fingerprint: %w", err)
|
||||
}
|
||||
return fingerprint, nil
|
||||
}
|
||||
77
internal/ssh/hostkey_test.go
Normal file
77
internal/ssh/hostkey_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/vstockwell/wraith/internal/db"
|
||||
)
|
||||
|
||||
func setupHostKeyStore(t *testing.T) *HostKeyStore {
|
||||
t.Helper()
|
||||
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.Migrate(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { d.Close() })
|
||||
return NewHostKeyStore(d)
|
||||
}
|
||||
|
||||
func TestVerifyNewHost(t *testing.T) {
|
||||
store := setupHostKeyStore(t)
|
||||
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != HostKeyNew {
|
||||
t.Errorf("got %d, want HostKeyNew", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreAndVerifyMatch(t *testing.T) {
|
||||
store := setupHostKeyStore(t)
|
||||
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != HostKeyMatch {
|
||||
t.Errorf("got %d, want HostKeyMatch", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyChangedKey(t *testing.T) {
|
||||
store := setupHostKeyStore(t)
|
||||
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:DIFFERENT")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != HostKeyChanged {
|
||||
t.Errorf("got %d, want HostKeyChanged", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteHostKey(t *testing.T) {
|
||||
store := setupHostKeyStore(t)
|
||||
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := store.Delete("192.168.1.4", 22); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != HostKeyNew {
|
||||
t.Errorf("after delete, got %d, want HostKeyNew", result)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user