package credentials import ( "crypto/ed25519" "crypto/rand" "encoding/pem" "path/filepath" "testing" "github.com/vstockwell/wraith/internal/db" "github.com/vstockwell/wraith/internal/vault" "golang.org/x/crypto/ssh" ) func setupCredentialService(t *testing.T) *CredentialService { t.Helper() d, err := db.Open(filepath.Join(t.TempDir(), "test.db")) if err != nil { t.Fatal(err) } if err := db.Migrate(d); err != nil { t.Fatal(err) } t.Cleanup(func() { d.Close() }) salt := []byte("test-salt-exactly-32-bytes-long!") key := vault.DeriveKey("testpassword", salt) vs := vault.NewVaultService(key) return NewCredentialService(d, vs) } func TestCreatePasswordCredential(t *testing.T) { svc := setupCredentialService(t) cred, err := svc.CreatePassword("Test Cred", "admin", "secret123", "") if err != nil { t.Fatal(err) } if cred.Name != "Test Cred" { t.Error("wrong name") } if cred.Type != "password" { t.Error("wrong type") } } func TestDecryptPassword(t *testing.T) { svc := setupCredentialService(t) cred, _ := svc.CreatePassword("Test", "admin", "mypassword", "") password, err := svc.DecryptPassword(cred.ID) if err != nil { t.Fatal(err) } if password != "mypassword" { t.Errorf("got %q, want mypassword", password) } } func TestListCredentialsExcludesSecrets(t *testing.T) { svc := setupCredentialService(t) svc.CreatePassword("Cred1", "user1", "pass1", "") svc.CreatePassword("Cred2", "user2", "pass2", "") creds, err := svc.ListCredentials() if err != nil { t.Fatal(err) } if len(creds) != 2 { t.Errorf("got %d, want 2", len(creds)) } } func TestCreateSSHKey(t *testing.T) { svc := setupCredentialService(t) // Generate a test key _, priv, _ := ed25519.GenerateKey(rand.Reader) pemBlock, _ := ssh.MarshalPrivateKey(priv, "") keyPEM := pem.EncodeToMemory(pemBlock) key, err := svc.CreateSSHKey("My Key", keyPEM, "") if err != nil { t.Fatal(err) } if key.KeyType != "ed25519" { t.Errorf("KeyType = %q, want ed25519", key.KeyType) } if key.Fingerprint == "" { t.Error("fingerprint should not be empty") } } func TestDecryptSSHKey(t *testing.T) { svc := setupCredentialService(t) _, priv, _ := ed25519.GenerateKey(rand.Reader) pemBlock, _ := ssh.MarshalPrivateKey(priv, "") keyPEM := pem.EncodeToMemory(pemBlock) key, _ := svc.CreateSSHKey("My Key", keyPEM, "testpass") decryptedKey, passphrase, err := svc.DecryptSSHKey(key.ID) if err != nil { t.Fatal(err) } if len(decryptedKey) == 0 { t.Error("decrypted key should not be empty") } if passphrase != "testpass" { t.Errorf("passphrase = %q, want testpass", passphrase) } } func TestDetectKeyType(t *testing.T) { _, priv, _ := ed25519.GenerateKey(rand.Reader) pemBlock, _ := ssh.MarshalPrivateKey(priv, "") keyPEM := pem.EncodeToMemory(pemBlock) if got := DetectKeyType(keyPEM); got != "ed25519" { t.Errorf("got %q", got) } } func TestDeleteCredential(t *testing.T) { svc := setupCredentialService(t) cred, _ := svc.CreatePassword("ToDelete", "user", "pass", "") err := svc.DeleteCredential(cred.ID) if err != nil { t.Fatal(err) } creds, _ := svc.ListCredentials() if len(creds) != 0 { t.Errorf("got %d credentials, want 0", len(creds)) } } func TestDeleteSSHKey(t *testing.T) { svc := setupCredentialService(t) _, priv, _ := ed25519.GenerateKey(rand.Reader) pemBlock, _ := ssh.MarshalPrivateKey(priv, "") keyPEM := pem.EncodeToMemory(pemBlock) key, _ := svc.CreateSSHKey("ToDelete", keyPEM, "") err := svc.DeleteSSHKey(key.ID) if err != nil { t.Fatal(err) } keys, _ := svc.ListSSHKeys() if len(keys) != 0 { t.Errorf("got %d keys, want 0", len(keys)) } } func TestListSSHKeys(t *testing.T) { svc := setupCredentialService(t) _, priv, _ := ed25519.GenerateKey(rand.Reader) pemBlock, _ := ssh.MarshalPrivateKey(priv, "") keyPEM := pem.EncodeToMemory(pemBlock) svc.CreateSSHKey("Key1", keyPEM, "") svc.CreateSSHKey("Key2", keyPEM, "") keys, err := svc.ListSSHKeys() if err != nil { t.Fatal(err) } if len(keys) != 2 { t.Errorf("got %d keys, want 2", len(keys)) } } func TestCreatePasswordWithDomain(t *testing.T) { svc := setupCredentialService(t) cred, err := svc.CreatePassword("Domain Cred", "admin", "secret", "example.com") if err != nil { t.Fatal(err) } if cred.Domain != "example.com" { t.Errorf("domain = %q, want example.com", cred.Domain) } }