wraith/internal/credentials/service.go
Vantz Stockwell b46c20b0d0
All checks were successful
Build & Sign Wraith / Build Windows + Sign (push) Successful in 1m4s
feat: wire all remaining stubs — settings, SFTP, RDP, credentials, FreeRDP callbacks
Four-agent parallel deployment:

1. Settings persistence — all 5 settings wired to SettingsService.Set/Get,
   theme picker persists, update check calls real UpdateService, external
   links use Browser.OpenURL, SFTP file open/save calls real service,
   Quick Connect creates real connection + session, exit uses Wails quit

2. SSH key management — credential dropdown in ConnectionEditDialog,
   collapsible "Add New Credential" panel with password/SSH key modes,
   CredentialService proxied through WraithApp (vault-locked guard),
   new CreateSSHKeyCredential method for atomic key+credential creation

3. RDP frontend wiring — useRdp.ts calls real RDPGetFrame/SendMouse/
   SendKey/SendClipboard via Wails bindings, ConnectRDP on WraithApp
   resolves credentials and builds RDPConfig, session store handles
   RDP protocol, frame pipeline uses polling at 30fps

4. FreeRDP3 callback registration — PostConnect and BitmapUpdate callbacks
   via syscall.NewCallback, GDI mode for automatic frame decoding,
   freerdp_context_new() call added, settings/input/context pointers
   extracted from struct offsets, BGRA→RGBA channel swap in frame copy,
   event loop fixed to pass context not instance

11 files changed. Zero build errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 11:25:03 -04:00

442 lines
12 KiB
Go

