package credentials import ( "crypto/ed25519" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" "crypto/x509" "database/sql" "encoding/base64" "encoding/pem" "fmt" "github.com/vstockwell/wraith/internal/vault" "golang.org/x/crypto/ssh" ) // Credential represents a stored credential (password or SSH key reference). type Credential struct { ID int64 `json:"id"` Name string `json:"name"` Username string `json:"username"` Domain string `json:"domain"` Type string `json:"type"` // "password" or "ssh_key" SSHKeyID *int64 `json:"sshKeyId"` CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` } // SSHKey represents a stored SSH key. type SSHKey struct { ID int64 `json:"id"` Name string `json:"name"` KeyType string `json:"keyType"` Fingerprint string `json:"fingerprint"` PublicKey string `json:"publicKey"` CreatedAt string `json:"createdAt"` } // CredentialService provides CRUD for credentials with vault encryption. type CredentialService struct { db *sql.DB vault *vault.VaultService } // NewCredentialService creates a new CredentialService. func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService { return &CredentialService{db: db, vault: vault} } // CreatePassword creates a password credential (password encrypted via vault). func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) { encrypted, err := s.vault.Encrypt(password) if err != nil { return nil, fmt.Errorf("encrypt password: %w", err) } result, err := s.db.Exec( `INSERT INTO credentials (name, username, domain, type, encrypted_value) VALUES (?, ?, ?, 'password', ?)`, name, username, domain, encrypted, ) if err != nil { return nil, fmt.Errorf("insert credential: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("get credential id: %w", err) } return s.getCredential(id) } // CreateSSHKey imports an SSH key (private key encrypted via vault). func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) { // Parse the private key to detect type and extract public key keyType := DetectKeyType(privateKeyPEM) // Parse the key to get the public key for fingerprinting. // Try without passphrase first (handles unencrypted keys even when a // passphrase is provided for storage), then fall back to using the // passphrase for encrypted PEM keys. var signer ssh.Signer var err error signer, err = ssh.ParsePrivateKey(privateKeyPEM) if err != nil && passphrase != "" { signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase)) } if err != nil { return nil, fmt.Errorf("parse private key: %w", err) } pubKey := signer.PublicKey() fingerprint := ssh.FingerprintSHA256(pubKey) publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey)) // Encrypt private key via vault encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM)) if err != nil { return nil, fmt.Errorf("encrypt private key: %w", err) } // Encrypt passphrase via vault (if provided) var encryptedPassphrase sql.NullString if passphrase != "" { ep, err := s.vault.Encrypt(passphrase) if err != nil { return nil, fmt.Errorf("encrypt passphrase: %w", err) } encryptedPassphrase = sql.NullString{String: ep, Valid: true} } result, err := s.db.Exec( `INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted) VALUES (?, ?, ?, ?, ?, ?)`, name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase, ) if err != nil { return nil, fmt.Errorf("insert ssh key: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("get ssh key id: %w", err) } return s.getSSHKey(id) } // ListCredentials returns all credentials WITHOUT encrypted values. func (s *CredentialService) ListCredentials() ([]Credential, error) { rows, err := s.db.Query( `SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at FROM credentials ORDER BY name`, ) if err != nil { return nil, fmt.Errorf("list credentials: %w", err) } defer rows.Close() var creds []Credential for rows.Next() { var c Credential var username, domain sql.NullString if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil { return nil, fmt.Errorf("scan credential: %w", err) } if username.Valid { c.Username = username.String } if domain.Valid { c.Domain = domain.String } creds = append(creds, c) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate credentials: %w", err) } if creds == nil { creds = []Credential{} } return creds, nil } // ListSSHKeys returns all SSH keys WITHOUT private key data. func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) { rows, err := s.db.Query( `SELECT id, name, key_type, fingerprint, public_key, created_at FROM ssh_keys ORDER BY name`, ) if err != nil { return nil, fmt.Errorf("list ssh keys: %w", err) } defer rows.Close() var keys []SSHKey for rows.Next() { var k SSHKey var keyType, fingerprint, publicKey sql.NullString if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil { return nil, fmt.Errorf("scan ssh key: %w", err) } if keyType.Valid { k.KeyType = keyType.String } if fingerprint.Valid { k.Fingerprint = fingerprint.String } if publicKey.Valid { k.PublicKey = publicKey.String } keys = append(keys, k) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate ssh keys: %w", err) } if keys == nil { keys = []SSHKey{} } return keys, nil } // DecryptPassword returns the decrypted password for a credential. func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) { var encrypted sql.NullString err := s.db.QueryRow( "SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'", credentialID, ).Scan(&encrypted) if err != nil { return "", fmt.Errorf("get encrypted password: %w", err) } if !encrypted.Valid { return "", fmt.Errorf("no encrypted value for credential %d", credentialID) } password, err := s.vault.Decrypt(encrypted.String) if err != nil { return "", fmt.Errorf("decrypt password: %w", err) } return password, nil } // DecryptSSHKey returns the decrypted private key + passphrase. func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) { var encryptedKey string var encryptedPassphrase sql.NullString err = s.db.QueryRow( "SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?", sshKeyID, ).Scan(&encryptedKey, &encryptedPassphrase) if err != nil { return nil, "", fmt.Errorf("get encrypted ssh key: %w", err) } decryptedKey, err := s.vault.Decrypt(encryptedKey) if err != nil { return nil, "", fmt.Errorf("decrypt private key: %w", err) } if encryptedPassphrase.Valid { passphrase, err = s.vault.Decrypt(encryptedPassphrase.String) if err != nil { return nil, "", fmt.Errorf("decrypt passphrase: %w", err) } } return []byte(decryptedKey), passphrase, nil } // DeleteCredential removes a credential. func (s *CredentialService) DeleteCredential(id int64) error { _, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id) if err != nil { return fmt.Errorf("delete credential: %w", err) } return nil } // DeleteSSHKey removes an SSH key. func (s *CredentialService) DeleteSSHKey(id int64) error { _, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id) if err != nil { return fmt.Errorf("delete ssh key: %w", err) } return nil } // DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa). func DetectKeyType(pemData []byte) string { block, _ := pem.Decode(pemData) if block == nil { return "unknown" } // Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks) if block.Type == "OPENSSH PRIVATE KEY" { return detectOpenSSHKeyType(block.Bytes) } // Try PKCS8 format if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { switch key.(type) { case *rsa.PrivateKey: return "rsa" case ed25519.PrivateKey: return "ed25519" case *ecdsa.PrivateKey: return "ecdsa" } } // Try RSA PKCS1 if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { return "rsa" } // Try EC if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil { return "ecdsa" } return "unknown" } // detectOpenSSHKeyType parses the OpenSSH private key format to determine key type. func detectOpenSSHKeyType(data []byte) string { // OpenSSH private key format: "openssh-key-v1\0" magic, then fields. // We parse the key using ssh package to determine the type. // Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer. pemBlock := &pem.Block{ Type: "OPENSSH PRIVATE KEY", Bytes: data, } pemBytes := pem.EncodeToMemory(pemBlock) signer, err := ssh.ParsePrivateKey(pemBytes) if err != nil { return "unknown" } return classifyPublicKey(signer.PublicKey()) } // classifyPublicKey determines the key type from an ssh.PublicKey. func classifyPublicKey(pub ssh.PublicKey) string { keyType := pub.Type() switch keyType { case "ssh-rsa": return "rsa" case "ssh-ed25519": return "ed25519" case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521": return "ecdsa" default: return keyType } } // getCredential retrieves a single credential by ID. // GetCredential retrieves a single credential by ID. func (s *CredentialService) GetCredential(id int64) (*Credential, error) { return s.getCredential(id) } func (s *CredentialService) getCredential(id int64) (*Credential, error) { var c Credential var username, domain sql.NullString err := s.db.QueryRow( `SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at FROM credentials WHERE id = ?`, id, ).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt) if err != nil { return nil, fmt.Errorf("get credential: %w", err) } if username.Valid { c.Username = username.String } if domain.Valid { c.Domain = domain.String } return &c, nil } // getSSHKey retrieves a single SSH key by ID (without private key data). func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) { var k SSHKey var keyType, fingerprint, publicKey sql.NullString err := s.db.QueryRow( `SELECT id, name, key_type, fingerprint, public_key, created_at FROM ssh_keys WHERE id = ?`, id, ).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt) if err != nil { return nil, fmt.Errorf("get ssh key: %w", err) } if keyType.Valid { k.KeyType = keyType.String } if fingerprint.Valid { k.Fingerprint = fingerprint.String } if publicKey.Valid { k.PublicKey = publicKey.String } return &k, nil } // generateFingerprint generates an SSH fingerprint string from a public key. func generateFingerprint(pubKey ssh.PublicKey) string { return ssh.FingerprintSHA256(pubKey) } // marshalPublicKey returns the authorized_keys format of an SSH public key. func marshalPublicKey(pubKey ssh.PublicKey) string { return base64.StdEncoding.EncodeToString(pubKey.Marshal()) } // ecdsaCurveName returns the name for an ECDSA curve. func ecdsaCurveName(curve elliptic.Curve) string { switch curve { case elliptic.P256(): return "nistp256" case elliptic.P384(): return "nistp384" case elliptic.P521(): return "nistp521" default: return "unknown" } }