Merge branch 'worktree-agent-a14ceeb8' into feat/phase2-ssh-sftp

# Conflicts:
#	go.mod
#	go.sum
This commit is contained in:
Vantz Stockwell 2026-03-17 06:56:11 -04:00
commit d05639ef4c
6 changed files with 947 additions and 0 deletions

2
go.mod
View File

@ -4,6 +4,7 @@ go 1.26.1
require ( require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/pkg/sftp v1.13.10
github.com/wailsapp/wails/v3 v3.0.0-alpha.74 github.com/wailsapp/wails/v3 v3.0.0-alpha.74
golang.org/x/crypto v0.49.0 golang.org/x/crypto v0.49.0
modernc.org/sqlite v1.46.2 modernc.org/sqlite v1.46.2
@ -31,6 +32,7 @@ require (
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
github.com/kevinburke/ssh_config v1.4.0 // indirect github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
github.com/leaanthony/u v1.1.1 // indirect github.com/leaanthony/u v1.1.1 // indirect
github.com/lmittmann/tint v1.1.2 // indirect github.com/lmittmann/tint v1.1.2 // indirect

4
go.sum
View File

@ -64,6 +64,8 @@ github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PW
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@ -94,6 +96,8 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=

View File

@ -0,0 +1,408 @@
package credentials
import (
"crypto/ed25519"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/pem"
"fmt"
"github.com/vstockwell/wraith/internal/vault"
"golang.org/x/crypto/ssh"
)
// Credential represents a stored credential (password or SSH key reference).
type Credential struct {
ID int64 `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Domain string `json:"domain"`
Type string `json:"type"` // "password" or "ssh_key"
SSHKeyID *int64 `json:"sshKeyId"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
// SSHKey represents a stored SSH key.
type SSHKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
KeyType string `json:"keyType"`
Fingerprint string `json:"fingerprint"`
PublicKey string `json:"publicKey"`
CreatedAt string `json:"createdAt"`
}
// CredentialService provides CRUD for credentials with vault encryption.
type CredentialService struct {
db *sql.DB
vault *vault.VaultService
}
// NewCredentialService creates a new CredentialService.
func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService {
return &CredentialService{db: db, vault: vault}
}
// CreatePassword creates a password credential (password encrypted via vault).
func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) {
encrypted, err := s.vault.Encrypt(password)
if err != nil {
return nil, fmt.Errorf("encrypt password: %w", err)
}
result, err := s.db.Exec(
`INSERT INTO credentials (name, username, domain, type, encrypted_value)
VALUES (?, ?, ?, 'password', ?)`,
name, username, domain, encrypted,
)
if err != nil {
return nil, fmt.Errorf("insert credential: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get credential id: %w", err)
}
return s.getCredential(id)
}
// CreateSSHKey imports an SSH key (private key encrypted via vault).
func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) {
// Parse the private key to detect type and extract public key
keyType := DetectKeyType(privateKeyPEM)
// Parse the key to get the public key for fingerprinting.
// Try without passphrase first (handles unencrypted keys even when a
// passphrase is provided for storage), then fall back to using the
// passphrase for encrypted PEM keys.
var signer ssh.Signer
var err error
signer, err = ssh.ParsePrivateKey(privateKeyPEM)
if err != nil && passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase))
}
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
pubKey := signer.PublicKey()
fingerprint := ssh.FingerprintSHA256(pubKey)
publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
// Encrypt private key via vault
encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM))
if err != nil {
return nil, fmt.Errorf("encrypt private key: %w", err)
}
// Encrypt passphrase via vault (if provided)
var encryptedPassphrase sql.NullString
if passphrase != "" {
ep, err := s.vault.Encrypt(passphrase)
if err != nil {
return nil, fmt.Errorf("encrypt passphrase: %w", err)
}
encryptedPassphrase = sql.NullString{String: ep, Valid: true}
}
result, err := s.db.Exec(
`INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted)
VALUES (?, ?, ?, ?, ?, ?)`,
name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase,
)
if err != nil {
return nil, fmt.Errorf("insert ssh key: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get ssh key id: %w", err)
}
return s.getSSHKey(id)
}
// ListCredentials returns all credentials WITHOUT encrypted values.
func (s *CredentialService) ListCredentials() ([]Credential, error) {
rows, err := s.db.Query(
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
FROM credentials ORDER BY name`,
)
if err != nil {
return nil, fmt.Errorf("list credentials: %w", err)
}
defer rows.Close()
var creds []Credential
for rows.Next() {
var c Credential
var username, domain sql.NullString
if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan credential: %w", err)
}
if username.Valid {
c.Username = username.String
}
if domain.Valid {
c.Domain = domain.String
}
creds = append(creds, c)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate credentials: %w", err)
}
if creds == nil {
creds = []Credential{}
}
return creds, nil
}
// ListSSHKeys returns all SSH keys WITHOUT private key data.
func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) {
rows, err := s.db.Query(
`SELECT id, name, key_type, fingerprint, public_key, created_at
FROM ssh_keys ORDER BY name`,
)
if err != nil {
return nil, fmt.Errorf("list ssh keys: %w", err)
}
defer rows.Close()
var keys []SSHKey
for rows.Next() {
var k SSHKey
var keyType, fingerprint, publicKey sql.NullString
if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil {
return nil, fmt.Errorf("scan ssh key: %w", err)
}
if keyType.Valid {
k.KeyType = keyType.String
}
if fingerprint.Valid {
k.Fingerprint = fingerprint.String
}
if publicKey.Valid {
k.PublicKey = publicKey.String
}
keys = append(keys, k)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate ssh keys: %w", err)
}
if keys == nil {
keys = []SSHKey{}
}
return keys, nil
}
// DecryptPassword returns the decrypted password for a credential.
func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) {
var encrypted sql.NullString
err := s.db.QueryRow(
"SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'",
credentialID,
).Scan(&encrypted)
if err != nil {
return "", fmt.Errorf("get encrypted password: %w", err)
}
if !encrypted.Valid {
return "", fmt.Errorf("no encrypted value for credential %d", credentialID)
}
password, err := s.vault.Decrypt(encrypted.String)
if err != nil {
return "", fmt.Errorf("decrypt password: %w", err)
}
return password, nil
}
// DecryptSSHKey returns the decrypted private key + passphrase.
func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) {
var encryptedKey string
var encryptedPassphrase sql.NullString
err = s.db.QueryRow(
"SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?",
sshKeyID,
).Scan(&encryptedKey, &encryptedPassphrase)
if err != nil {
return nil, "", fmt.Errorf("get encrypted ssh key: %w", err)
}
decryptedKey, err := s.vault.Decrypt(encryptedKey)
if err != nil {
return nil, "", fmt.Errorf("decrypt private key: %w", err)
}
if encryptedPassphrase.Valid {
passphrase, err = s.vault.Decrypt(encryptedPassphrase.String)
if err != nil {
return nil, "", fmt.Errorf("decrypt passphrase: %w", err)
}
}
return []byte(decryptedKey), passphrase, nil
}
// DeleteCredential removes a credential.
func (s *CredentialService) DeleteCredential(id int64) error {
_, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete credential: %w", err)
}
return nil
}
// DeleteSSHKey removes an SSH key.
func (s *CredentialService) DeleteSSHKey(id int64) error {
_, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete ssh key: %w", err)
}
return nil
}
// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa).
func DetectKeyType(pemData []byte) string {
block, _ := pem.Decode(pemData)
if block == nil {
return "unknown"
}
// Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks)
if block.Type == "OPENSSH PRIVATE KEY" {
return detectOpenSSHKeyType(block.Bytes)
}
// Try PKCS8 format
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
switch key.(type) {
case *rsa.PrivateKey:
return "rsa"
case ed25519.PrivateKey:
return "ed25519"
case *ecdsa.PrivateKey:
return "ecdsa"
}
}
// Try RSA PKCS1
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return "rsa"
}
// Try EC
if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return "ecdsa"
}
return "unknown"
}
// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type.
func detectOpenSSHKeyType(data []byte) string {
// OpenSSH private key format: "openssh-key-v1\0" magic, then fields.
// We parse the key using ssh package to determine the type.
// Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer.
pemBlock := &pem.Block{
Type: "OPENSSH PRIVATE KEY",
Bytes: data,
}
pemBytes := pem.EncodeToMemory(pemBlock)
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return "unknown"
}
return classifyPublicKey(signer.PublicKey())
}
// classifyPublicKey determines the key type from an ssh.PublicKey.
func classifyPublicKey(pub ssh.PublicKey) string {
keyType := pub.Type()
switch keyType {
case "ssh-rsa":
return "rsa"
case "ssh-ed25519":
return "ed25519"
case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521":
return "ecdsa"
default:
return keyType
}
}
// getCredential retrieves a single credential by ID.
func (s *CredentialService) getCredential(id int64) (*Credential, error) {
var c Credential
var username, domain sql.NullString
err := s.db.QueryRow(
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
FROM credentials WHERE id = ?`, id,
).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("get credential: %w", err)
}
if username.Valid {
c.Username = username.String
}
if domain.Valid {
c.Domain = domain.String
}
return &c, nil
}
// getSSHKey retrieves a single SSH key by ID (without private key data).
func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) {
var k SSHKey
var keyType, fingerprint, publicKey sql.NullString
err := s.db.QueryRow(
`SELECT id, name, key_type, fingerprint, public_key, created_at
FROM ssh_keys WHERE id = ?`, id,
).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt)
if err != nil {
return nil, fmt.Errorf("get ssh key: %w", err)
}
if keyType.Valid {
k.KeyType = keyType.String
}
if fingerprint.Valid {
k.Fingerprint = fingerprint.String
}
if publicKey.Valid {
k.PublicKey = publicKey.String
}
return &k, nil
}
// generateFingerprint generates an SSH fingerprint string from a public key.
func generateFingerprint(pubKey ssh.PublicKey) string {
return ssh.FingerprintSHA256(pubKey)
}
// marshalPublicKey returns the authorized_keys format of an SSH public key.
func marshalPublicKey(pubKey ssh.PublicKey) string {
return base64.StdEncoding.EncodeToString(pubKey.Marshal())
}
// ecdsaCurveName returns the name for an ECDSA curve.
func ecdsaCurveName(curve elliptic.Curve) string {
switch curve {
case elliptic.P256():
return "nistp256"
case elliptic.P384():
return "nistp384"
case elliptic.P521():
return "nistp521"
default:
return "unknown"
}
}