package credentials
import (
"crypto/ed25519"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/pem"
"fmt"
"github.com/vstockwell/wraith/internal/vault"
"golang.org/x/crypto/ssh"
)
// Credential represents a stored credential (password or SSH key reference).
type Credential struct {
ID int64 `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Domain string `json:"domain"`
Type string `json:"type"` // "password" or "ssh_key"
SSHKeyID *int64 `json:"sshKeyId"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}
// SSHKey represents a stored SSH key.
type SSHKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
KeyType string `json:"keyType"`
Fingerprint string `json:"fingerprint"`
PublicKey string `json:"publicKey"`
CreatedAt string `json:"createdAt"`
}
// CredentialService provides CRUD for credentials with vault encryption.
type CredentialService struct {
db *sql.DB
vault *vault.VaultService
}
// NewCredentialService creates a new CredentialService.
func NewCredentialService(db *sql.DB, vault *vault.VaultService) *CredentialService {
return &CredentialService{db: db, vault: vault}
}
// CreatePassword creates a password credential (password encrypted via vault).
func (s *CredentialService) CreatePassword(name, username, password, domain string) (*Credential, error) {
encrypted, err := s.vault.Encrypt(password)
if err != nil {
return nil, fmt.Errorf("encrypt password: %w", err)
}
result, err := s.db.Exec(
`INSERT INTO credentials (name, username, domain, type, encrypted_value)
VALUES (?, ?, ?, 'password', ?)`,
name, username, domain, encrypted,
)
if err != nil {
return nil, fmt.Errorf("insert credential: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get credential id: %w", err)
}
return s.getCredential(id)
}
// CreateSSHKey imports an SSH key (private key encrypted via vault).
func (s *CredentialService) CreateSSHKey(name string, privateKeyPEM []byte, passphrase string) (*SSHKey, error) {
// Parse the private key to detect type and extract public key
keyType := DetectKeyType(privateKeyPEM)
// Parse the key to get the public key for fingerprinting.
// Try without passphrase first (handles unencrypted keys even when a
// passphrase is provided for storage), then fall back to using the
// passphrase for encrypted PEM keys.
var signer ssh.Signer
var err error
signer, err = ssh.ParsePrivateKey(privateKeyPEM)
if err != nil && passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKeyPEM, []byte(passphrase))
}
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
pubKey := signer.PublicKey()
fingerprint := ssh.FingerprintSHA256(pubKey)
publicKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
// Encrypt private key via vault
encryptedKey, err := s.vault.Encrypt(string(privateKeyPEM))
if err != nil {
return nil, fmt.Errorf("encrypt private key: %w", err)
}
// Encrypt passphrase via vault (if provided)
var encryptedPassphrase sql.NullString
if passphrase != "" {
ep, err := s.vault.Encrypt(passphrase)
if err != nil {
return nil, fmt.Errorf("encrypt passphrase: %w", err)
}
encryptedPassphrase = sql.NullString{String: ep, Valid: true}
}
result, err := s.db.Exec(
`INSERT INTO ssh_keys (name, key_type, fingerprint, public_key, encrypted_private_key, passphrase_encrypted)
VALUES (?, ?, ?, ?, ?, ?)`,
name, keyType, fingerprint, publicKeyStr, encryptedKey, encryptedPassphrase,
)
if err != nil {
return nil, fmt.Errorf("insert ssh key: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get ssh key id: %w", err)
}
return s.getSSHKey(id)
}
// ListCredentials returns all credentials WITHOUT encrypted values.
func (s *CredentialService) ListCredentials() ([]Credential, error) {
rows, err := s.db.Query(
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
FROM credentials ORDER BY name`,
)
if err != nil {
return nil, fmt.Errorf("list credentials: %w", err)
}
defer rows.Close()
var creds []Credential
for rows.Next() {
var c Credential
var username, domain sql.NullString
if err := rows.Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan credential: %w", err)
}
if username.Valid {
c.Username = username.String
}
if domain.Valid {
c.Domain = domain.String
}
creds = append(creds, c)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate credentials: %w", err)
}
if creds == nil {
creds = []Credential{}
}
return creds, nil
}
// ListSSHKeys returns all SSH keys WITHOUT private key data.
func (s *CredentialService) ListSSHKeys() ([]SSHKey, error) {
rows, err := s.db.Query(
`SELECT id, name, key_type, fingerprint, public_key, created_at
FROM ssh_keys ORDER BY name`,
)
if err != nil {
return nil, fmt.Errorf("list ssh keys: %w", err)
}
defer rows.Close()
var keys []SSHKey
for rows.Next() {
var k SSHKey
var keyType, fingerprint, publicKey sql.NullString
if err := rows.Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt); err != nil {
return nil, fmt.Errorf("scan ssh key: %w", err)
}
if keyType.Valid {
k.KeyType = keyType.String
}
if fingerprint.Valid {
k.Fingerprint = fingerprint.String
}
if publicKey.Valid {
k.PublicKey = publicKey.String
}
keys = append(keys, k)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate ssh keys: %w", err)
}
if keys == nil {
keys = []SSHKey{}
}
return keys, nil
}
// DecryptPassword returns the decrypted password for a credential.
func (s *CredentialService) DecryptPassword(credentialID int64) (string, error) {
var encrypted sql.NullString
err := s.db.QueryRow(
"SELECT encrypted_value FROM credentials WHERE id = ? AND type = 'password'",
credentialID,
).Scan(&encrypted)
if err != nil {
return "", fmt.Errorf("get encrypted password: %w", err)
}
if !encrypted.Valid {
return "", fmt.Errorf("no encrypted value for credential %d", credentialID)
}
password, err := s.vault.Decrypt(encrypted.String)
if err != nil {
return "", fmt.Errorf("decrypt password: %w", err)
}
return password, nil
}
// DecryptSSHKey returns the decrypted private key + passphrase.
func (s *CredentialService) DecryptSSHKey(sshKeyID int64) (privateKey []byte, passphrase string, err error) {
var encryptedKey string
var encryptedPassphrase sql.NullString
err = s.db.QueryRow(
"SELECT encrypted_private_key, passphrase_encrypted FROM ssh_keys WHERE id = ?",
sshKeyID,
).Scan(&encryptedKey, &encryptedPassphrase)
if err != nil {
return nil, "", fmt.Errorf("get encrypted ssh key: %w", err)
}
decryptedKey, err := s.vault.Decrypt(encryptedKey)
if err != nil {
return nil, "", fmt.Errorf("decrypt private key: %w", err)
}
if encryptedPassphrase.Valid {
passphrase, err = s.vault.Decrypt(encryptedPassphrase.String)
if err != nil {
return nil, "", fmt.Errorf("decrypt passphrase: %w", err)
}
}
return []byte(decryptedKey), passphrase, nil
}
// CreateSSHKeyCredential imports an SSH key and creates a matching credentials
// row in a single transaction, returning the Credential record that the
// frontend can immediately use as a credentialId on a connection.
func (s *CredentialService) CreateSSHKeyCredential(name, username string, privateKeyPEM []byte, passphrase string) (*Credential, error) {
sshKey, err := s.CreateSSHKey(name, privateKeyPEM, passphrase)
if err != nil {
return nil, err
}
result, err := s.db.Exec(
`INSERT INTO credentials (name, username, type, ssh_key_id)
VALUES (?, ?, 'ssh_key', ?)`,
name, username, sshKey.ID,
)
if err != nil {
// Best-effort cleanup of the orphaned ssh_key row
_, _ = s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", sshKey.ID)
return nil, fmt.Errorf("insert ssh_key credential: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get credential id: %w", err)
}
return s.getCredential(id)
}
// DeleteCredential removes a credential.
func (s *CredentialService) DeleteCredential(id int64) error {
_, err := s.db.Exec("DELETE FROM credentials WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete credential: %w", err)
}
return nil
}
// DeleteSSHKey removes an SSH key.
func (s *CredentialService) DeleteSSHKey(id int64) error {
_, err := s.db.Exec("DELETE FROM ssh_keys WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete ssh key: %w", err)
}
return nil
}
// DetectKeyType parses a PEM key and returns its type (rsa, ed25519, ecdsa).
func DetectKeyType(pemData []byte) string {
block, _ := pem.Decode(pemData)
if block == nil {
return "unknown"
}
// Try OpenSSH format first (ssh.MarshalPrivateKey produces OPENSSH PRIVATE KEY blocks)
if block.Type == "OPENSSH PRIVATE KEY" {
return detectOpenSSHKeyType(block.Bytes)
}
// Try PKCS8 format
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
switch key.(type) {
case *rsa.PrivateKey:
return "rsa"
case ed25519.PrivateKey:
return "ed25519"
case *ecdsa.PrivateKey:
return "ecdsa"
}
}
// Try RSA PKCS1
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return "rsa"
}
// Try EC
if _, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return "ecdsa"
}
return "unknown"
}
// detectOpenSSHKeyType parses the OpenSSH private key format to determine key type.
func detectOpenSSHKeyType(data []byte) string {
// OpenSSH private key format: "openssh-key-v1\0" magic, then fields.
// We parse the key using ssh package to determine the type.
// Re-encode to PEM to use ssh.ParsePrivateKey which gives us the signer.
pemBlock := &pem.Block{
Type: "OPENSSH PRIVATE KEY",
Bytes: data,
}
pemBytes := pem.EncodeToMemory(pemBlock)
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return "unknown"
}
return classifyPublicKey(signer.PublicKey())
}
// classifyPublicKey determines the key type from an ssh.PublicKey.
func classifyPublicKey(pub ssh.PublicKey) string {
keyType := pub.Type()
switch keyType {
case "ssh-rsa":
return "rsa"
case "ssh-ed25519":
return "ed25519"
case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521":
return "ecdsa"
default:
return keyType
}
}
// getCredential retrieves a single credential by ID.
// GetCredential retrieves a single credential by ID.
func (s *CredentialService) GetCredential(id int64) (*Credential, error) {
return s.getCredential(id)
}
func (s *CredentialService) getCredential(id int64) (*Credential, error) {
var c Credential
var username, domain sql.NullString
err := s.db.QueryRow(
`SELECT id, name, username, domain, type, ssh_key_id, created_at, updated_at
FROM credentials WHERE id = ?`, id,
).Scan(&c.ID, &c.Name, &username, &domain, &c.Type, &c.SSHKeyID, &c.CreatedAt, &c.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("get credential: %w", err)
}
if username.Valid {
c.Username = username.String
}
if domain.Valid {
c.Domain = domain.String
}
return &c, nil
}
// getSSHKey retrieves a single SSH key by ID (without private key data).
func (s *CredentialService) getSSHKey(id int64) (*SSHKey, error) {
var k SSHKey
var keyType, fingerprint, publicKey sql.NullString
err := s.db.QueryRow(
`SELECT id, name, key_type, fingerprint, public_key, created_at
FROM ssh_keys WHERE id = ?`, id,
).Scan(&k.ID, &k.Name, &keyType, &fingerprint, &publicKey, &k.CreatedAt)
if err != nil {
return nil, fmt.Errorf("get ssh key: %w", err)
}
if keyType.Valid {
k.KeyType = keyType.String
}
if fingerprint.Valid {
k.Fingerprint = fingerprint.String
}
if publicKey.Valid {
k.PublicKey = publicKey.String
}
return &k, nil
}
// generateFingerprint generates an SSH fingerprint string from a public key.
func generateFingerprint(pubKey ssh.PublicKey) string {
return ssh.FingerprintSHA256(pubKey)
}
// marshalPublicKey returns the authorized_keys format of an SSH public key.
func marshalPublicKey(pubKey ssh.PublicKey) string {
return base64.StdEncoding.EncodeToString(pubKey.Marshal())
}
// ecdsaCurveName returns the name for an ECDSA curve.
func ecdsaCurveName(curve elliptic.Curve) string {
switch curve {
case elliptic.P256():
return "nistp256"
case elliptic.P384():
return "nistp384"
case elliptic.P521():
return "nistp521"
default:
return "unknown"
}
}