feat: SSH host key verification + OSC 7 CWD tracker
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a75e21138e
commit
cab286b4a6
128
internal/ssh/cwd.go
Normal file
128
internal/ssh/cwd.go
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CWDTracker parses OSC 7 escape sequences from terminal output to track the
|
||||||
|
// remote working directory.
|
||||||
|
type CWDTracker struct {
|
||||||
|
currentPath string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCWDTracker creates a new CWDTracker.
|
||||||
|
func NewCWDTracker() *CWDTracker {
|
||||||
|
return &CWDTracker{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// osc7Prefix is the escape sequence that starts an OSC 7 directive.
|
||||||
|
var osc7Prefix = []byte("\033]7;")
|
||||||
|
|
||||||
|
// stTerminator is the ST (String Terminator) escape: ESC + backslash.
|
||||||
|
var stTerminator = []byte("\033\\")
|
||||||
|
|
||||||
|
// belTerminator is the BEL character, an alternative OSC terminator.
|
||||||
|
var belTerminator = []byte{0x07}
|
||||||
|
|
||||||
|
// ProcessOutput scans data for OSC 7 sequences of the form:
|
||||||
|
//
|
||||||
|
// \033]7;file://hostname/path\033\\ (ST terminator)
|
||||||
|
// \033]7;file://hostname/path\007 (BEL terminator)
|
||||||
|
//
|
||||||
|
// It returns cleaned output with all OSC 7 sequences stripped and the new CWD
|
||||||
|
// path (or "" if no OSC 7 was found).
|
||||||
|
func (t *CWDTracker) ProcessOutput(data []byte) (cleaned []byte, newCWD string) {
|
||||||
|
var result []byte
|
||||||
|
remaining := data
|
||||||
|
var lastCWD string
|
||||||
|
|
||||||
|
for {
|
||||||
|
idx := bytes.Index(remaining, osc7Prefix)
|
||||||
|
if idx == -1 {
|
||||||
|
result = append(result, remaining...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append everything before the OSC 7 sequence.
|
||||||
|
result = append(result, remaining[:idx]...)
|
||||||
|
|
||||||
|
// Find the end of the OSC 7 payload (after the prefix).
|
||||||
|
afterPrefix := remaining[idx+len(osc7Prefix):]
|
||||||
|
|
||||||
|
// Try ST terminator first (\033\\), then BEL (\007).
|
||||||
|
endIdx := -1
|
||||||
|
terminatorLen := 0
|
||||||
|
|
||||||
|
if stIdx := bytes.Index(afterPrefix, stTerminator); stIdx != -1 {
|
||||||
|
endIdx = stIdx
|
||||||
|
terminatorLen = len(stTerminator)
|
||||||
|
}
|
||||||
|
if belIdx := bytes.Index(afterPrefix, belTerminator); belIdx != -1 {
|
||||||
|
if endIdx == -1 || belIdx < endIdx {
|
||||||
|
endIdx = belIdx
|
||||||
|
terminatorLen = len(belTerminator)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if endIdx == -1 {
|
||||||
|
// No terminator found; treat the rest as literal output.
|
||||||
|
result = append(result, remaining[idx:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the URI payload between prefix and terminator.
|
||||||
|
payload := string(afterPrefix[:endIdx])
|
||||||
|
if path := extractPathFromOSC7(payload); path != "" {
|
||||||
|
lastCWD = path
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining = afterPrefix[endIdx+terminatorLen:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastCWD != "" {
|
||||||
|
t.mu.Lock()
|
||||||
|
t.currentPath = lastCWD
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, lastCWD
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCWD returns the current tracked working directory.
|
||||||
|
func (t *CWDTracker) GetCWD() string {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return t.currentPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPathFromOSC7 parses a file:// URI and returns the path component.
|
||||||
|
func extractPathFromOSC7(uri string) string {
|
||||||
|
u, err := url.Parse(uri)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if u.Scheme != "file" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return u.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShellIntegrationCommand returns the shell command to inject for CWD tracking
|
||||||
|
// via OSC 7. The returned command sets up a prompt hook that emits the OSC 7
|
||||||
|
// escape sequence after every command.
|
||||||
|
func ShellIntegrationCommand(shellType string) string {
|
||||||
|
switch shellType {
|
||||||
|
case "bash":
|
||||||
|
return fmt.Sprintf(`PROMPT_COMMAND='printf "\033]7;file://%%s%%s\033\\" "$HOSTNAME" "$PWD"'`)
|
||||||
|
case "zsh":
|
||||||
|
return fmt.Sprintf(`precmd() { printf '\033]7;file://%%s%%s\033\\' "$HOST" "$PWD" }`)
|
||||||
|
case "fish":
|
||||||
|
return `function __wraith_osc7 --on-event fish_prompt; printf '\033]7;file://%s%s\033\\' (hostname) (pwd); end`
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
72
internal/ssh/cwd_test.go
Normal file
72
internal/ssh/cwd_test.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessOutputBasicOSC7(t *testing.T) {
|
||||||
|
tracker := NewCWDTracker()
|
||||||
|
input := []byte("hello\033]7;file://myhost/home/user\033\\world")
|
||||||
|
cleaned, cwd := tracker.ProcessOutput(input)
|
||||||
|
if string(cleaned) != "helloworld" {
|
||||||
|
t.Errorf("cleaned = %q", string(cleaned))
|
||||||
|
}
|
||||||
|
if cwd != "/home/user" {
|
||||||
|
t.Errorf("cwd = %q", cwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessOutputBELTerminator(t *testing.T) {
|
||||||
|
tracker := NewCWDTracker()
|
||||||
|
input := []byte("output\033]7;file://host/tmp\007more")
|
||||||
|
cleaned, cwd := tracker.ProcessOutput(input)
|
||||||
|
if string(cleaned) != "outputmore" {
|
||||||
|
t.Errorf("cleaned = %q", string(cleaned))
|
||||||
|
}
|
||||||
|
if cwd != "/tmp" {
|
||||||
|
t.Errorf("cwd = %q", cwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessOutputNoOSC7(t *testing.T) {
|
||||||
|
tracker := NewCWDTracker()
|
||||||
|
input := []byte("just normal output")
|
||||||
|
cleaned, cwd := tracker.ProcessOutput(input)
|
||||||
|
if string(cleaned) != "just normal output" {
|
||||||
|
t.Errorf("cleaned = %q", string(cleaned))
|
||||||
|
}
|
||||||
|
if cwd != "" {
|
||||||
|
t.Errorf("cwd should be empty, got %q", cwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessOutputMultipleOSC7(t *testing.T) {
|
||||||
|
tracker := NewCWDTracker()
|
||||||
|
input := []byte("\033]7;file://h/dir1\033\\text\033]7;file://h/dir2\033\\end")
|
||||||
|
cleaned, cwd := tracker.ProcessOutput(input)
|
||||||
|
if string(cleaned) != "textend" {
|
||||||
|
t.Errorf("cleaned = %q", string(cleaned))
|
||||||
|
}
|
||||||
|
if cwd != "/dir2" {
|
||||||
|
t.Errorf("cwd = %q, want /dir2", cwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCWDPersists(t *testing.T) {
|
||||||
|
tracker := NewCWDTracker()
|
||||||
|
tracker.ProcessOutput([]byte("\033]7;file://h/home/user\033\\"))
|
||||||
|
if tracker.GetCWD() != "/home/user" {
|
||||||
|
t.Errorf("GetCWD = %q", tracker.GetCWD())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShellIntegrationCommand(t *testing.T) {
|
||||||
|
cmd := ShellIntegrationCommand("bash")
|
||||||
|
if cmd == "" {
|
||||||
|
t.Error("bash command should not be empty")
|
||||||
|
}
|
||||||
|
cmd = ShellIntegrationCommand("zsh")
|
||||||
|
if cmd == "" {
|
||||||
|
t.Error("zsh command should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
89
internal/ssh/hostkey.go
Normal file
89
internal/ssh/hostkey.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostKeyResult represents the result of a host key verification.
|
||||||
|
type HostKeyResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
HostKeyNew HostKeyResult = iota // never seen this host before
|
||||||
|
HostKeyMatch // fingerprint matches stored
|
||||||
|
HostKeyChanged // fingerprint CHANGED — possible MITM
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostKeyStore stores and verifies SSH host key fingerprints in the host_keys SQLite table.
|
||||||
|
type HostKeyStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHostKeyStore creates a new HostKeyStore backed by the given database.
|
||||||
|
func NewHostKeyStore(db *sql.DB) *HostKeyStore {
|
||||||
|
return &HostKeyStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify checks whether the given fingerprint matches any stored host key for the
|
||||||
|
// specified hostname, port, and key type. It returns HostKeyNew if no key is stored,
|
||||||
|
// HostKeyMatch if the fingerprint matches, or HostKeyChanged if it differs.
|
||||||
|
func (s *HostKeyStore) Verify(hostname string, port int, keyType string, fingerprint string) (HostKeyResult, error) {
|
||||||
|
var storedFingerprint string
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
"SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ? AND key_type = ?",
|
||||||
|
hostname, port, keyType,
|
||||||
|
).Scan(&storedFingerprint)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return HostKeyNew, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("query host key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if storedFingerprint == fingerprint {
|
||||||
|
return HostKeyMatch, nil
|
||||||
|
}
|
||||||
|
return HostKeyChanged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store inserts or replaces a host key fingerprint for the given hostname, port, and key type.
|
||||||
|
func (s *HostKeyStore) Store(hostname string, port int, keyType string, fingerprint string, rawKey string) error {
|
||||||
|
_, err := s.db.Exec(
|
||||||
|
`INSERT INTO host_keys (hostname, port, key_type, fingerprint, raw_key)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (hostname, port, key_type)
|
||||||
|
DO UPDATE SET fingerprint = excluded.fingerprint, raw_key = excluded.raw_key`,
|
||||||
|
hostname, port, keyType, fingerprint, rawKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("store host key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes all stored host keys for the given hostname and port.
|
||||||
|
func (s *HostKeyStore) Delete(hostname string, port int) error {
|
||||||
|
_, err := s.db.Exec(
|
||||||
|
"DELETE FROM host_keys WHERE hostname = ? AND port = ?",
|
||||||
|
hostname, port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete host key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFingerprint returns the stored fingerprint for the given hostname and port.
|
||||||
|
// It returns an empty string and sql.ErrNoRows if no key is stored.
|
||||||
|
func (s *HostKeyStore) GetFingerprint(hostname string, port int) (string, error) {
|
||||||
|
var fingerprint string
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
"SELECT fingerprint FROM host_keys WHERE hostname = ? AND port = ?",
|
||||||
|
hostname, port,
|
||||||
|
).Scan(&fingerprint)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get fingerprint: %w", err)
|
||||||
|
}
|
||||||
|
return fingerprint, nil
|
||||||
|
}
|
||||||
77
internal/ssh/hostkey_test.go
Normal file
77
internal/ssh/hostkey_test.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupHostKeyStore(t *testing.T) *HostKeyStore {
|
||||||
|
t.Helper()
|
||||||
|
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(d); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { d.Close() })
|
||||||
|
return NewHostKeyStore(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyNewHost(t *testing.T) {
|
||||||
|
store := setupHostKeyStore(t)
|
||||||
|
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != HostKeyNew {
|
||||||
|
t.Errorf("got %d, want HostKeyNew", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAndVerifyMatch(t *testing.T) {
|
||||||
|
store := setupHostKeyStore(t)
|
||||||
|
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != HostKeyMatch {
|
||||||
|
t.Errorf("got %d, want HostKeyMatch", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyChangedKey(t *testing.T) {
|
||||||
|
store := setupHostKeyStore(t)
|
||||||
|
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:DIFFERENT")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != HostKeyChanged {
|
||||||
|
t.Errorf("got %d, want HostKeyChanged", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteHostKey(t *testing.T) {
|
||||||
|
store := setupHostKeyStore(t)
|
||||||
|
if err := store.Store("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123", "AAAA..."); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := store.Delete("192.168.1.4", 22); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
result, err := store.Verify("192.168.1.4", 22, "ssh-ed25519", "SHA256:abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != HostKeyNew {
|
||||||
|
t.Errorf("after delete, got %d, want HostKeyNew", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user