View File

@ -0,0 +1,176 @@
package credentials
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
"github.com/vstockwell/wraith/internal/vault"
"golang.org/x/crypto/ssh"
)
func setupCredentialService(t *testing.T) *CredentialService {
t.Helper()
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
if err := db.Migrate(d); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { d.Close() })
salt := []byte("test-salt-exactly-32-bytes-long!")
key := vault.DeriveKey("testpassword", salt)
vs := vault.NewVaultService(key)
return NewCredentialService(d, vs)
}
func TestCreatePasswordCredential(t *testing.T) {
svc := setupCredentialService(t)
cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "")
if err != nil {
t.Fatal(err)
}
if cred.Name != "Test Cred" {
t.Error("wrong name")
}
if cred.Type != "password" {
t.Error("wrong type")
}
}
func TestDecryptPassword(t *testing.T) {
svc := setupCredentialService(t)
cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "")
password, err := svc.DecryptPassword(cred.ID)
if err != nil {
t.Fatal(err)
}
if password != "mypassword" {
t.Errorf("got %q, want mypassword", password)
}
}
func TestListCredentialsExcludesSecrets(t *testing.T) {
svc := setupCredentialService(t)
svc.CreatePassword("Cred1", "user1", "pass1", "")
svc.CreatePassword("Cred2", "user2", "pass2", "")
creds, err := svc.ListCredentials()
if err != nil {
t.Fatal(err)
}
if len(creds) != 2 {
t.Errorf("got %d, want 2", len(creds))
}
}
func TestCreateSSHKey(t *testing.T) {
svc := setupCredentialService(t)
// Generate a test key
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
key, err := svc.CreateSSHKey("My Key", keyPEM, "")
if err != nil {
t.Fatal(err)
}
if key.KeyType != "ed25519" {
t.Errorf("KeyType = %q, want ed25519", key.KeyType)
}
if key.Fingerprint == "" {
t.Error("fingerprint should not be empty")
}
}
func TestDecryptSSHKey(t *testing.T) {
svc := setupCredentialService(t)
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass")
decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID)
if err != nil {
t.Fatal(err)
}
if len(decryptedKey) == 0 {
t.Error("decrypted key should not be empty")
}
if passphrase != "testpass" {
t.Errorf("passphrase = %q, want testpass", passphrase)
}
}
func TestDetectKeyType(t *testing.T) {
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
if got := DetectKeyType(keyPEM); got != "ed25519" {
t.Errorf("got %q", got)
}
}
func TestDeleteCredential(t *testing.T) {
svc := setupCredentialService(t)
cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "")
err := svc.DeleteCredential(cred.ID)
if err != nil {
t.Fatal(err)
}
creds, _ := svc.ListCredentials()
if len(creds) != 0 {
t.Errorf("got %d credentials, want 0", len(creds))
}
}
func TestDeleteSSHKey(t *testing.T) {
svc := setupCredentialService(t)
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "")
err := svc.DeleteSSHKey(key.ID)
if err != nil {
t.Fatal(err)
}
keys, _ := svc.ListSSHKeys()
if len(keys) != 0 {
t.Errorf("got %d keys, want 0", len(keys))
}
}
func TestListSSHKeys(t *testing.T) {
svc := setupCredentialService(t)
_, priv, _ := ed25519.GenerateKey(rand.Reader)
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
keyPEM := pem.EncodeToMemory(pemBlock)
svc.CreateSSHKey("Key1", keyPEM, "")
svc.CreateSSHKey("Key2", keyPEM, "")
keys, err := svc.ListSSHKeys()
if err != nil {
t.Fatal(err)
}
if len(keys) != 2 {
t.Errorf("got %d keys, want 2", len(keys))
}
}
func TestCreatePasswordWithDomain(t *testing.T) {
svc := setupCredentialService(t)
cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com")
if err != nil {
t.Fatal(err)
}
if cred.Domain != "example.com" {
t.Errorf("domain = %q, want example.com", cred.Domain)
}
}

