Merge branch 'worktree-agent-a36e902e' into feat/phase1-foundation

This commit is contained in:
Vantz Stockwell 2026-03-17 06:17:52 -04:00
commit 995e81de3b
2 changed files with 595 additions and 0 deletions

View File

@ -0,0 +1,361 @@
package connections
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parentId"`
SortOrder int `json:"sortOrder"`
Icon string `json:"icon"`
CreatedAt time.Time `json:"createdAt"`
Children []Group `json:"children,omitempty"`
}
type Connection struct {
ID int64 `json:"id"`
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color string `json:"color"`
Tags []string `json:"tags"`
Notes string `json:"notes"`
Options string `json:"options"`
SortOrder int `json:"sortOrder"`
LastConnected *time.Time `json:"lastConnected"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type CreateConnectionInput struct {
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color string `json:"color"`
Tags []string `json:"tags"`
Notes string `json:"notes"`
Options string `json:"options"`
}
type UpdateConnectionInput struct {
Name *string `json:"name"`
Hostname *string `json:"hostname"`
Port *int `json:"port"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color *string `json:"color"`
Tags []string `json:"tags"`
Notes *string `json:"notes"`
Options *string `json:"options"`
}
type ConnectionService struct {
db *sql.DB
}
func NewConnectionService(db *sql.DB) *ConnectionService {
return &ConnectionService{db: db}
}
// ---------- Group CRUD ----------
func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) {
result, err := s.db.Exec(
"INSERT INTO groups (name, parent_id) VALUES (?, ?)",
name, parentID,
)
if err != nil {
return nil, fmt.Errorf("create group: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get group id: %w", err)
}
var g Group
var icon sql.NullString
err = s.db.QueryRow(
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id,
).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt)
if err != nil {
return nil, fmt.Errorf("get created group: %w", err)
}
if icon.Valid {
g.Icon = icon.String
}
return &g, nil
}
func (s *ConnectionService) ListGroups() ([]Group, error) {
rows, err := s.db.Query(
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name",
)
if err != nil {
return nil, fmt.Errorf("list groups: %w", err)
}
defer rows.Close()
groupMap := make(map[int64]*Group)
var allGroups []*Group
for rows.Next() {
var g Group
var icon sql.NullString
if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil {
return nil, fmt.Errorf("scan group: %w", err)
}
if icon.Valid {
g.Icon = icon.String
}
groupMap[g.ID] = &g
allGroups = append(allGroups, &g)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate groups: %w", err)
}
// Build tree: attach children to parents, collect roots
var roots []Group
for _, g := range allGroups {
if g.ParentID != nil {
if parent, ok := groupMap[*g.ParentID]; ok {
parent.Children = append(parent.Children, *g)
}
} else {
roots = append(roots, *g)
}
}
// Re-attach children to root copies (since we copied into roots)
for i := range roots {
if orig, ok := groupMap[roots[i].ID]; ok {
roots[i].Children = orig.Children
}
}
if roots == nil {
roots = []Group{}
}
return roots, nil
}
func (s *ConnectionService) DeleteGroup(id int64) error {
_, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
}
// ---------- Connection CRUD ----------
func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) {
tags, err := json.Marshal(input.Tags)
if err != nil {
return nil, fmt.Errorf("marshal tags: %w", err)
}
if input.Tags == nil {
tags = []byte("[]")
}
options := input.Options
if options == "" {
options = "{}"
}
result, err := s.db.Exec(
`INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
input.Name, input.Hostname, input.Port, input.Protocol,
input.GroupID, input.CredentialID, input.Color,
string(tags), input.Notes, options,
)
if err != nil {
return nil, fmt.Errorf("create connection: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get connection id: %w", err)
}
return s.GetConnection(id)
}
func (s *ConnectionService) GetConnection(id int64) (*Connection, error) {
row := s.db.QueryRow(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
FROM connections WHERE id = ?`, id,
)
var c Connection
var tagsJSON string
var color, notes, options sql.NullString
var lastConnected sql.NullTime
err := row.Scan(
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
&c.GroupID, &c.CredentialID,
&color, &tagsJSON, &notes, &options,
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("get connection: %w", err)
}
if color.Valid {
c.Color = color.String
}
if notes.Valid {
c.Notes = notes.String
}
if options.Valid {
c.Options = options.String
}
if lastConnected.Valid {
c.LastConnected = &lastConnected.Time
}
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
c.Tags = []string{}
}
if c.Tags == nil {
c.Tags = []string{}
}
return &c, nil
}
func (s *ConnectionService) ListConnections() ([]Connection, error) {
rows, err := s.db.Query(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
FROM connections ORDER BY sort_order, name`,
)
if err != nil {
return nil, fmt.Errorf("list connections: %w", err)
}
defer rows.Close()
return scanConnections(rows)
}
// scanConnections is a shared helper used by ListConnections and (later) Search.
func scanConnections(rows *sql.Rows) ([]Connection, error) {
var conns []Connection
for rows.Next() {
var c Connection
var tagsJSON string
var color, notes, options sql.NullString
var lastConnected sql.NullTime
if err := rows.Scan(
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
&c.GroupID, &c.CredentialID,
&color, &tagsJSON, &notes, &options,
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan connection: %w", err)
}
if color.Valid {
c.Color = color.String
}
if notes.Valid {
c.Notes = notes.String
}
if options.Valid {
c.Options = options.String
}
if lastConnected.Valid {
c.LastConnected = &lastConnected.Time
}
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
c.Tags = []string{}
}
if c.Tags == nil {
c.Tags = []string{}
}
conns = append(conns, c)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate connections: %w", err)
}
if conns == nil {
conns = []Connection{}
}
return conns, nil
}
func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) {
setClauses := []string{"updated_at = CURRENT_TIMESTAMP"}
args := []interface{}{}
if input.Name != nil {
setClauses = append(setClauses, "name = ?")
args = append(args, *input.Name)
}
if input.Hostname != nil {
setClauses = append(setClauses, "hostname = ?")
args = append(args, *input.Hostname)
}
if input.Port != nil {
setClauses = append(setClauses, "port = ?")
args = append(args, *input.Port)
}
if input.GroupID != nil {
setClauses = append(setClauses, "group_id = ?")
args = append(args, *input.GroupID)
}
if input.CredentialID != nil {
setClauses = append(setClauses, "credential_id = ?")
args = append(args, *input.CredentialID)
}
if input.Tags != nil {
tags, _ := json.Marshal(input.Tags)
setClauses = append(setClauses, "tags = ?")
args = append(args, string(tags))
}
if input.Notes != nil {
setClauses = append(setClauses, "notes = ?")
args = append(args, *input.Notes)
}
if input.Color != nil {
setClauses = append(setClauses, "color = ?")
args = append(args, *input.Color)
}
if input.Options != nil {
setClauses = append(setClauses, "options = ?")
args = append(args, *input.Options)
}
args = append(args, id)
query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", "))
if _, err := s.db.Exec(query, args...); err != nil {
return nil, fmt.Errorf("update connection: %w", err)
}
return s.GetConnection(id)
}
func (s *ConnectionService) DeleteConnection(id int64) error {
_, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete connection: %w", err)
}
return nil
}

View File

@ -0,0 +1,234 @@
package connections
import (
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
)
func strPtr(s string) *string { return &s }
func setupTestService(t *testing.T) *ConnectionService {
t.Helper()
dir := t.TempDir()
database, err := db.Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("db.Open() error: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("db.Migrate() error: %v", err)
}
t.Cleanup(func() { database.Close() })
return NewConnectionService(database)
}
func TestCreateGroup(t *testing.T) {
svc := setupTestService(t)
g, err := svc.CreateGroup("Servers", nil)
if err != nil {
t.Fatalf("CreateGroup() error: %v", err)
}
if g.ID == 0 {
t.Error("expected non-zero ID")
}
if g.Name != "Servers" {
t.Errorf("Name = %q, want %q", g.Name, "Servers")
}
if g.ParentID != nil {
t.Errorf("ParentID = %v, want nil", g.ParentID)
}
}
func TestCreateSubGroup(t *testing.T) {
svc := setupTestService(t)
parent, err := svc.CreateGroup("Servers", nil)
if err != nil {
t.Fatalf("CreateGroup(parent) error: %v", err)
}
child, err := svc.CreateGroup("Production", &parent.ID)
if err != nil {
t.Fatalf("CreateGroup(child) error: %v", err)
}
if child.ParentID == nil {
t.Fatal("expected non-nil ParentID")
}
if *child.ParentID != parent.ID {
t.Errorf("ParentID = %d, want %d", *child.ParentID, parent.ID)
}
}
func TestListGroups(t *testing.T) {
svc := setupTestService(t)
parent, err := svc.CreateGroup("Servers", nil)
if err != nil {
t.Fatalf("CreateGroup(parent) error: %v", err)
}
if _, err := svc.CreateGroup("Production", &parent.ID); err != nil {
t.Fatalf("CreateGroup(child) error: %v", err)
}
groups, err := svc.ListGroups()
if err != nil {
t.Fatalf("ListGroups() error: %v", err)
}
if len(groups) != 1 {
t.Fatalf("len(groups) = %d, want 1 (only root groups)", len(groups))
}
if groups[0].Name != "Servers" {
t.Errorf("groups[0].Name = %q, want %q", groups[0].Name, "Servers")
}
if len(groups[0].Children) != 1 {
t.Fatalf("len(children) = %d, want 1", len(groups[0].Children))
}
if groups[0].Children[0].Name != "Production" {
t.Errorf("child name = %q, want %q", groups[0].Children[0].Name, "Production")
}
}
func TestDeleteGroup(t *testing.T) {
svc := setupTestService(t)
g, err := svc.CreateGroup("ToDelete", nil)
if err != nil {
t.Fatalf("CreateGroup() error: %v", err)
}
if err := svc.DeleteGroup(g.ID); err != nil {
t.Fatalf("DeleteGroup() error: %v", err)
}
groups, err := svc.ListGroups()
if err != nil {
t.Fatalf("ListGroups() error: %v", err)
}
if len(groups) != 0 {
t.Errorf("len(groups) = %d, want 0 after delete", len(groups))
}
}
func TestCreateConnection(t *testing.T) {
svc := setupTestService(t)
conn, err := svc.CreateConnection(CreateConnectionInput{
Name: "Web Server",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
Tags: []string{"Prod", "Linux"},
Options: `{"keepAliveInterval": 60}`,
})
if err != nil {
t.Fatalf("CreateConnection() error: %v", err)
}
if conn.ID == 0 {
t.Error("expected non-zero ID")
}
if conn.Name != "Web Server" {
t.Errorf("Name = %q, want %q", conn.Name, "Web Server")
}
if len(conn.Tags) != 2 {
t.Fatalf("len(Tags) = %d, want 2", len(conn.Tags))
}
if conn.Tags[0] != "Prod" || conn.Tags[1] != "Linux" {
t.Errorf("Tags = %v, want [Prod Linux]", conn.Tags)
}
if conn.Options != `{"keepAliveInterval": 60}` {
t.Errorf("Options = %q, want JSON blob", conn.Options)
}
}
func TestListConnections(t *testing.T) {
svc := setupTestService(t)
if _, err := svc.CreateConnection(CreateConnectionInput{
Name: "Server A",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
}); err != nil {
t.Fatalf("CreateConnection(A) error: %v", err)
}
if _, err := svc.CreateConnection(CreateConnectionInput{
Name: "Server B",
Hostname: "10.0.0.2",
Port: 3389,
Protocol: "rdp",
}); err != nil {
t.Fatalf("CreateConnection(B) error: %v", err)
}
conns, err := svc.ListConnections()
if err != nil {
t.Fatalf("ListConnections() error: %v", err)
}
if len(conns) != 2 {
t.Fatalf("len(conns) = %d, want 2", len(conns))
}
}
func TestUpdateConnection(t *testing.T) {
svc := setupTestService(t)
conn, err := svc.CreateConnection(CreateConnectionInput{
Name: "Old Name",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
Tags: []string{"Dev"},
})
if err != nil {
t.Fatalf("CreateConnection() error: %v", err)
}
updated, err := svc.UpdateConnection(conn.ID, UpdateConnectionInput{
Name: strPtr("New Name"),
Tags: []string{"Prod", "Linux"},
})
if err != nil {
t.Fatalf("UpdateConnection() error: %v", err)
}
if updated.Name != "New Name" {
t.Errorf("Name = %q, want %q", updated.Name, "New Name")
}
if len(updated.Tags) != 2 {
t.Fatalf("len(Tags) = %d, want 2", len(updated.Tags))
}
if updated.Tags[0] != "Prod" {
t.Errorf("Tags[0] = %q, want %q", updated.Tags[0], "Prod")
}
// Hostname should remain unchanged
if updated.Hostname != "10.0.0.1" {
t.Errorf("Hostname = %q, want %q (unchanged)", updated.Hostname, "10.0.0.1")
}
}
func TestDeleteConnection(t *testing.T) {
svc := setupTestService(t)
conn, err := svc.CreateConnection(CreateConnectionInput{
Name: "ToDelete",
Hostname: "10.0.0.1",
Port: 22,
Protocol: "ssh",
})
if err != nil {
t.Fatalf("CreateConnection() error: %v", err)
}
if err := svc.DeleteConnection(conn.ID); err != nil {
t.Fatalf("DeleteConnection() error: %v", err)
}
conns, err := svc.ListConnections()
if err != nil {
t.Fatalf("ListConnections() error: %v", err)
}
if len(conns) != 0 {
t.Errorf("len(conns) = %d, want 0 after delete", len(conns))
}
}