Merge branch 'worktree-agent-a14ceeb8' into feat/phase2-ssh-sftp
# Conflicts: # go.mod # go.sum
This commit is contained in:
commit
d05639ef4c
2
go.mod
2
go.mod
@ -4,6 +4,7 @@ go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/pkg/sftp v1.13.10
|
||||
github.com/wailsapp/wails/v3 v3.0.0-alpha.74
|
||||
golang.org/x/crypto v0.49.0
|
||||
modernc.org/sqlite v1.46.2
|
||||
@ -31,6 +32,7 @@ require (
|
||||
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
|
||||
github.com/leaanthony/u v1.1.1 // indirect
|
||||
github.com/lmittmann/tint v1.1.2 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@ -64,6 +64,8 @@ github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PW
|
||||
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
@ -94,6 +96,8 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
||||
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
|
||||
408
internal/credentials/service.go
Normal file
408
internal/credentials/service.go
Normal file
@ -0,0 +1,408 @@
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"github.com/vstockwell/wraith/internal/vault"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Credential represents a stored credential (password or SSH key reference).
|
||||
type Credential struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Domain string `json:"domain"`
|
||||
Type string `json:"type"` // "password" or "ssh_key"
|
||||
SSHKeyID *int64 `json:"sshKeyId"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// SSHKey represents a stored SSH key.
|
||||
type SSHKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
KeyType string `json:"keyType"`
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}
|
||||
|
||||
// CredentialService provides CRUD for credentials with vault encryption.
|
||||
type CredentialService struct {
|
||||
db *sql.DB
|
||||
vault *vault.VaultService
|
||||
}
|
||||
|
||||
// NewCredentialService creates a new CredentialService.
|
||||
func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService {
|
||||
return &CredentialService{db: db, vault: vault}
|
||||
}
|
||||
|
||||
// CreatePassword creates a password credential (password encrypted via vault).
|
||||
func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) {
|
||||
encrypted, err := s.vault.Encrypt(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt password: %w", err)
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO credentials (name, username, domain, type, encrypted_value)
|
||||
VALUES (?, ?, ?, 'password', ?)`,
|
||||
name, username, domain, encrypted,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert credential: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get credential id: %w", err)
|
||||
}
|
||||
|
||||
return s.getCredential(id)
|
||||
}
|
||||
|
||||
// CreateSSHKey imports an SSH key (private key encrypted via vault).
|
||||
func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) {
|
||||
// Parse the private key to detect type and extract public key
|
||||
keyType := DetectKeyType(privateKeyPEM)
|
||||
|
||||
// Parse the key to get the public key for fingerprinting.
|
||||
// Try without passphrase first (handles unencrypted keys even when a
|
||||
// passphrase is provided for storage), then fall back to using the
|
||||
// passphrase for encrypted PEM keys.
|
||||
var signer ssh.Signer
|
||||
var err error
|
||||
signer, err = ssh.ParsePrivateKey(privateKeyPEM)
|
||||
if err != nil && passphrase != "" {
|
||||
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
pubKey := signer.PublicKey()
|
||||
fingerprint := ssh.FingerprintSHA256(pubKey)
|
||||
publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||
|
||||
// Encrypt private key via vault
|
||||
encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt private key: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt passphrase via vault (if provided)
|
||||
var encryptedPassphrase sql.NullString
|
||||
if passphrase != "" {
|
||||
ep, err := s.vault.Encrypt(passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt passphrase: %w", err)
|
||||
}
|
||||
encryptedPassphrase = sql.NullString{String: ep, Valid: true}
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert ssh key: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ssh key id: %w", err)
|
||||
}
|
||||
|
||||
return s.getSSHKey(id)
|
||||
}
|
||||
|
||||
// ListCredentials returns all credentials WITHOUT encrypted values.
|
||||
func (s *CredentialService) ListCredentials() ([]Credential, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
|
||||
FROM credentials ORDER BY name`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list credentials: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var creds []Credential
|
||||
for rows.Next() {
|
||||
var c Credential
|
||||
var username, domain sql.NullString
|
||||
if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan credential: %w", err)
|
||||
}
|
||||
if username.Valid {
|
||||
c.Username = username.String
|
||||
}
|
||||
if domain.Valid {
|
||||
c.Domain = domain.String
|
||||
}
|
||||
creds = append(creds, c)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate credentials: %w", err)
|
||||
}
|
||||
|
||||
if creds == nil {
|
||||
creds = []Credential{}
|
||||
}
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// ListSSHKeys returns all SSH keys WITHOUT private key data.
|
||||
func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, key_type, fingerprint, public_key, created_at
|
||||
FROM ssh_keys ORDER BY name`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list ssh keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []SSHKey
|
||||
for rows.Next() {
|
||||
var k SSHKey
|
||||
var keyType, fingerprint, publicKey sql.NullString
|
||||
if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan ssh key: %w", err)
|
||||
}
|
||||
if keyType.Valid {
|
||||
k.KeyType = keyType.String
|
||||
}
|
||||
if fingerprint.Valid {
|
||||
k.Fingerprint = fingerprint.String
|
||||
}
|
||||
if publicKey.Valid {
|
||||
k.PublicKey = publicKey.String
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate ssh keys: %w", err)
|
||||
}
|
||||
|
||||
if keys == nil {
|
||||
keys = []SSHKey{}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// DecryptPassword returns the decrypted password for a credential.
|
||||
func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) {
|
||||
var encrypted sql.NullString
|
||||
err := s.db.QueryRow(
|
||||
"SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'",
|
||||
credentialID,
|
||||
).Scan(&encrypted)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get encrypted password: %w", err)
|
||||
}
|
||||
if !encrypted.Valid {
|
||||
return "", fmt.Errorf("no encrypted value for credential %d", credentialID)
|
||||
}
|
||||
|
||||
password, err := s.vault.Decrypt(encrypted.String)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt password: %w", err)
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
|
||||
// DecryptSSHKey returns the decrypted private key + passphrase.
|
||||
func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) {
|
||||
var encryptedKey string
|
||||
var encryptedPassphrase sql.NullString
|
||||
err = s.db.QueryRow(
|
||||
"SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?",
|
||||
sshKeyID,
|
||||
).Scan(&encryptedKey, &encryptedPassphrase)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("get encrypted ssh key: %w", err)
|
||||
}
|
||||
|
||||
decryptedKey, err := s.vault.Decrypt(encryptedKey)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("decrypt private key: %w", err)
|
||||
}
|
||||
|
||||
if encryptedPassphrase.Valid {
|
||||
passphrase, err = s.vault.Decrypt(encryptedPassphrase.String)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("decrypt passphrase: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(decryptedKey), passphrase, nil
|
||||
}
|
||||
|
||||
// DeleteCredential removes a credential.
|
||||
func (s *CredentialService) DeleteCredential(id int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete credential: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSSHKey removes an SSH key.
|
||||
func (s *CredentialService) DeleteSSHKey(id int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete ssh key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa).
|
||||
func DetectKeyType(pemData []byte) string {
|
||||
block, _ := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks)
|
||||
if block.Type == "OPENSSH PRIVATE KEY" {
|
||||
return detectOpenSSHKeyType(block.Bytes)
|
||||
}
|
||||
|
||||
// Try PKCS8 format
|
||||
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||
switch key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return "rsa"
|
||||
case ed25519.PrivateKey:
|
||||
return "ed25519"
|
||||
case *ecdsa.PrivateKey:
|
||||
return "ecdsa"
|
||||
}
|
||||
}
|
||||
|
||||
// Try RSA PKCS1
|
||||
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
return "rsa"
|
||||
}
|
||||
|
||||
// Try EC
|
||||
if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||
return "ecdsa"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type.
|
||||
func detectOpenSSHKeyType(data []byte) string {
|
||||
// OpenSSH private key format: "openssh-key-v1\0" magic, then fields.
|
||||
// We parse the key using ssh package to determine the type.
|
||||
// Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer.
|
||||
pemBlock := &pem.Block{
|
||||
Type: "OPENSSH PRIVATE KEY",
|
||||
Bytes: data,
|
||||
}
|
||||
pemBytes := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
return classifyPublicKey(signer.PublicKey())
|
||||
}
|
||||
|
||||
// classifyPublicKey determines the key type from an ssh.PublicKey.
|
||||
func classifyPublicKey(pub ssh.PublicKey) string {
|
||||
keyType := pub.Type()
|
||||
switch keyType {
|
||||
case "ssh-rsa":
|
||||
return "rsa"
|
||||
case "ssh-ed25519":
|
||||
return "ed25519"
|
||||
case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521":
|
||||
return "ecdsa"
|
||||
default:
|
||||
return keyType
|
||||
}
|
||||
}
|
||||
|
||||
// getCredential retrieves a single credential by ID.
|
||||
func (s *CredentialService) getCredential(id int64) (*Credential, error) {
|
||||
var c Credential
|
||||
var username, domain sql.NullString
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
|
||||
FROM credentials WHERE id = ?`, id,
|
||||
).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get credential: %w", err)
|
||||
}
|
||||
if username.Valid {
|
||||
c.Username = username.String
|
||||
}
|
||||
if domain.Valid {
|
||||
c.Domain = domain.String
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// getSSHKey retrieves a single SSH key by ID (without private key data).
|
||||
func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) {
|
||||
var k SSHKey
|
||||
var keyType, fingerprint, publicKey sql.NullString
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, name, key_type, fingerprint, public_key, created_at
|
||||
FROM ssh_keys WHERE id = ?`, id,
|
||||
).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ssh key: %w", err)
|
||||
}
|
||||
if keyType.Valid {
|
||||
k.KeyType = keyType.String
|
||||
}
|
||||
if fingerprint.Valid {
|
||||
k.Fingerprint = fingerprint.String
|
||||
}
|
||||
if publicKey.Valid {
|
||||
k.PublicKey = publicKey.String
|
||||
}
|
||||
return &k, nil
|
||||
}
|
||||
|
||||
// generateFingerprint generates an SSH fingerprint string from a public key.
|
||||
func generateFingerprint(pubKey ssh.PublicKey) string {
|
||||
return ssh.FingerprintSHA256(pubKey)
|
||||
}
|
||||
|
||||
// marshalPublicKey returns the authorized_keys format of an SSH public key.
|
||||
func marshalPublicKey(pubKey ssh.PublicKey) string {
|
||||
return base64.StdEncoding.EncodeToString(pubKey.Marshal())
|
||||
}
|
||||
|
||||
// ecdsaCurveName returns the name for an ECDSA curve.
|
||||
func ecdsaCurveName(curve elliptic.Curve) string {
|
||||
switch curve {
|
||||
case elliptic.P256():
|
||||
return "nistp256"
|
||||
case elliptic.P384():
|
||||
return "nistp384"
|
||||
case elliptic.P521():
|
||||
return "nistp521"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
176
internal/credentials/service_test.go
Normal file
176
internal/credentials/service_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/vstockwell/wraith/internal/db"
|
||||
"github.com/vstockwell/wraith/internal/vault"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func setupCredentialService(t *testing.T) *CredentialService {
|
||||
t.Helper()
|
||||
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.Migrate(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { d.Close() })
|
||||
|
||||
salt := []byte("test-salt-exactly-32-bytes-long!")
|
||||
key := vault.DeriveKey("testpassword", salt)
|
||||
vs := vault.NewVaultService(key)
|
||||
|
||||
return NewCredentialService(d, vs)
|
||||
}
|
||||
|
||||
func TestCreatePasswordCredential(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cred.Name != "Test Cred" {
|
||||
t.Error("wrong name")
|
||||
}
|
||||
if cred.Type != "password" {
|
||||
t.Error("wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptPassword(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "")
|
||||
password, err := svc.DecryptPassword(cred.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if password != "mypassword" {
|
||||
t.Errorf("got %q, want mypassword", password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCredentialsExcludesSecrets(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
svc.CreatePassword("Cred1", "user1", "pass1", "")
|
||||
svc.CreatePassword("Cred2", "user2", "pass2", "")
|
||||
creds, err := svc.ListCredentials()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(creds) != 2 {
|
||||
t.Errorf("got %d, want 2", len(creds))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSSHKey(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
// Generate a test key
|
||||
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
key, err := svc.CreateSSHKey("My Key", keyPEM, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if key.KeyType != "ed25519" {
|
||||
t.Errorf("KeyType = %q, want ed25519", key.KeyType)
|
||||
}
|
||||
if key.Fingerprint == "" {
|
||||
t.Error("fingerprint should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptSSHKey(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass")
|
||||
decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(decryptedKey) == 0 {
|
||||
t.Error("decrypted key should not be empty")
|
||||
}
|
||||
if passphrase != "testpass" {
|
||||
t.Errorf("passphrase = %q, want testpass", passphrase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectKeyType(t *testing.T) {
|
||||
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||
if got := DetectKeyType(keyPEM); got != "ed25519" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteCredential(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "")
|
||||
err := svc.DeleteCredential(cred.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
creds, _ := svc.ListCredentials()
|
||||
if len(creds) != 0 {
|
||||
t.Errorf("got %d credentials, want 0", len(creds))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSSHKey(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "")
|
||||
err := svc.DeleteSSHKey(key.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keys, _ := svc.ListSSHKeys()
|
||||
if len(keys) != 0 {
|
||||
t.Errorf("got %d keys, want 0", len(keys))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSSHKeys(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||
|
||||
svc.CreateSSHKey("Key1", keyPEM, "")
|
||||
svc.CreateSSHKey("Key2", keyPEM, "")
|
||||
|
||||
keys, err := svc.ListSSHKeys()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(keys) != 2 {
|
||||
t.Errorf("got %d keys, want 2", len(keys))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePasswordWithDomain(t *testing.T) {
|
||||
svc := setupCredentialService(t)
|
||||
cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cred.Domain != "example.com" {
|
||||
t.Errorf("domain = %q, want example.com", cred.Domain)
|
||||
}
|
||||
}
|
||||
238
internal/sftp/service.go
Normal file
238
internal/sftp/service.go
Normal file
@ -0,0 +1,238 @@
|
||||
package sftp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
)
|
||||
|
||||
const MaxEditFileSize = 5 * 1024 * 1024 // 5MB
|
||||
|
||||
// FileEntry represents a file or directory in a remote filesystem.
|
||||
type FileEntry struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Size int64 `json:"size"`
|
||||
IsDir bool `json:"isDir"`
|
||||
Permissions string `json:"permissions"`
|
||||
ModTime string `json:"modTime"`
|
||||
}
|
||||
|
||||
// SFTPService manages SFTP clients keyed by session ID.
|
||||
type SFTPService struct {
|
||||
clients map[string]*sftp.Client
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSFTPService creates a new SFTPService.
|
||||
func NewSFTPService() *SFTPService {
|
||||
return &SFTPService{
|
||||
clients: make(map[string]*sftp.Client),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClient stores an SFTP client for a session.
|
||||
func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.clients[sessionID] = client
|
||||
}
|
||||
|
||||
// RemoveClient removes and closes an SFTP client.
|
||||
func (s *SFTPService) RemoveClient(sessionID string) {
|
||||
s.mu.Lock()
|
||||
client, ok := s.clients[sessionID]
|
||||
if ok {
|
||||
delete(s.clients, sessionID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if ok && client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// getClient returns the SFTP client for a session or an error if not found.
|
||||
func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
client, ok := s.clients[sessionID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no SFTP client for session %s", sessionID)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// SortEntries sorts file entries with directories first, then alphabetically by name.
|
||||
func SortEntries(entries []FileEntry) {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
if entries[i].IsDir != entries[j].IsDir {
|
||||
return entries[i].IsDir
|
||||
}
|
||||
return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name)
|
||||
})
|
||||
}
|
||||
|
||||
// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry.
|
||||
func fileInfoToEntry(info os.FileInfo, path string) FileEntry {
|
||||
return FileEntry{
|
||||
Name: info.Name(),
|
||||
Path: path,
|
||||
Size: info.Size(),
|
||||
IsDir: info.IsDir(),
|
||||
Permissions: info.Mode().Perm().String(),
|
||||
ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
}
|
||||
|
||||
// List returns directory contents sorted (dirs first, then files alphabetically).
|
||||
func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
infos, err := client.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read directory %s: %w", path, err)
|
||||
}
|
||||
|
||||
entries := make([]FileEntry, 0, len(infos))
|
||||
for _, info := range infos {
|
||||
entryPath := path
|
||||
if !strings.HasSuffix(entryPath, "/") {
|
||||
entryPath += "/"
|
||||
}
|
||||
entryPath += info.Name()
|
||||
entries = append(entries, fileInfoToEntry(info, entryPath))
|
||||
}
|
||||
|
||||
SortEntries(entries)
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// ReadFile reads a file (max 5MB). Returns content as string.
|
||||
func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
info, err := client.Stat(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stat %s: %w", path, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return "", fmt.Errorf("%s is a directory", path)
|
||||
}
|
||||
|
||||
if info.Size() > MaxEditFileSize {
|
||||
return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize)
|
||||
}
|
||||
|
||||
f, err := client.Open(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// WriteFile writes content to a file.
|
||||
func (s *SFTPService) WriteFile(sessionID string, path string, content string) error {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := client.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Write([]byte(content)); err != nil {
|
||||
return fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mkdir creates a directory.
|
||||
func (s *SFTPService) Mkdir(sessionID string, path string) error {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := client.Mkdir(path); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file or empty directory.
|
||||
func (s *SFTPService) Delete(sessionID string, path string) error {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := client.Stat(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", path, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
if err := client.RemoveDirectory(path); err != nil {
|
||||
return fmt.Errorf("remove directory %s: %w", path, err)
|
||||
}
|
||||
} else {
|
||||
if err := client.Remove(path); err != nil {
|
||||
return fmt.Errorf("remove %s: %w", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rename renames/moves a file.
|
||||
func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := client.Rename(oldPath, newPath); err != nil {
|
||||
return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stat returns info about a file/directory.
|
||||
func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) {
|
||||
client, err := s.getClient(sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info, err := client.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", path, err)
|
||||
}
|
||||
|
||||
entry := fileInfoToEntry(info, path)
|
||||
return &entry, nil
|
||||
}
|
||||
119
internal/sftp/service_test.go
Normal file
119
internal/sftp/service_test.go
Normal file
@ -0,0 +1,119 @@
|
||||
package sftp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewSFTPService(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
if svc == nil {
|
||||
t.Fatal("nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
_, err := svc.List("nonexistent", "/")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
_, err := svc.ReadFile("nonexistent", "/etc/hosts")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFileWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
err := svc.WriteFile("nonexistent", "/tmp/test", "data")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMkdirWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
err := svc.Mkdir("nonexistent", "/tmp/newdir")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
err := svc.Delete("nonexistent", "/tmp/file")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
err := svc.Rename("nonexistent", "/old", "/new")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatWithoutClient(t *testing.T) {
|
||||
svc := NewSFTPService()
|
||||
_, err := svc.Stat("nonexistent", "/tmp")
|
||||
if err == nil {
|
||||
t.Error("should error without client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileEntrySorting(t *testing.T) {
|
||||
// Test that SortEntries puts dirs first, then alpha
|
||||
entries := []FileEntry{
|
||||
{Name: "zebra.txt", IsDir: false},
|
||||
{Name: "alpha", IsDir: true},
|
||||
{Name: "beta.conf", IsDir: false},
|
||||
{Name: "omega", IsDir: true},
|
||||
}
|
||||
SortEntries(entries)
|
||||
if entries[0].Name != "alpha" {
|
||||
t.Errorf("[0] = %s, want alpha", entries[0].Name)
|
||||
}
|
||||
if entries[1].Name != "omega" {
|
||||
t.Errorf("[1] = %s, want omega", entries[1].Name)
|
||||
}
|
||||
if entries[2].Name != "beta.conf" {
|
||||
t.Errorf("[2] = %s, want beta.conf", entries[2].Name)
|
||||
}
|
||||
if entries[3].Name != "zebra.txt" {
|
||||
t.Errorf("[3] = %s, want zebra.txt", entries[3].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortEntriesEmpty(t *testing.T) {
|
||||
entries := []FileEntry{}
|
||||
SortEntries(entries)
|
||||
if len(entries) != 0 {
|
||||
t.Errorf("expected empty slice, got %d entries", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortEntriesCaseInsensitive(t *testing.T) {
|
||||
entries := []FileEntry{
|
||||
{Name: "Zebra", IsDir: false},
|
||||
{Name: "alpha", IsDir: false},
|
||||
}
|
||||
SortEntries(entries)
|
||||
if entries[0].Name != "alpha" {
|
||||
t.Errorf("[0] = %s, want alpha", entries[0].Name)
|
||||
}
|
||||
if entries[1].Name != "Zebra" {
|
||||
t.Errorf("[1] = %s, want Zebra", entries[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxEditFileSize(t *testing.T) {
|
||||
if MaxEditFileSize != 5*1024*1024 {
|
||||
t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user