238
internal/sftp/service.go Normal file
View File

@ -0,0 +1,238 @@
package sftp
import (
"fmt"
"io"
"os"
"sort"
"strings"
"sync"
"github.com/pkg/sftp"
)
const MaxEditFileSize = 5 * 1024 * 1024 // 5MB
// FileEntry represents a file or directory in a remote filesystem.
type FileEntry struct {
Name string `json:"name"`
Path string `json:"path"`
Size int64 `json:"size"`
IsDir bool `json:"isDir"`
Permissions string `json:"permissions"`
ModTime string `json:"modTime"`
}
// SFTPService manages SFTP clients keyed by session ID.
type SFTPService struct {
clients map[string]*sftp.Client
mu sync.RWMutex
}
// NewSFTPService creates a new SFTPService.
func NewSFTPService() *SFTPService {
return &SFTPService{
clients: make(map[string]*sftp.Client),
}
}
// RegisterClient stores an SFTP client for a session.
func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients[sessionID] = client
}
// RemoveClient removes and closes an SFTP client.
func (s *SFTPService) RemoveClient(sessionID string) {
s.mu.Lock()
client, ok := s.clients[sessionID]
if ok {
delete(s.clients, sessionID)
}
s.mu.Unlock()
if ok && client != nil {
client.Close()
}
}
// getClient returns the SFTP client for a session or an error if not found.
func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client, ok := s.clients[sessionID]
if !ok {
return nil, fmt.Errorf("no SFTP client for session %s", sessionID)
}
return client, nil
}
// SortEntries sorts file entries with directories first, then alphabetically by name.
func SortEntries(entries []FileEntry) {
sort.Slice(entries, func(i, j int) bool {
if entries[i].IsDir != entries[j].IsDir {
return entries[i].IsDir
}
return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name)
})
}
// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry.
func fileInfoToEntry(info os.FileInfo, path string) FileEntry {
return FileEntry{
Name: info.Name(),
Path: path,
Size: info.Size(),
IsDir: info.IsDir(),
Permissions: info.Mode().Perm().String(),
ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"),
}
}
// List returns directory contents sorted (dirs first, then files alphabetically).
func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) {
client, err := s.getClient(sessionID)
if err != nil {
return nil, err
}
infos, err := client.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("read directory %s: %w", path, err)
}
entries := make([]FileEntry, 0, len(infos))
for _, info := range infos {
entryPath := path
if !strings.HasSuffix(entryPath, "/") {
entryPath += "/"
}
entryPath += info.Name()
entries = append(entries, fileInfoToEntry(info, entryPath))
}
SortEntries(entries)
return entries, nil
}
// ReadFile reads a file (max 5MB). Returns content as string.
func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) {
client, err := s.getClient(sessionID)
if err != nil {
return "", err
}
info, err := client.Stat(path)
if err != nil {
return "", fmt.Errorf("stat %s: %w", path, err)
}
if info.IsDir() {
return "", fmt.Errorf("%s is a directory", path)
}
if info.Size() > MaxEditFileSize {
return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize)
}
f, err := client.Open(path)
if err != nil {
return "", fmt.Errorf("open %s: %w", path, err)
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return "", fmt.Errorf("read %s: %w", path, err)
}
return string(data), nil
}
// WriteFile writes content to a file.
func (s *SFTPService) WriteFile(sessionID string, path string, content string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
f, err := client.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
if _, err := f.Write([]byte(content)); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
return nil
}
// Mkdir creates a directory.
func (s *SFTPService) Mkdir(sessionID string, path string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
if err := client.Mkdir(path); err != nil {
return fmt.Errorf("mkdir %s: %w", path, err)
}
return nil
}
// Delete removes a file or empty directory.
func (s *SFTPService) Delete(sessionID string, path string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
info, err := client.Stat(path)
if err != nil {
return fmt.Errorf("stat %s: %w", path, err)
}
if info.IsDir() {
if err := client.RemoveDirectory(path); err != nil {
return fmt.Errorf("remove directory %s: %w", path, err)
}
} else {
if err := client.Remove(path); err != nil {
return fmt.Errorf("remove %s: %w", path, err)
}
}
return nil
}
// Rename renames/moves a file.
func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error {
client, err := s.getClient(sessionID)
if err != nil {
return err
}
if err := client.Rename(oldPath, newPath); err != nil {
return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err)
}
return nil
}
// Stat returns info about a file/directory.
func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) {
client, err := s.getClient(sessionID)
if err != nil {
return nil, err
}
info, err := client.Stat(path)
if err != nil {
return nil, fmt.Errorf("stat %s: %w", path, err)
}
entry := fileInfoToEntry(info, path)
return &entry, nil
}

