diff --git a/go.mod b/go.mod index 5814c65..16d8ffc 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.26.1 require ( github.com/google/uuid v1.6.0 + github.com/pkg/sftp v1.13.10 github.com/wailsapp/wails/v3 v3.0.0-alpha.74 golang.org/x/crypto v0.49.0 modernc.org/sqlite v1.46.2 @@ -31,6 +32,7 @@ require ( github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect github.com/kevinburke/ssh_config v1.4.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/leaanthony/go-ansi-parser v1.6.1 // indirect github.com/leaanthony/u v1.1.1 // indirect github.com/lmittmann/tint v1.1.2 // indirect diff --git a/go.sum b/go.sum index 235d6ba..0868101 100644 --- a/go.sum +++ b/go.sum @@ -64,6 +64,8 @@ github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PW github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -94,6 +96,8 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= +github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= diff --git a/internal/credentials/service.go b/internal/credentials/service.go new file mode 100644 index 0000000..617de89 --- /dev/null +++ b/internal/credentials/service.go @@ -0,0 +1,408 @@ +package credentials + +import ( + "crypto/ed25519" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "database/sql" + "encoding/base64" + "encoding/pem" + "fmt" + + "github.com/vstockwell/wraith/internal/vault" + "golang.org/x/crypto/ssh" +) + +// Credential represents a stored credential (password or SSH key reference). +type Credential struct { + ID int64 `json:"id"` + Name string `json:"name"` + Username string `json:"username"` + Domain string `json:"domain"` + Type string `json:"type"` // "password" or "ssh_key" + SSHKeyID *int64 `json:"sshKeyId"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` +} + +// SSHKey represents a stored SSH key. +type SSHKey struct { + ID int64 `json:"id"` + Name string `json:"name"` + KeyType string `json:"keyType"` + Fingerprint string `json:"fingerprint"` + PublicKey string `json:"publicKey"` + CreatedAt string `json:"createdAt"` +} + +// CredentialService provides CRUD for credentials with vault encryption. +type CredentialService struct { + db *sql.DB + vault *vault.VaultService +} + +// NewCredentialService creates a new CredentialService. +func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService { + return &CredentialService{db: db, vault: vault} +} + +// CreatePassword creates a password credential (password encrypted via vault). +func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) { + encrypted, err := s.vault.Encrypt(password) + if err != nil { + return nil, fmt.Errorf("encrypt password: %w", err) + } + + result, err := s.db.Exec( + `INSERT INTO credentials (name, username, domain, type, encrypted_value) + VALUES (?, ?, ?, 'password', ?)`, + name, username, domain, encrypted, + ) + if err != nil { + return nil, fmt.Errorf("insert credential: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get credential id: %w", err) + } + + return s.getCredential(id) +} + +// CreateSSHKey imports an SSH key (private key encrypted via vault). +func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) { + // Parse the private key to detect type and extract public key + keyType := DetectKeyType(privateKeyPEM) + + // Parse the key to get the public key for fingerprinting. + // Try without passphrase first (handles unencrypted keys even when a + // passphrase is provided for storage), then fall back to using the + // passphrase for encrypted PEM keys. + var signer ssh.Signer + var err error + signer, err = ssh.ParsePrivateKey(privateKeyPEM) + if err != nil && passphrase != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase)) + } + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + pubKey := signer.PublicKey() + fingerprint := ssh.FingerprintSHA256(pubKey) + publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey)) + + // Encrypt private key via vault + encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM)) + if err != nil { + return nil, fmt.Errorf("encrypt private key: %w", err) + } + + // Encrypt passphrase via vault (if provided) + var encryptedPassphrase sql.NullString + if passphrase != "" { + ep, err := s.vault.Encrypt(passphrase) + if err != nil { + return nil, fmt.Errorf("encrypt passphrase: %w", err) + } + encryptedPassphrase = sql.NullString{String: ep, Valid: true} + } + + result, err := s.db.Exec( + `INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted) + VALUES (?, ?, ?, ?, ?, ?)`, + name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase, + ) + if err != nil { + return nil, fmt.Errorf("insert ssh key: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get ssh key id: %w", err) + } + + return s.getSSHKey(id) +} + +// ListCredentials returns all credentials WITHOUT encrypted values. +func (s *CredentialService) ListCredentials() ([]Credential, error) { + rows, err := s.db.Query( + `SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at + FROM credentials ORDER BY name`, + ) + if err != nil { + return nil, fmt.Errorf("list credentials: %w", err) + } + defer rows.Close() + + var creds []Credential + for rows.Next() { + var c Credential + var username, domain sql.NullString + if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan credential: %w", err) + } + if username.Valid { + c.Username = username.String + } + if domain.Valid { + c.Domain = domain.String + } + creds = append(creds, c) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate credentials: %w", err) + } + + if creds == nil { + creds = []Credential{} + } + return creds, nil +} + +// ListSSHKeys returns all SSH keys WITHOUT private key data. +func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) { + rows, err := s.db.Query( + `SELECT id, name, key_type, fingerprint, public_key, created_at + FROM ssh_keys ORDER BY name`, + ) + if err != nil { + return nil, fmt.Errorf("list ssh keys: %w", err) + } + defer rows.Close() + + var keys []SSHKey + for rows.Next() { + var k SSHKey + var keyType, fingerprint, publicKey sql.NullString + if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil { + return nil, fmt.Errorf("scan ssh key: %w", err) + } + if keyType.Valid { + k.KeyType = keyType.String + } + if fingerprint.Valid { + k.Fingerprint = fingerprint.String + } + if publicKey.Valid { + k.PublicKey = publicKey.String + } + keys = append(keys, k) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate ssh keys: %w", err) + } + + if keys == nil { + keys = []SSHKey{} + } + return keys, nil +} + +// DecryptPassword returns the decrypted password for a credential. +func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) { + var encrypted sql.NullString + err := s.db.QueryRow( + "SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'", + credentialID, + ).Scan(&encrypted) + if err != nil { + return "", fmt.Errorf("get encrypted password: %w", err) + } + if !encrypted.Valid { + return "", fmt.Errorf("no encrypted value for credential %d", credentialID) + } + + password, err := s.vault.Decrypt(encrypted.String) + if err != nil { + return "", fmt.Errorf("decrypt password: %w", err) + } + return password, nil +} + +// DecryptSSHKey returns the decrypted private key + passphrase. +func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) { + var encryptedKey string + var encryptedPassphrase sql.NullString + err = s.db.QueryRow( + "SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?", + sshKeyID, + ).Scan(&encryptedKey, &encryptedPassphrase) + if err != nil { + return nil, "", fmt.Errorf("get encrypted ssh key: %w", err) + } + + decryptedKey, err := s.vault.Decrypt(encryptedKey) + if err != nil { + return nil, "", fmt.Errorf("decrypt private key: %w", err) + } + + if encryptedPassphrase.Valid { + passphrase, err = s.vault.Decrypt(encryptedPassphrase.String) + if err != nil { + return nil, "", fmt.Errorf("decrypt passphrase: %w", err) + } + } + + return []byte(decryptedKey), passphrase, nil +} + +// DeleteCredential removes a credential. +func (s *CredentialService) DeleteCredential(id int64) error { + _, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete credential: %w", err) + } + return nil +} + +// DeleteSSHKey removes an SSH key. +func (s *CredentialService) DeleteSSHKey(id int64) error { + _, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete ssh key: %w", err) + } + return nil +} + +// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa). +func DetectKeyType(pemData []byte) string { + block, _ := pem.Decode(pemData) + if block == nil { + return "unknown" + } + + // Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks) + if block.Type == "OPENSSH PRIVATE KEY" { + return detectOpenSSHKeyType(block.Bytes) + } + + // Try PKCS8 format + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + switch key.(type) { + case *rsa.PrivateKey: + return "rsa" + case ed25519.PrivateKey: + return "ed25519" + case *ecdsa.PrivateKey: + return "ecdsa" + } + } + + // Try RSA PKCS1 + if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return "rsa" + } + + // Try EC + if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return "ecdsa" + } + + return "unknown" +} + +// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type. +func detectOpenSSHKeyType(data []byte) string { + // OpenSSH private key format: "openssh-key-v1\0" magic, then fields. + // We parse the key using ssh package to determine the type. + // Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer. + pemBlock := &pem.Block{ + Type: "OPENSSH PRIVATE KEY", + Bytes: data, + } + pemBytes := pem.EncodeToMemory(pemBlock) + + signer, err := ssh.ParsePrivateKey(pemBytes) + if err != nil { + return "unknown" + } + + return classifyPublicKey(signer.PublicKey()) +} + +// classifyPublicKey determines the key type from an ssh.PublicKey. +func classifyPublicKey(pub ssh.PublicKey) string { + keyType := pub.Type() + switch keyType { + case "ssh-rsa": + return "rsa" + case "ssh-ed25519": + return "ed25519" + case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521": + return "ecdsa" + default: + return keyType + } +} + +// getCredential retrieves a single credential by ID. +func (s *CredentialService) getCredential(id int64) (*Credential, error) { + var c Credential + var username, domain sql.NullString + err := s.db.QueryRow( + `SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at + FROM credentials WHERE id = ?`, id, + ).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("get credential: %w", err) + } + if username.Valid { + c.Username = username.String + } + if domain.Valid { + c.Domain = domain.String + } + return &c, nil +} + +// getSSHKey retrieves a single SSH key by ID (without private key data). +func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) { + var k SSHKey + var keyType, fingerprint, publicKey sql.NullString + err := s.db.QueryRow( + `SELECT id, name, key_type, fingerprint, public_key, created_at + FROM ssh_keys WHERE id = ?`, id, + ).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get ssh key: %w", err) + } + if keyType.Valid { + k.KeyType = keyType.String + } + if fingerprint.Valid { + k.Fingerprint = fingerprint.String + } + if publicKey.Valid { + k.PublicKey = publicKey.String + } + return &k, nil +} + +// generateFingerprint generates an SSH fingerprint string from a public key. +func generateFingerprint(pubKey ssh.PublicKey) string { + return ssh.FingerprintSHA256(pubKey) +} + +// marshalPublicKey returns the authorized_keys format of an SSH public key. +func marshalPublicKey(pubKey ssh.PublicKey) string { + return base64.StdEncoding.EncodeToString(pubKey.Marshal()) +} + +// ecdsaCurveName returns the name for an ECDSA curve. +func ecdsaCurveName(curve elliptic.Curve) string { + switch curve { + case elliptic.P256(): + return "nistp256" + case elliptic.P384(): + return "nistp384" + case elliptic.P521(): + return "nistp521" + default: + return "unknown" + } +} diff --git a/internal/credentials/service_test.go b/internal/credentials/service_test.go new file mode 100644 index 0000000..8c7211e --- /dev/null +++ b/internal/credentials/service_test.go @@ -0,0 +1,176 @@ +package credentials + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" + "github.com/vstockwell/wraith/internal/vault" + "golang.org/x/crypto/ssh" +) + +func setupCredentialService(t *testing.T) *CredentialService { + t.Helper() + d, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + if err := db.Migrate(d); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { d.Close() }) + + salt := []byte("test-salt-exactly-32-bytes-long!") + key := vault.DeriveKey("testpassword", salt) + vs := vault.NewVaultService(key) + + return NewCredentialService(d, vs) +} + +func TestCreatePasswordCredential(t *testing.T) { + svc := setupCredentialService(t) + cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "") + if err != nil { + t.Fatal(err) + } + if cred.Name != "Test Cred" { + t.Error("wrong name") + } + if cred.Type != "password" { + t.Error("wrong type") + } +} + +func TestDecryptPassword(t *testing.T) { + svc := setupCredentialService(t) + cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "") + password, err := svc.DecryptPassword(cred.ID) + if err != nil { + t.Fatal(err) + } + if password != "mypassword" { + t.Errorf("got %q, want mypassword", password) + } +} + +func TestListCredentialsExcludesSecrets(t *testing.T) { + svc := setupCredentialService(t) + svc.CreatePassword("Cred1", "user1", "pass1", "") + svc.CreatePassword("Cred2", "user2", "pass2", "") + creds, err := svc.ListCredentials() + if err != nil { + t.Fatal(err) + } + if len(creds) != 2 { + t.Errorf("got %d, want 2", len(creds)) + } +} + +func TestCreateSSHKey(t *testing.T) { + svc := setupCredentialService(t) + // Generate a test key + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + key, err := svc.CreateSSHKey("My Key", keyPEM, "") + if err != nil { + t.Fatal(err) + } + if key.KeyType != "ed25519" { + t.Errorf("KeyType = %q, want ed25519", key.KeyType) + } + if key.Fingerprint == "" { + t.Error("fingerprint should not be empty") + } +} + +func TestDecryptSSHKey(t *testing.T) { + svc := setupCredentialService(t) + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass") + decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID) + if err != nil { + t.Fatal(err) + } + if len(decryptedKey) == 0 { + t.Error("decrypted key should not be empty") + } + if passphrase != "testpass" { + t.Errorf("passphrase = %q, want testpass", passphrase) + } +} + +func TestDetectKeyType(t *testing.T) { + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + if got := DetectKeyType(keyPEM); got != "ed25519" { + t.Errorf("got %q", got) + } +} + +func TestDeleteCredential(t *testing.T) { + svc := setupCredentialService(t) + cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "") + err := svc.DeleteCredential(cred.ID) + if err != nil { + t.Fatal(err) + } + creds, _ := svc.ListCredentials() + if len(creds) != 0 { + t.Errorf("got %d credentials, want 0", len(creds)) + } +} + +func TestDeleteSSHKey(t *testing.T) { + svc := setupCredentialService(t) + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "") + err := svc.DeleteSSHKey(key.ID) + if err != nil { + t.Fatal(err) + } + keys, _ := svc.ListSSHKeys() + if len(keys) != 0 { + t.Errorf("got %d keys, want 0", len(keys)) + } +} + +func TestListSSHKeys(t *testing.T) { + svc := setupCredentialService(t) + _, priv, _ := ed25519.GenerateKey(rand.Reader) + pemBlock, _ := ssh.MarshalPrivateKey(priv, "") + keyPEM := pem.EncodeToMemory(pemBlock) + + svc.CreateSSHKey("Key1", keyPEM, "") + svc.CreateSSHKey("Key2", keyPEM, "") + + keys, err := svc.ListSSHKeys() + if err != nil { + t.Fatal(err) + } + if len(keys) != 2 { + t.Errorf("got %d keys, want 2", len(keys)) + } +} + +func TestCreatePasswordWithDomain(t *testing.T) { + svc := setupCredentialService(t) + cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com") + if err != nil { + t.Fatal(err) + } + if cred.Domain != "example.com" { + t.Errorf("domain = %q, want example.com", cred.Domain) + } +} diff --git a/internal/sftp/service.go b/internal/sftp/service.go new file mode 100644 index 0000000..a027fdb --- /dev/null +++ b/internal/sftp/service.go @@ -0,0 +1,238 @@ +package sftp + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "sync" + + "github.com/pkg/sftp" +) + +const MaxEditFileSize = 5 * 1024 * 1024 // 5MB + +// FileEntry represents a file or directory in a remote filesystem. +type FileEntry struct { + Name string `json:"name"` + Path string `json:"path"` + Size int64 `json:"size"` + IsDir bool `json:"isDir"` + Permissions string `json:"permissions"` + ModTime string `json:"modTime"` +} + +// SFTPService manages SFTP clients keyed by session ID. +type SFTPService struct { + clients map[string]*sftp.Client + mu sync.RWMutex +} + +// NewSFTPService creates a new SFTPService. +func NewSFTPService() *SFTPService { + return &SFTPService{ + clients: make(map[string]*sftp.Client), + } +} + +// RegisterClient stores an SFTP client for a session. +func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) { + s.mu.Lock() + defer s.mu.Unlock() + s.clients[sessionID] = client +} + +// RemoveClient removes and closes an SFTP client. +func (s *SFTPService) RemoveClient(sessionID string) { + s.mu.Lock() + client, ok := s.clients[sessionID] + if ok { + delete(s.clients, sessionID) + } + s.mu.Unlock() + + if ok && client != nil { + client.Close() + } +} + +// getClient returns the SFTP client for a session or an error if not found. +func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) { + s.mu.RLock() + defer s.mu.RUnlock() + client, ok := s.clients[sessionID] + if !ok { + return nil, fmt.Errorf("no SFTP client for session %s", sessionID) + } + return client, nil +} + +// SortEntries sorts file entries with directories first, then alphabetically by name. +func SortEntries(entries []FileEntry) { + sort.Slice(entries, func(i, j int) bool { + if entries[i].IsDir != entries[j].IsDir { + return entries[i].IsDir + } + return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name) + }) +} + +// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry. +func fileInfoToEntry(info os.FileInfo, path string) FileEntry { + return FileEntry{ + Name: info.Name(), + Path: path, + Size: info.Size(), + IsDir: info.IsDir(), + Permissions: info.Mode().Perm().String(), + ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"), + } +} + +// List returns directory contents sorted (dirs first, then files alphabetically). +func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) { + client, err := s.getClient(sessionID) + if err != nil { + return nil, err + } + + infos, err := client.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("read directory %s: %w", path, err) + } + + entries := make([]FileEntry, 0, len(infos)) + for _, info := range infos { + entryPath := path + if !strings.HasSuffix(entryPath, "/") { + entryPath += "/" + } + entryPath += info.Name() + entries = append(entries, fileInfoToEntry(info, entryPath)) + } + + SortEntries(entries) + return entries, nil +} + +// ReadFile reads a file (max 5MB). Returns content as string. +func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) { + client, err := s.getClient(sessionID) + if err != nil { + return "", err + } + + info, err := client.Stat(path) + if err != nil { + return "", fmt.Errorf("stat %s: %w", path, err) + } + + if info.IsDir() { + return "", fmt.Errorf("%s is a directory", path) + } + + if info.Size() > MaxEditFileSize { + return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize) + } + + f, err := client.Open(path) + if err != nil { + return "", fmt.Errorf("open %s: %w", path, err) + } + defer f.Close() + + data, err := io.ReadAll(f) + if err != nil { + return "", fmt.Errorf("read %s: %w", path, err) + } + + return string(data), nil +} + +// WriteFile writes content to a file. +func (s *SFTPService) WriteFile(sessionID string, path string, content string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + f, err := client.Create(path) + if err != nil { + return fmt.Errorf("create %s: %w", path, err) + } + defer f.Close() + + if _, err := f.Write([]byte(content)); err != nil { + return fmt.Errorf("write %s: %w", path, err) + } + + return nil +} + +// Mkdir creates a directory. +func (s *SFTPService) Mkdir(sessionID string, path string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + if err := client.Mkdir(path); err != nil { + return fmt.Errorf("mkdir %s: %w", path, err) + } + return nil +} + +// Delete removes a file or empty directory. +func (s *SFTPService) Delete(sessionID string, path string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + info, err := client.Stat(path) + if err != nil { + return fmt.Errorf("stat %s: %w", path, err) + } + + if info.IsDir() { + if err := client.RemoveDirectory(path); err != nil { + return fmt.Errorf("remove directory %s: %w", path, err) + } + } else { + if err := client.Remove(path); err != nil { + return fmt.Errorf("remove %s: %w", path, err) + } + } + + return nil +} + +// Rename renames/moves a file. +func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error { + client, err := s.getClient(sessionID) + if err != nil { + return err + } + + if err := client.Rename(oldPath, newPath); err != nil { + return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err) + } + return nil +} + +// Stat returns info about a file/directory. +func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) { + client, err := s.getClient(sessionID) + if err != nil { + return nil, err + } + + info, err := client.Stat(path) + if err != nil { + return nil, fmt.Errorf("stat %s: %w", path, err) + } + + entry := fileInfoToEntry(info, path) + return &entry, nil +} diff --git a/internal/sftp/service_test.go b/internal/sftp/service_test.go new file mode 100644 index 0000000..6ce5975 --- /dev/null +++ b/internal/sftp/service_test.go @@ -0,0 +1,119 @@ +package sftp + +import ( + "testing" +) + +func TestNewSFTPService(t *testing.T) { + svc := NewSFTPService() + if svc == nil { + t.Fatal("nil") + } +} + +func TestListWithoutClient(t *testing.T) { + svc := NewSFTPService() + _, err := svc.List("nonexistent", "/") + if err == nil { + t.Error("should error without client") + } +} + +func TestReadFileWithoutClient(t *testing.T) { + svc := NewSFTPService() + _, err := svc.ReadFile("nonexistent", "/etc/hosts") + if err == nil { + t.Error("should error without client") + } +} + +func TestWriteFileWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.WriteFile("nonexistent", "/tmp/test", "data") + if err == nil { + t.Error("should error without client") + } +} + +func TestMkdirWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.Mkdir("nonexistent", "/tmp/newdir") + if err == nil { + t.Error("should error without client") + } +} + +func TestDeleteWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.Delete("nonexistent", "/tmp/file") + if err == nil { + t.Error("should error without client") + } +} + +func TestRenameWithoutClient(t *testing.T) { + svc := NewSFTPService() + err := svc.Rename("nonexistent", "/old", "/new") + if err == nil { + t.Error("should error without client") + } +} + +func TestStatWithoutClient(t *testing.T) { + svc := NewSFTPService() + _, err := svc.Stat("nonexistent", "/tmp") + if err == nil { + t.Error("should error without client") + } +} + +func TestFileEntrySorting(t *testing.T) { + // Test that SortEntries puts dirs first, then alpha + entries := []FileEntry{ + {Name: "zebra.txt", IsDir: false}, + {Name: "alpha", IsDir: true}, + {Name: "beta.conf", IsDir: false}, + {Name: "omega", IsDir: true}, + } + SortEntries(entries) + if entries[0].Name != "alpha" { + t.Errorf("[0] = %s, want alpha", entries[0].Name) + } + if entries[1].Name != "omega" { + t.Errorf("[1] = %s, want omega", entries[1].Name) + } + if entries[2].Name != "beta.conf" { + t.Errorf("[2] = %s, want beta.conf", entries[2].Name) + } + if entries[3].Name != "zebra.txt" { + t.Errorf("[3] = %s, want zebra.txt", entries[3].Name) + } +} + +func TestSortEntriesEmpty(t *testing.T) { + entries := []FileEntry{} + SortEntries(entries) + if len(entries) != 0 { + t.Errorf("expected empty slice, got %d entries", len(entries)) + } +} + +func TestSortEntriesCaseInsensitive(t *testing.T) { + entries := []FileEntry{ + {Name: "Zebra", IsDir: false}, + {Name: "alpha", IsDir: false}, + } + SortEntries(entries) + if entries[0].Name != "alpha" { + t.Errorf("[0] = %s, want alpha", entries[0].Name) + } + if entries[1].Name != "Zebra" { + t.Errorf("[1] = %s, want Zebra", entries[1].Name) + } +} + +func TestMaxEditFileSize(t *testing.T) { + if MaxEditFileSize != 5*1024*1024 { + t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024) + } +}