Merge branch 'worktree-agent-a14ceeb8' into feat/phase2-ssh-sftp
# Conflicts: # go.mod # go.sum
This commit is contained in:
commit
d05639ef4c
2
go.mod
2
go.mod
@ -4,6 +4,7 @@ go 1.26.1
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/pkg/sftp v1.13.10
|
||||||
github.com/wailsapp/wails/v3 v3.0.0-alpha.74
|
github.com/wailsapp/wails/v3 v3.0.0-alpha.74
|
||||||
golang.org/x/crypto v0.49.0
|
golang.org/x/crypto v0.49.0
|
||||||
modernc.org/sqlite v1.46.2
|
modernc.org/sqlite v1.46.2
|
||||||
@ -31,6 +32,7 @@ require (
|
|||||||
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
|
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
|
||||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/kr/fs v0.1.0 // indirect
|
||||||
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
|
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
|
||||||
github.com/leaanthony/u v1.1.1 // indirect
|
github.com/leaanthony/u v1.1.1 // indirect
|
||||||
github.com/lmittmann/tint v1.1.2 // indirect
|
github.com/lmittmann/tint v1.1.2 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@ -64,6 +64,8 @@ github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PW
|
|||||||
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
|
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||||
|
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
@ -94,6 +96,8 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd
|
|||||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU=
|
||||||
|
github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
|
|||||||
408
internal/credentials/service.go
Normal file
408
internal/credentials/service.go
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Credential represents a stored credential (password or SSH key reference).
|
||||||
|
type Credential struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
Type string `json:"type"` // "password" or "ssh_key"
|
||||||
|
SSHKeyID *int64 `json:"sshKeyId"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
UpdatedAt string `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHKey represents a stored SSH key.
|
||||||
|
type SSHKey struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
KeyType string `json:"keyType"`
|
||||||
|
Fingerprint string `json:"fingerprint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CredentialService provides CRUD for credentials with vault encryption.
|
||||||
|
type CredentialService struct {
|
||||||
|
db *sql.DB
|
||||||
|
vault *vault.VaultService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredentialService creates a new CredentialService.
|
||||||
|
func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService {
|
||||||
|
return &CredentialService{db: db, vault: vault}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePassword creates a password credential (password encrypted via vault).
|
||||||
|
func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) {
|
||||||
|
encrypted, err := s.vault.Encrypt(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
`INSERT INTO credentials (name, username, domain, type, encrypted_value)
|
||||||
|
VALUES (?, ?, ?, 'password', ?)`,
|
||||||
|
name, username, domain, encrypted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("insert credential: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get credential id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getCredential(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSSHKey imports an SSH key (private key encrypted via vault).
|
||||||
|
func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) {
|
||||||
|
// Parse the private key to detect type and extract public key
|
||||||
|
keyType := DetectKeyType(privateKeyPEM)
|
||||||
|
|
||||||
|
// Parse the key to get the public key for fingerprinting.
|
||||||
|
// Try without passphrase first (handles unencrypted keys even when a
|
||||||
|
// passphrase is provided for storage), then fall back to using the
|
||||||
|
// passphrase for encrypted PEM keys.
|
||||||
|
var signer ssh.Signer
|
||||||
|
var err error
|
||||||
|
signer, err = ssh.ParsePrivateKey(privateKeyPEM)
|
||||||
|
if err != nil && passphrase != "" {
|
||||||
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := signer.PublicKey()
|
||||||
|
fingerprint := ssh.FingerprintSHA256(pubKey)
|
||||||
|
publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||||
|
|
||||||
|
// Encrypt private key via vault
|
||||||
|
encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt passphrase via vault (if provided)
|
||||||
|
var encryptedPassphrase sql.NullString
|
||||||
|
if passphrase != "" {
|
||||||
|
ep, err := s.vault.Encrypt(passphrase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt passphrase: %w", err)
|
||||||
|
}
|
||||||
|
encryptedPassphrase = sql.NullString{String: ep, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
`INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||||
|
name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("insert ssh key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get ssh key id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.getSSHKey(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListCredentials returns all credentials WITHOUT encrypted values.
|
||||||
|
func (s *CredentialService) ListCredentials() ([]Credential, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
|
||||||
|
FROM credentials ORDER BY name`,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list credentials: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var creds []Credential
|
||||||
|
for rows.Next() {
|
||||||
|
var c Credential
|
||||||
|
var username, domain sql.NullString
|
||||||
|
if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan credential: %w", err)
|
||||||
|
}
|
||||||
|
if username.Valid {
|
||||||
|
c.Username = username.String
|
||||||
|
}
|
||||||
|
if domain.Valid {
|
||||||
|
c.Domain = domain.String
|
||||||
|
}
|
||||||
|
creds = append(creds, c)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if creds == nil {
|
||||||
|
creds = []Credential{}
|
||||||
|
}
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSSHKeys returns all SSH keys WITHOUT private key data.
|
||||||
|
func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
`SELECT id, name, key_type, fingerprint, public_key, created_at
|
||||||
|
FROM ssh_keys ORDER BY name`,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list ssh keys: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var keys []SSHKey
|
||||||
|
for rows.Next() {
|
||||||
|
var k SSHKey
|
||||||
|
var keyType, fingerprint, publicKey sql.NullString
|
||||||
|
if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan ssh key: %w", err)
|
||||||
|
}
|
||||||
|
if keyType.Valid {
|
||||||
|
k.KeyType = keyType.String
|
||||||
|
}
|
||||||
|
if fingerprint.Valid {
|
||||||
|
k.Fingerprint = fingerprint.String
|
||||||
|
}
|
||||||
|
if publicKey.Valid {
|
||||||
|
k.PublicKey = publicKey.String
|
||||||
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate ssh keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keys == nil {
|
||||||
|
keys = []SSHKey{}
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptPassword returns the decrypted password for a credential.
|
||||||
|
func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) {
|
||||||
|
var encrypted sql.NullString
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
"SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'",
|
||||||
|
credentialID,
|
||||||
|
).Scan(&encrypted)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get encrypted password: %w", err)
|
||||||
|
}
|
||||||
|
if !encrypted.Valid {
|
||||||
|
return "", fmt.Errorf("no encrypted value for credential %d", credentialID)
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err := s.vault.Decrypt(encrypted.String)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decrypt password: %w", err)
|
||||||
|
}
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptSSHKey returns the decrypted private key + passphrase.
|
||||||
|
func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) {
|
||||||
|
var encryptedKey string
|
||||||
|
var encryptedPassphrase sql.NullString
|
||||||
|
err = s.db.QueryRow(
|
||||||
|
"SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?",
|
||||||
|
sshKeyID,
|
||||||
|
).Scan(&encryptedKey, &encryptedPassphrase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("get encrypted ssh key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedKey, err := s.vault.Decrypt(encryptedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("decrypt private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if encryptedPassphrase.Valid {
|
||||||
|
passphrase, err = s.vault.Decrypt(encryptedPassphrase.String)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("decrypt passphrase: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(decryptedKey), passphrase, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteCredential removes a credential.
|
||||||
|
func (s *CredentialService) DeleteCredential(id int64) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete credential: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSSHKey removes an SSH key.
|
||||||
|
func (s *CredentialService) DeleteSSHKey(id int64) error {
|
||||||
|
_, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete ssh key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa).
|
||||||
|
func DetectKeyType(pemData []byte) string {
|
||||||
|
block, _ := pem.Decode(pemData)
|
||||||
|
if block == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks)
|
||||||
|
if block.Type == "OPENSSH PRIVATE KEY" {
|
||||||
|
return detectOpenSSHKeyType(block.Bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try PKCS8 format
|
||||||
|
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||||
|
switch key.(type) {
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
return "rsa"
|
||||||
|
case ed25519.PrivateKey:
|
||||||
|
return "ed25519"
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
return "ecdsa"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try RSA PKCS1
|
||||||
|
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||||
|
return "rsa"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try EC
|
||||||
|
if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||||
|
return "ecdsa"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type.
|
||||||
|
func detectOpenSSHKeyType(data []byte) string {
|
||||||
|
// OpenSSH private key format: "openssh-key-v1\0" magic, then fields.
|
||||||
|
// We parse the key using ssh package to determine the type.
|
||||||
|
// Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer.
|
||||||
|
pemBlock := &pem.Block{
|
||||||
|
Type: "OPENSSH PRIVATE KEY",
|
||||||
|
Bytes: data,
|
||||||
|
}
|
||||||
|
pemBytes := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
signer, err := ssh.ParsePrivateKey(pemBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
return classifyPublicKey(signer.PublicKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifyPublicKey determines the key type from an ssh.PublicKey.
|
||||||
|
func classifyPublicKey(pub ssh.PublicKey) string {
|
||||||
|
keyType := pub.Type()
|
||||||
|
switch keyType {
|
||||||
|
case "ssh-rsa":
|
||||||
|
return "rsa"
|
||||||
|
case "ssh-ed25519":
|
||||||
|
return "ed25519"
|
||||||
|
case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521":
|
||||||
|
return "ecdsa"
|
||||||
|
default:
|
||||||
|
return keyType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCredential retrieves a single credential by ID.
|
||||||
|
func (s *CredentialService) getCredential(id int64) (*Credential, error) {
|
||||||
|
var c Credential
|
||||||
|
var username, domain sql.NullString
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
|
||||||
|
FROM credentials WHERE id = ?`, id,
|
||||||
|
).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get credential: %w", err)
|
||||||
|
}
|
||||||
|
if username.Valid {
|
||||||
|
c.Username = username.String
|
||||||
|
}
|
||||||
|
if domain.Valid {
|
||||||
|
c.Domain = domain.String
|
||||||
|
}
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSSHKey retrieves a single SSH key by ID (without private key data).
|
||||||
|
func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) {
|
||||||
|
var k SSHKey
|
||||||
|
var keyType, fingerprint, publicKey sql.NullString
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
`SELECT id, name, key_type, fingerprint, public_key, created_at
|
||||||
|
FROM ssh_keys WHERE id = ?`, id,
|
||||||
|
).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get ssh key: %w", err)
|
||||||
|
}
|
||||||
|
if keyType.Valid {
|
||||||
|
k.KeyType = keyType.String
|
||||||
|
}
|
||||||
|
if fingerprint.Valid {
|
||||||
|
k.Fingerprint = fingerprint.String
|
||||||
|
}
|
||||||
|
if publicKey.Valid {
|
||||||
|
k.PublicKey = publicKey.String
|
||||||
|
}
|
||||||
|
return &k, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateFingerprint generates an SSH fingerprint string from a public key.
|
||||||
|
func generateFingerprint(pubKey ssh.PublicKey) string {
|
||||||
|
return ssh.FingerprintSHA256(pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalPublicKey returns the authorized_keys format of an SSH public key.
|
||||||
|
func marshalPublicKey(pubKey ssh.PublicKey) string {
|
||||||
|
return base64.StdEncoding.EncodeToString(pubKey.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ecdsaCurveName returns the name for an ECDSA curve.
|
||||||
|
func ecdsaCurveName(curve elliptic.Curve) string {
|
||||||
|
switch curve {
|
||||||
|
case elliptic.P256():
|
||||||
|
return "nistp256"
|
||||||
|
case elliptic.P384():
|
||||||
|
return "nistp384"
|
||||||
|
case elliptic.P521():
|
||||||
|
return "nistp521"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
176
internal/credentials/service_test.go
Normal file
176
internal/credentials/service_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/pem"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupCredentialService(t *testing.T) *CredentialService {
|
||||||
|
t.Helper()
|
||||||
|
d, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(d); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { d.Close() })
|
||||||
|
|
||||||
|
salt := []byte("test-salt-exactly-32-bytes-long!")
|
||||||
|
key := vault.DeriveKey("testpassword", salt)
|
||||||
|
vs := vault.NewVaultService(key)
|
||||||
|
|
||||||
|
return NewCredentialService(d, vs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreatePasswordCredential(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cred.Name != "Test Cred" {
|
||||||
|
t.Error("wrong name")
|
||||||
|
}
|
||||||
|
if cred.Type != "password" {
|
||||||
|
t.Error("wrong type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptPassword(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "")
|
||||||
|
password, err := svc.DecryptPassword(cred.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if password != "mypassword" {
|
||||||
|
t.Errorf("got %q, want mypassword", password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListCredentialsExcludesSecrets(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
svc.CreatePassword("Cred1", "user1", "pass1", "")
|
||||||
|
svc.CreatePassword("Cred2", "user2", "pass2", "")
|
||||||
|
creds, err := svc.ListCredentials()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(creds) != 2 {
|
||||||
|
t.Errorf("got %d, want 2", len(creds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSSHKey(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
// Generate a test key
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
key, err := svc.CreateSSHKey("My Key", keyPEM, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if key.KeyType != "ed25519" {
|
||||||
|
t.Errorf("KeyType = %q, want ed25519", key.KeyType)
|
||||||
|
}
|
||||||
|
if key.Fingerprint == "" {
|
||||||
|
t.Error("fingerprint should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptSSHKey(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass")
|
||||||
|
decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(decryptedKey) == 0 {
|
||||||
|
t.Error("decrypted key should not be empty")
|
||||||
|
}
|
||||||
|
if passphrase != "testpass" {
|
||||||
|
t.Errorf("passphrase = %q, want testpass", passphrase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectKeyType(t *testing.T) {
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
if got := DetectKeyType(keyPEM); got != "ed25519" {
|
||||||
|
t.Errorf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCredential(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "")
|
||||||
|
err := svc.DeleteCredential(cred.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
creds, _ := svc.ListCredentials()
|
||||||
|
if len(creds) != 0 {
|
||||||
|
t.Errorf("got %d credentials, want 0", len(creds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteSSHKey(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "")
|
||||||
|
err := svc.DeleteSSHKey(key.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
keys, _ := svc.ListSSHKeys()
|
||||||
|
if len(keys) != 0 {
|
||||||
|
t.Errorf("got %d keys, want 0", len(keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListSSHKeys(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
_, priv, _ := ed25519.GenerateKey(rand.Reader)
|
||||||
|
pemBlock, _ := ssh.MarshalPrivateKey(priv, "")
|
||||||
|
keyPEM := pem.EncodeToMemory(pemBlock)
|
||||||
|
|
||||||
|
svc.CreateSSHKey("Key1", keyPEM, "")
|
||||||
|
svc.CreateSSHKey("Key2", keyPEM, "")
|
||||||
|
|
||||||
|
keys, err := svc.ListSSHKeys()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(keys) != 2 {
|
||||||
|
t.Errorf("got %d keys, want 2", len(keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreatePasswordWithDomain(t *testing.T) {
|
||||||
|
svc := setupCredentialService(t)
|
||||||
|
cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cred.Domain != "example.com" {
|
||||||
|
t.Errorf("domain = %q, want example.com", cred.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
238
internal/sftp/service.go
Normal file
238
internal/sftp/service.go
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
package sftp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/sftp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxEditFileSize = 5 * 1024 * 1024 // 5MB
|
||||||
|
|
||||||
|
// FileEntry represents a file or directory in a remote filesystem.
|
||||||
|
type FileEntry struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
IsDir bool `json:"isDir"`
|
||||||
|
Permissions string `json:"permissions"`
|
||||||
|
ModTime string `json:"modTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFTPService manages SFTP clients keyed by session ID.
|
||||||
|
type SFTPService struct {
|
||||||
|
clients map[string]*sftp.Client
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSFTPService creates a new SFTPService.
|
||||||
|
func NewSFTPService() *SFTPService {
|
||||||
|
return &SFTPService{
|
||||||
|
clients: make(map[string]*sftp.Client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClient stores an SFTP client for a session.
|
||||||
|
func (s *SFTPService) RegisterClient(sessionID string, client *sftp.Client) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.clients[sessionID] = client
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveClient removes and closes an SFTP client.
|
||||||
|
func (s *SFTPService) RemoveClient(sessionID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
client, ok := s.clients[sessionID]
|
||||||
|
if ok {
|
||||||
|
delete(s.clients, sessionID)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if ok && client != nil {
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClient returns the SFTP client for a session or an error if not found.
|
||||||
|
func (s *SFTPService) getClient(sessionID string) (*sftp.Client, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
client, ok := s.clients[sessionID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no SFTP client for session %s", sessionID)
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortEntries sorts file entries with directories first, then alphabetically by name.
|
||||||
|
func SortEntries(entries []FileEntry) {
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
if entries[i].IsDir != entries[j].IsDir {
|
||||||
|
return entries[i].IsDir
|
||||||
|
}
|
||||||
|
return strings.ToLower(entries[i].Name) < strings.ToLower(entries[j].Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileInfoToEntry converts an os.FileInfo and its path into a FileEntry.
|
||||||
|
func fileInfoToEntry(info os.FileInfo, path string) FileEntry {
|
||||||
|
return FileEntry{
|
||||||
|
Name: info.Name(),
|
||||||
|
Path: path,
|
||||||
|
Size: info.Size(),
|
||||||
|
IsDir: info.IsDir(),
|
||||||
|
Permissions: info.Mode().Perm().String(),
|
||||||
|
ModTime: info.ModTime().UTC().Format("2006-01-02T15:04:05Z"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns directory contents sorted (dirs first, then files alphabetically).
|
||||||
|
func (s *SFTPService) List(sessionID string, path string) ([]FileEntry, error) {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
infos, err := client.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read directory %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]FileEntry, 0, len(infos))
|
||||||
|
for _, info := range infos {
|
||||||
|
entryPath := path
|
||||||
|
if !strings.HasSuffix(entryPath, "/") {
|
||||||
|
entryPath += "/"
|
||||||
|
}
|
||||||
|
entryPath += info.Name()
|
||||||
|
entries = append(entries, fileInfoToEntry(info, entryPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
SortEntries(entries)
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFile reads a file (max 5MB). Returns content as string.
|
||||||
|
func (s *SFTPService) ReadFile(sessionID string, path string) (string, error) {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("stat %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
return "", fmt.Errorf("%s is a directory", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Size() > MaxEditFileSize {
|
||||||
|
return "", fmt.Errorf("file %s is %d bytes, exceeds max edit size of %d bytes", path, info.Size(), MaxEditFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := client.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("open %s: %w", path, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFile writes content to a file.
|
||||||
|
func (s *SFTPService) WriteFile(sessionID string, path string, content string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := client.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create %s: %w", path, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := f.Write([]byte(content)); err != nil {
|
||||||
|
return fmt.Errorf("write %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mkdir creates a directory.
|
||||||
|
func (s *SFTPService) Mkdir(sessionID string, path string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Mkdir(path); err != nil {
|
||||||
|
return fmt.Errorf("mkdir %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a file or empty directory.
|
||||||
|
func (s *SFTPService) Delete(sessionID string, path string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stat %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
if err := client.RemoveDirectory(path); err != nil {
|
||||||
|
return fmt.Errorf("remove directory %s: %w", path, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := client.Remove(path); err != nil {
|
||||||
|
return fmt.Errorf("remove %s: %w", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rename renames/moves a file.
|
||||||
|
func (s *SFTPService) Rename(sessionID string, oldPath, newPath string) error {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Rename(oldPath, newPath); err != nil {
|
||||||
|
return fmt.Errorf("rename %s to %s: %w", oldPath, newPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stat returns info about a file/directory.
|
||||||
|
func (s *SFTPService) Stat(sessionID string, path string) (*FileEntry, error) {
|
||||||
|
client, err := s.getClient(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("stat %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := fileInfoToEntry(info, path)
|
||||||
|
return &entry, nil
|
||||||
|
}
|
||||||
119
internal/sftp/service_test.go
Normal file
119
internal/sftp/service_test.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package sftp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSFTPService(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
if svc == nil {
|
||||||
|
t.Fatal("nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
_, err := svc.List("nonexistent", "/")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFileWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
_, err := svc.ReadFile("nonexistent", "/etc/hosts")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteFileWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.WriteFile("nonexistent", "/tmp/test", "data")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMkdirWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.Mkdir("nonexistent", "/tmp/newdir")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.Delete("nonexistent", "/tmp/file")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenameWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
err := svc.Rename("nonexistent", "/old", "/new")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatWithoutClient(t *testing.T) {
|
||||||
|
svc := NewSFTPService()
|
||||||
|
_, err := svc.Stat("nonexistent", "/tmp")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should error without client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileEntrySorting(t *testing.T) {
|
||||||
|
// Test that SortEntries puts dirs first, then alpha
|
||||||
|
entries := []FileEntry{
|
||||||
|
{Name: "zebra.txt", IsDir: false},
|
||||||
|
{Name: "alpha", IsDir: true},
|
||||||
|
{Name: "beta.conf", IsDir: false},
|
||||||
|
{Name: "omega", IsDir: true},
|
||||||
|
}
|
||||||
|
SortEntries(entries)
|
||||||
|
if entries[0].Name != "alpha" {
|
||||||
|
t.Errorf("[0] = %s, want alpha", entries[0].Name)
|
||||||
|
}
|
||||||
|
if entries[1].Name != "omega" {
|
||||||
|
t.Errorf("[1] = %s, want omega", entries[1].Name)
|
||||||
|
}
|
||||||
|
if entries[2].Name != "beta.conf" {
|
||||||
|
t.Errorf("[2] = %s, want beta.conf", entries[2].Name)
|
||||||
|
}
|
||||||
|
if entries[3].Name != "zebra.txt" {
|
||||||
|
t.Errorf("[3] = %s, want zebra.txt", entries[3].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortEntriesEmpty(t *testing.T) {
|
||||||
|
entries := []FileEntry{}
|
||||||
|
SortEntries(entries)
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Errorf("expected empty slice, got %d entries", len(entries))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortEntriesCaseInsensitive(t *testing.T) {
|
||||||
|
entries := []FileEntry{
|
||||||
|
{Name: "Zebra", IsDir: false},
|
||||||
|
{Name: "alpha", IsDir: false},
|
||||||
|
}
|
||||||
|
SortEntries(entries)
|
||||||
|
if entries[0].Name != "alpha" {
|
||||||
|
t.Errorf("[0] = %s, want alpha", entries[0].Name)
|
||||||
|
}
|
||||||
|
if entries[1].Name != "Zebra" {
|
||||||
|
t.Errorf("[1] = %s, want Zebra", entries[1].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxEditFileSize(t *testing.T) {
|
||||||
|
if MaxEditFileSize != 5*1024*1024 {
|
||||||
|
t.Errorf("MaxEditFileSize = %d, want %d", MaxEditFileSize, 5*1024*1024)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user