View File

@ -0,0 +1,119 @@
package sftp
import (
"testing"
)
func TestNewSFTPService(t *testing.T) {
svc := NewSFTPService()
if svc == nil {
t.Fatal("nil")
}
}
func TestListWithoutClient(t *testing.T) {
svc := NewSFTPService()
_, err := svc.List("nonexistent", "/")
if err == nil {
t.Error("should error without client")
}
}
func TestReadFileWithoutClient(t *testing.T) {
svc := NewSFTPService()
_, err := svc.ReadFile("nonexistent", "/etc/hosts")
if err == nil {
t.Error("should error without client")
}
}
func TestWriteFileWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.WriteFile("nonexistent", "/tmp/test", "data")
if err == nil {
t.Error("should error without client")
}
}
func TestMkdirWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.Mkdir("nonexistent", "/tmp/newdir")
if err == nil {
t.Error("should error without client")
}
}
func TestDeleteWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.Delete("nonexistent", "/tmp/file")
if err == nil {
t.Error("should error without client")
}
}
func TestRenameWithoutClient(t *testing.T) {
svc := NewSFTPService()
err := svc.Rename("nonexistent", "/old", "/new")
if err == nil {
t.Error("should error without client")
}
}
func TestStatWithoutClient(t *testing.T) {
svc := NewSFTPService()
_, err := svc.Stat("nonexistent", "/tmp")
if err == nil {
t.Error("should error without client")
}
}
func TestFileEntrySorting(t *testing.T) {
// Test that SortEntries puts dirs first, then alpha
entries := []FileEntry{
{Name: "zebra.txt", IsDir: false},
{Name: "alpha", IsDir: true},
{Name: "beta.conf", IsDir: false},
{Name: "omega", IsDir: true},
}
SortEntries(entries)
if entries[0].Name != "alpha" {
t.Errorf("[0] = %s, want alpha", entries[0].Name)
}
if entries[1].Name != "omega" {
t.Errorf("[1] = %s, want omega", entries[1].Name)
}
if entries[2].Name != "beta.conf" {
t.Errorf("[2] = %s, want beta.conf", entries[2].Name)
}
if entries[3].Name != "zebra.txt" {
t.Errorf("[3] = %s, want zebra.txt", entries[3].Name)
}
}
func TestSortEntriesEmpty(t *testing.T) {
entries := []FileEntry{}
SortEntries(entries)
if len(entries) != 0 {
t.Errorf("expected empty slice, got %d entries", len(entries))
}
}
func TestSortEntriesCaseInsensitive(t *testing.T) {
entries := []FileEntry{
{Name: "Zebra", IsDir: false},
{Name: "alpha", IsDir: false},
}
SortEntries(entries)
if entries[0].Name != "alpha" {
t.Errorf("[0] = %s, want alpha", entries[0].Name)
}
if entries[1].Name != "Zebra" {
t.Errorf("[1] = %s, want Zebra", entries[1].Name)
}
}
func TestMaxEditFileSize(t *testing.T) {
if MaxEditFileSize != 5*1024*1024 {
t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024)
}
}