diff --git a/internal/connections/service.go b/internal/connections/service.go new file mode 100644 index 0000000..5ca5f23 --- /dev/null +++ b/internal/connections/service.go @@ -0,0 +1,361 @@ +package connections + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" +) + +type Group struct { + ID int64 `json:"id"` + Name string `json:"name"` + ParentID *int64 `json:"parentId"` + SortOrder int `json:"sortOrder"` + Icon string `json:"icon"` + CreatedAt time.Time `json:"createdAt"` + Children []Group `json:"children,omitempty"` +} + +type Connection struct { + ID int64 `json:"id"` + Name string `json:"name"` + Hostname string `json:"hostname"` + Port int `json:"port"` + Protocol string `json:"protocol"` + GroupID *int64 `json:"groupId"` + CredentialID *int64 `json:"credentialId"` + Color string `json:"color"` + Tags []string `json:"tags"` + Notes string `json:"notes"` + Options string `json:"options"` + SortOrder int `json:"sortOrder"` + LastConnected *time.Time `json:"lastConnected"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type CreateConnectionInput struct { + Name string `json:"name"` + Hostname string `json:"hostname"` + Port int `json:"port"` + Protocol string `json:"protocol"` + GroupID *int64 `json:"groupId"` + CredentialID *int64 `json:"credentialId"` + Color string `json:"color"` + Tags []string `json:"tags"` + Notes string `json:"notes"` + Options string `json:"options"` +} + +type UpdateConnectionInput struct { + Name *string `json:"name"` + Hostname *string `json:"hostname"` + Port *int `json:"port"` + GroupID *int64 `json:"groupId"` + CredentialID *int64 `json:"credentialId"` + Color *string `json:"color"` + Tags []string `json:"tags"` + Notes *string `json:"notes"` + Options *string `json:"options"` +} + +type ConnectionService struct { + db *sql.DB +} + +func NewConnectionService(db *sql.DB) *ConnectionService { + return &ConnectionService{db: db} +} + +// ---------- Group CRUD ---------- + +func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) { + result, err := s.db.Exec( + "INSERT INTO groups (name, parent_id) VALUES (?, ?)", + name, parentID, + ) + if err != nil { + return nil, fmt.Errorf("create group: %w", err) + } + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get group id: %w", err) + } + + var g Group + var icon sql.NullString + err = s.db.QueryRow( + "SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id, + ).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get created group: %w", err) + } + if icon.Valid { + g.Icon = icon.String + } + return &g, nil +} + +func (s *ConnectionService) ListGroups() ([]Group, error) { + rows, err := s.db.Query( + "SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name", + ) + if err != nil { + return nil, fmt.Errorf("list groups: %w", err) + } + defer rows.Close() + + groupMap := make(map[int64]*Group) + var allGroups []*Group + + for rows.Next() { + var g Group + var icon sql.NullString + if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil { + return nil, fmt.Errorf("scan group: %w", err) + } + if icon.Valid { + g.Icon = icon.String + } + groupMap[g.ID] = &g + allGroups = append(allGroups, &g) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate groups: %w", err) + } + + // Build tree: attach children to parents, collect roots + var roots []Group + for _, g := range allGroups { + if g.ParentID != nil { + if parent, ok := groupMap[*g.ParentID]; ok { + parent.Children = append(parent.Children, *g) + } + } else { + roots = append(roots, *g) + } + } + + // Re-attach children to root copies (since we copied into roots) + for i := range roots { + if orig, ok := groupMap[roots[i].ID]; ok { + roots[i].Children = orig.Children + } + } + + if roots == nil { + roots = []Group{} + } + return roots, nil +} + +func (s *ConnectionService) DeleteGroup(id int64) error { + _, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete group: %w", err) + } + return nil +} + +// ---------- Connection CRUD ---------- + +func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) { + tags, err := json.Marshal(input.Tags) + if err != nil { + return nil, fmt.Errorf("marshal tags: %w", err) + } + if input.Tags == nil { + tags = []byte("[]") + } + + options := input.Options + if options == "" { + options = "{}" + } + + result, err := s.db.Exec( + `INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.Name, input.Hostname, input.Port, input.Protocol, + input.GroupID, input.CredentialID, input.Color, + string(tags), input.Notes, options, + ) + if err != nil { + return nil, fmt.Errorf("create connection: %w", err) + } + id, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("get connection id: %w", err) + } + return s.GetConnection(id) +} + +func (s *ConnectionService) GetConnection(id int64) (*Connection, error) { + row := s.db.QueryRow( + `SELECT id, name, hostname, port, protocol, group_id, credential_id, + color, tags, notes, options, sort_order, last_connected, created_at, updated_at + FROM connections WHERE id = ?`, id, + ) + + var c Connection + var tagsJSON string + var color, notes, options sql.NullString + var lastConnected sql.NullTime + + err := row.Scan( + &c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol, + &c.GroupID, &c.CredentialID, + &color, &tagsJSON, ¬es, &options, + &c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("get connection: %w", err) + } + + if color.Valid { + c.Color = color.String + } + if notes.Valid { + c.Notes = notes.String + } + if options.Valid { + c.Options = options.String + } + if lastConnected.Valid { + c.LastConnected = &lastConnected.Time + } + + if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil { + c.Tags = []string{} + } + if c.Tags == nil { + c.Tags = []string{} + } + + return &c, nil +} + +func (s *ConnectionService) ListConnections() ([]Connection, error) { + rows, err := s.db.Query( + `SELECT id, name, hostname, port, protocol, group_id, credential_id, + color, tags, notes, options, sort_order, last_connected, created_at, updated_at + FROM connections ORDER BY sort_order, name`, + ) + if err != nil { + return nil, fmt.Errorf("list connections: %w", err) + } + defer rows.Close() + + return scanConnections(rows) +} + +// scanConnections is a shared helper used by ListConnections and (later) Search. +func scanConnections(rows *sql.Rows) ([]Connection, error) { + var conns []Connection + + for rows.Next() { + var c Connection + var tagsJSON string + var color, notes, options sql.NullString + var lastConnected sql.NullTime + + if err := rows.Scan( + &c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol, + &c.GroupID, &c.CredentialID, + &color, &tagsJSON, ¬es, &options, + &c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan connection: %w", err) + } + + if color.Valid { + c.Color = color.String + } + if notes.Valid { + c.Notes = notes.String + } + if options.Valid { + c.Options = options.String + } + if lastConnected.Valid { + c.LastConnected = &lastConnected.Time + } + + if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil { + c.Tags = []string{} + } + if c.Tags == nil { + c.Tags = []string{} + } + + conns = append(conns, c) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate connections: %w", err) + } + + if conns == nil { + conns = []Connection{} + } + return conns, nil +} + +func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) { + setClauses := []string{"updated_at = CURRENT_TIMESTAMP"} + args := []interface{}{} + + if input.Name != nil { + setClauses = append(setClauses, "name = ?") + args = append(args, *input.Name) + } + if input.Hostname != nil { + setClauses = append(setClauses, "hostname = ?") + args = append(args, *input.Hostname) + } + if input.Port != nil { + setClauses = append(setClauses, "port = ?") + args = append(args, *input.Port) + } + if input.GroupID != nil { + setClauses = append(setClauses, "group_id = ?") + args = append(args, *input.GroupID) + } + if input.CredentialID != nil { + setClauses = append(setClauses, "credential_id = ?") + args = append(args, *input.CredentialID) + } + if input.Tags != nil { + tags, _ := json.Marshal(input.Tags) + setClauses = append(setClauses, "tags = ?") + args = append(args, string(tags)) + } + if input.Notes != nil { + setClauses = append(setClauses, "notes = ?") + args = append(args, *input.Notes) + } + if input.Color != nil { + setClauses = append(setClauses, "color = ?") + args = append(args, *input.Color) + } + if input.Options != nil { + setClauses = append(setClauses, "options = ?") + args = append(args, *input.Options) + } + + args = append(args, id) + query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", ")) + if _, err := s.db.Exec(query, args...); err != nil { + return nil, fmt.Errorf("update connection: %w", err) + } + return s.GetConnection(id) +} + +func (s *ConnectionService) DeleteConnection(id int64) error { + _, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete connection: %w", err) + } + return nil +} diff --git a/internal/connections/service_test.go b/internal/connections/service_test.go new file mode 100644 index 0000000..7be0d96 --- /dev/null +++ b/internal/connections/service_test.go @@ -0,0 +1,234 @@ +package connections + +import ( + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" +) + +func strPtr(s string) *string { return &s } + +func setupTestService(t *testing.T) *ConnectionService { + t.Helper() + dir := t.TempDir() + database, err := db.Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("db.Open() error: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("db.Migrate() error: %v", err) + } + t.Cleanup(func() { database.Close() }) + return NewConnectionService(database) +} + +func TestCreateGroup(t *testing.T) { + svc := setupTestService(t) + + g, err := svc.CreateGroup("Servers", nil) + if err != nil { + t.Fatalf("CreateGroup() error: %v", err) + } + if g.ID == 0 { + t.Error("expected non-zero ID") + } + if g.Name != "Servers" { + t.Errorf("Name = %q, want %q", g.Name, "Servers") + } + if g.ParentID != nil { + t.Errorf("ParentID = %v, want nil", g.ParentID) + } +} + +func TestCreateSubGroup(t *testing.T) { + svc := setupTestService(t) + + parent, err := svc.CreateGroup("Servers", nil) + if err != nil { + t.Fatalf("CreateGroup(parent) error: %v", err) + } + + child, err := svc.CreateGroup("Production", &parent.ID) + if err != nil { + t.Fatalf("CreateGroup(child) error: %v", err) + } + if child.ParentID == nil { + t.Fatal("expected non-nil ParentID") + } + if *child.ParentID != parent.ID { + t.Errorf("ParentID = %d, want %d", *child.ParentID, parent.ID) + } +} + +func TestListGroups(t *testing.T) { + svc := setupTestService(t) + + parent, err := svc.CreateGroup("Servers", nil) + if err != nil { + t.Fatalf("CreateGroup(parent) error: %v", err) + } + if _, err := svc.CreateGroup("Production", &parent.ID); err != nil { + t.Fatalf("CreateGroup(child) error: %v", err) + } + + groups, err := svc.ListGroups() + if err != nil { + t.Fatalf("ListGroups() error: %v", err) + } + if len(groups) != 1 { + t.Fatalf("len(groups) = %d, want 1 (only root groups)", len(groups)) + } + if groups[0].Name != "Servers" { + t.Errorf("groups[0].Name = %q, want %q", groups[0].Name, "Servers") + } + if len(groups[0].Children) != 1 { + t.Fatalf("len(children) = %d, want 1", len(groups[0].Children)) + } + if groups[0].Children[0].Name != "Production" { + t.Errorf("child name = %q, want %q", groups[0].Children[0].Name, "Production") + } +} + +func TestDeleteGroup(t *testing.T) { + svc := setupTestService(t) + + g, err := svc.CreateGroup("ToDelete", nil) + if err != nil { + t.Fatalf("CreateGroup() error: %v", err) + } + + if err := svc.DeleteGroup(g.ID); err != nil { + t.Fatalf("DeleteGroup() error: %v", err) + } + + groups, err := svc.ListGroups() + if err != nil { + t.Fatalf("ListGroups() error: %v", err) + } + if len(groups) != 0 { + t.Errorf("len(groups) = %d, want 0 after delete", len(groups)) + } +} + +func TestCreateConnection(t *testing.T) { + svc := setupTestService(t) + + conn, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Web Server", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + Tags: []string{"Prod", "Linux"}, + Options: `{"keepAliveInterval": 60}`, + }) + if err != nil { + t.Fatalf("CreateConnection() error: %v", err) + } + if conn.ID == 0 { + t.Error("expected non-zero ID") + } + if conn.Name != "Web Server" { + t.Errorf("Name = %q, want %q", conn.Name, "Web Server") + } + if len(conn.Tags) != 2 { + t.Fatalf("len(Tags) = %d, want 2", len(conn.Tags)) + } + if conn.Tags[0] != "Prod" || conn.Tags[1] != "Linux" { + t.Errorf("Tags = %v, want [Prod Linux]", conn.Tags) + } + if conn.Options != `{"keepAliveInterval": 60}` { + t.Errorf("Options = %q, want JSON blob", conn.Options) + } +} + +func TestListConnections(t *testing.T) { + svc := setupTestService(t) + + if _, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Server A", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + }); err != nil { + t.Fatalf("CreateConnection(A) error: %v", err) + } + if _, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Server B", + Hostname: "10.0.0.2", + Port: 3389, + Protocol: "rdp", + }); err != nil { + t.Fatalf("CreateConnection(B) error: %v", err) + } + + conns, err := svc.ListConnections() + if err != nil { + t.Fatalf("ListConnections() error: %v", err) + } + if len(conns) != 2 { + t.Fatalf("len(conns) = %d, want 2", len(conns)) + } +} + +func TestUpdateConnection(t *testing.T) { + svc := setupTestService(t) + + conn, err := svc.CreateConnection(CreateConnectionInput{ + Name: "Old Name", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + Tags: []string{"Dev"}, + }) + if err != nil { + t.Fatalf("CreateConnection() error: %v", err) + } + + updated, err := svc.UpdateConnection(conn.ID, UpdateConnectionInput{ + Name: strPtr("New Name"), + Tags: []string{"Prod", "Linux"}, + }) + if err != nil { + t.Fatalf("UpdateConnection() error: %v", err) + } + if updated.Name != "New Name" { + t.Errorf("Name = %q, want %q", updated.Name, "New Name") + } + if len(updated.Tags) != 2 { + t.Fatalf("len(Tags) = %d, want 2", len(updated.Tags)) + } + if updated.Tags[0] != "Prod" { + t.Errorf("Tags[0] = %q, want %q", updated.Tags[0], "Prod") + } + // Hostname should remain unchanged + if updated.Hostname != "10.0.0.1" { + t.Errorf("Hostname = %q, want %q (unchanged)", updated.Hostname, "10.0.0.1") + } +} + +func TestDeleteConnection(t *testing.T) { + svc := setupTestService(t) + + conn, err := svc.CreateConnection(CreateConnectionInput{ + Name: "ToDelete", + Hostname: "10.0.0.1", + Port: 22, + Protocol: "ssh", + }) + if err != nil { + t.Fatalf("CreateConnection() error: %v", err) + } + + if err := svc.DeleteConnection(conn.ID); err != nil { + t.Fatalf("DeleteConnection() error: %v", err) + } + + conns, err := svc.ListConnections() + if err != nil { + t.Fatalf("ListConnections() error: %v", err) + } + if len(conns) != 0 { + t.Errorf("len(conns) = %d, want 0 after delete", len(conns)) + } +}