Merge branch 'worktree-agent-a36e902e' into feat/phase1-foundation
This commit is contained in:
commit
995e81de3b
361
internal/connections/service.go
Normal file
361
internal/connections/service.go
Normal file
@ -0,0 +1,361 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ParentID *int64 `json:"parentId"`
|
||||
SortOrder int `json:"sortOrder"`
|
||||
Icon string `json:"icon"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Children []Group `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Hostname string `json:"hostname"`
|
||||
Port int `json:"port"`
|
||||
Protocol string `json:"protocol"`
|
||||
GroupID *int64 `json:"groupId"`
|
||||
CredentialID *int64 `json:"credentialId"`
|
||||
Color string `json:"color"`
|
||||
Tags []string `json:"tags"`
|
||||
Notes string `json:"notes"`
|
||||
Options string `json:"options"`
|
||||
SortOrder int `json:"sortOrder"`
|
||||
LastConnected *time.Time `json:"lastConnected"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type CreateConnectionInput struct {
|
||||
Name string `json:"name"`
|
||||
Hostname string `json:"hostname"`
|
||||
Port int `json:"port"`
|
||||
Protocol string `json:"protocol"`
|
||||
GroupID *int64 `json:"groupId"`
|
||||
CredentialID *int64 `json:"credentialId"`
|
||||
Color string `json:"color"`
|
||||
Tags []string `json:"tags"`
|
||||
Notes string `json:"notes"`
|
||||
Options string `json:"options"`
|
||||
}
|
||||
|
||||
type UpdateConnectionInput struct {
|
||||
Name *string `json:"name"`
|
||||
Hostname *string `json:"hostname"`
|
||||
Port *int `json:"port"`
|
||||
GroupID *int64 `json:"groupId"`
|
||||
CredentialID *int64 `json:"credentialId"`
|
||||
Color *string `json:"color"`
|
||||
Tags []string `json:"tags"`
|
||||
Notes *string `json:"notes"`
|
||||
Options *string `json:"options"`
|
||||
}
|
||||
|
||||
type ConnectionService struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewConnectionService(db *sql.DB) *ConnectionService {
|
||||
return &ConnectionService{db: db}
|
||||
}
|
||||
|
||||
// ---------- Group CRUD ----------
|
||||
|
||||
func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) {
|
||||
result, err := s.db.Exec(
|
||||
"INSERT INTO groups (name, parent_id) VALUES (?, ?)",
|
||||
name, parentID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create group: %w", err)
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group id: %w", err)
|
||||
}
|
||||
|
||||
var g Group
|
||||
var icon sql.NullString
|
||||
err = s.db.QueryRow(
|
||||
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id,
|
||||
).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get created group: %w", err)
|
||||
}
|
||||
if icon.Valid {
|
||||
g.Icon = icon.String
|
||||
}
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
func (s *ConnectionService) ListGroups() ([]Group, error) {
|
||||
rows, err := s.db.Query(
|
||||
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list groups: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
groupMap := make(map[int64]*Group)
|
||||
var allGroups []*Group
|
||||
|
||||
for rows.Next() {
|
||||
var g Group
|
||||
var icon sql.NullString
|
||||
if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan group: %w", err)
|
||||
}
|
||||
if icon.Valid {
|
||||
g.Icon = icon.String
|
||||
}
|
||||
groupMap[g.ID] = &g
|
||||
allGroups = append(allGroups, &g)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate groups: %w", err)
|
||||
}
|
||||
|
||||
// Build tree: attach children to parents, collect roots
|
||||
var roots []Group
|
||||
for _, g := range allGroups {
|
||||
if g.ParentID != nil {
|
||||
if parent, ok := groupMap[*g.ParentID]; ok {
|
||||
parent.Children = append(parent.Children, *g)
|
||||
}
|
||||
} else {
|
||||
roots = append(roots, *g)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-attach children to root copies (since we copied into roots)
|
||||
for i := range roots {
|
||||
if orig, ok := groupMap[roots[i].ID]; ok {
|
||||
roots[i].Children = orig.Children
|
||||
}
|
||||
}
|
||||
|
||||
if roots == nil {
|
||||
roots = []Group{}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
func (s *ConnectionService) DeleteGroup(id int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------- Connection CRUD ----------
|
||||
|
||||
func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) {
|
||||
tags, err := json.Marshal(input.Tags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal tags: %w", err)
|
||||
}
|
||||
if input.Tags == nil {
|
||||
tags = []byte("[]")
|
||||
}
|
||||
|
||||
options := input.Options
|
||||
if options == "" {
|
||||
options = "{}"
|
||||
}
|
||||
|
||||
result, err := s.db.Exec(
|
||||
`INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
input.Name, input.Hostname, input.Port, input.Protocol,
|
||||
input.GroupID, input.CredentialID, input.Color,
|
||||
string(tags), input.Notes, options,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create connection: %w", err)
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get connection id: %w", err)
|
||||
}
|
||||
return s.GetConnection(id)
|
||||
}
|
||||
|
||||
func (s *ConnectionService) GetConnection(id int64) (*Connection, error) {
|
||||
row := s.db.QueryRow(
|
||||
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
|
||||
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
|
||||
FROM connections WHERE id = ?`, id,
|
||||
)
|
||||
|
||||
var c Connection
|
||||
var tagsJSON string
|
||||
var color, notes, options sql.NullString
|
||||
var lastConnected sql.NullTime
|
||||
|
||||
err := row.Scan(
|
||||
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
|
||||
&c.GroupID, &c.CredentialID,
|
||||
&color, &tagsJSON, ¬es, &options,
|
||||
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get connection: %w", err)
|
||||
}
|
||||
|
||||
if color.Valid {
|
||||
c.Color = color.String
|
||||
}
|
||||
if notes.Valid {
|
||||
c.Notes = notes.String
|
||||
}
|
||||
if options.Valid {
|
||||
c.Options = options.String
|
||||
}
|
||||
if lastConnected.Valid {
|
||||
c.LastConnected = &lastConnected.Time
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
|
||||
c.Tags = []string{}
|
||||
}
|
||||
if c.Tags == nil {
|
||||
c.Tags = []string{}
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (s *ConnectionService) ListConnections() ([]Connection, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
|
||||
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
|
||||
FROM connections ORDER BY sort_order, name`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list connections: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanConnections(rows)
|
||||
}
|
||||
|
||||
// scanConnections is a shared helper used by ListConnections and (later) Search.
|
||||
func scanConnections(rows *sql.Rows) ([]Connection, error) {
|
||||
var conns []Connection
|
||||
|
||||
for rows.Next() {
|
||||
var c Connection
|
||||
var tagsJSON string
|
||||
var color, notes, options sql.NullString
|
||||
var lastConnected sql.NullTime
|
||||
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
|
||||
&c.GroupID, &c.CredentialID,
|
||||
&color, &tagsJSON, ¬es, &options,
|
||||
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan connection: %w", err)
|
||||
}
|
||||
|
||||
if color.Valid {
|
||||
c.Color = color.String
|
||||
}
|
||||
if notes.Valid {
|
||||
c.Notes = notes.String
|
||||
}
|
||||
if options.Valid {
|
||||
c.Options = options.String
|
||||
}
|
||||
if lastConnected.Valid {
|
||||
c.LastConnected = &lastConnected.Time
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
|
||||
c.Tags = []string{}
|
||||
}
|
||||
if c.Tags == nil {
|
||||
c.Tags = []string{}
|
||||
}
|
||||
|
||||
conns = append(conns, c)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate connections: %w", err)
|
||||
}
|
||||
|
||||
if conns == nil {
|
||||
conns = []Connection{}
|
||||
}
|
||||
return conns, nil
|
||||
}
|
||||
|
||||
func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) {
|
||||
setClauses := []string{"updated_at = CURRENT_TIMESTAMP"}
|
||||
args := []interface{}{}
|
||||
|
||||
if input.Name != nil {
|
||||
setClauses = append(setClauses, "name = ?")
|
||||
args = append(args, *input.Name)
|
||||
}
|
||||
if input.Hostname != nil {
|
||||
setClauses = append(setClauses, "hostname = ?")
|
||||
args = append(args, *input.Hostname)
|
||||
}
|
||||
if input.Port != nil {
|
||||
setClauses = append(setClauses, "port = ?")
|
||||
args = append(args, *input.Port)
|
||||
}
|
||||
if input.GroupID != nil {
|
||||
setClauses = append(setClauses, "group_id = ?")
|
||||
args = append(args, *input.GroupID)
|
||||
}
|
||||
if input.CredentialID != nil {
|
||||
setClauses = append(setClauses, "credential_id = ?")
|
||||
args = append(args, *input.CredentialID)
|
||||
}
|
||||
if input.Tags != nil {
|
||||
tags, _ := json.Marshal(input.Tags)
|
||||
setClauses = append(setClauses, "tags = ?")
|
||||
args = append(args, string(tags))
|
||||
}
|
||||
if input.Notes != nil {
|
||||
setClauses = append(setClauses, "notes = ?")
|
||||
args = append(args, *input.Notes)
|
||||
}
|
||||
if input.Color != nil {
|
||||
setClauses = append(setClauses, "color = ?")
|
||||
args = append(args, *input.Color)
|
||||
}
|
||||
if input.Options != nil {
|
||||
setClauses = append(setClauses, "options = ?")
|
||||
args = append(args, *input.Options)
|
||||
}
|
||||
|
||||
args = append(args, id)
|
||||
query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", "))
|
||||
if _, err := s.db.Exec(query, args...); err != nil {
|
||||
return nil, fmt.Errorf("update connection: %w", err)
|
||||
}
|
||||
return s.GetConnection(id)
|
||||
}
|
||||
|
||||
func (s *ConnectionService) DeleteConnection(id int64) error {
|
||||
_, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete connection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
234
internal/connections/service_test.go
Normal file
234
internal/connections/service_test.go
Normal file
@ -0,0 +1,234 @@
|
||||
package connections
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/vstockwell/wraith/internal/db"
|
||||
)
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
func setupTestService(t *testing.T) *ConnectionService {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("db.Open() error: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("db.Migrate() error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { database.Close() })
|
||||
return NewConnectionService(database)
|
||||
}
|
||||
|
||||
func TestCreateGroup(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
g, err := svc.CreateGroup("Servers", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateGroup() error: %v", err)
|
||||
}
|
||||
if g.ID == 0 {
|
||||
t.Error("expected non-zero ID")
|
||||
}
|
||||
if g.Name != "Servers" {
|
||||
t.Errorf("Name = %q, want %q", g.Name, "Servers")
|
||||
}
|
||||
if g.ParentID != nil {
|
||||
t.Errorf("ParentID = %v, want nil", g.ParentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSubGroup(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
parent, err := svc.CreateGroup("Servers", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateGroup(parent) error: %v", err)
|
||||
}
|
||||
|
||||
child, err := svc.CreateGroup("Production", &parent.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateGroup(child) error: %v", err)
|
||||
}
|
||||
if child.ParentID == nil {
|
||||
t.Fatal("expected non-nil ParentID")
|
||||
}
|
||||
if *child.ParentID != parent.ID {
|
||||
t.Errorf("ParentID = %d, want %d", *child.ParentID, parent.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListGroups(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
parent, err := svc.CreateGroup("Servers", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateGroup(parent) error: %v", err)
|
||||
}
|
||||
if _, err := svc.CreateGroup("Production", &parent.ID); err != nil {
|
||||
t.Fatalf("CreateGroup(child) error: %v", err)
|
||||
}
|
||||
|
||||
groups, err := svc.ListGroups()
|
||||
if err != nil {
|
||||
t.Fatalf("ListGroups() error: %v", err)
|
||||
}
|
||||
if len(groups) != 1 {
|
||||
t.Fatalf("len(groups) = %d, want 1 (only root groups)", len(groups))
|
||||
}
|
||||
if groups[0].Name != "Servers" {
|
||||
t.Errorf("groups[0].Name = %q, want %q", groups[0].Name, "Servers")
|
||||
}
|
||||
if len(groups[0].Children) != 1 {
|
||||
t.Fatalf("len(children) = %d, want 1", len(groups[0].Children))
|
||||
}
|
||||
if groups[0].Children[0].Name != "Production" {
|
||||
t.Errorf("child name = %q, want %q", groups[0].Children[0].Name, "Production")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteGroup(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
g, err := svc.CreateGroup("ToDelete", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateGroup() error: %v", err)
|
||||
}
|
||||
|
||||
if err := svc.DeleteGroup(g.ID); err != nil {
|
||||
t.Fatalf("DeleteGroup() error: %v", err)
|
||||
}
|
||||
|
||||
groups, err := svc.ListGroups()
|
||||
if err != nil {
|
||||
t.Fatalf("ListGroups() error: %v", err)
|
||||
}
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("len(groups) = %d, want 0 after delete", len(groups))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConnection(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
conn, err := svc.CreateConnection(CreateConnectionInput{
|
||||
Name: "Web Server",
|
||||
Hostname: "10.0.0.1",
|
||||
Port: 22,
|
||||
Protocol: "ssh",
|
||||
Tags: []string{"Prod", "Linux"},
|
||||
Options: `{"keepAliveInterval": 60}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConnection() error: %v", err)
|
||||
}
|
||||
if conn.ID == 0 {
|
||||
t.Error("expected non-zero ID")
|
||||
}
|
||||
if conn.Name != "Web Server" {
|
||||
t.Errorf("Name = %q, want %q", conn.Name, "Web Server")
|
||||
}
|
||||
if len(conn.Tags) != 2 {
|
||||
t.Fatalf("len(Tags) = %d, want 2", len(conn.Tags))
|
||||
}
|
||||
if conn.Tags[0] != "Prod" || conn.Tags[1] != "Linux" {
|
||||
t.Errorf("Tags = %v, want [Prod Linux]", conn.Tags)
|
||||
}
|
||||
if conn.Options != `{"keepAliveInterval": 60}` {
|
||||
t.Errorf("Options = %q, want JSON blob", conn.Options)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConnections(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
if _, err := svc.CreateConnection(CreateConnectionInput{
|
||||
Name: "Server A",
|
||||
Hostname: "10.0.0.1",
|
||||
Port: 22,
|
||||
Protocol: "ssh",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateConnection(A) error: %v", err)
|
||||
}
|
||||
if _, err := svc.CreateConnection(CreateConnectionInput{
|
||||
Name: "Server B",
|
||||
Hostname: "10.0.0.2",
|
||||
Port: 3389,
|
||||
Protocol: "rdp",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateConnection(B) error: %v", err)
|
||||
}
|
||||
|
||||
conns, err := svc.ListConnections()
|
||||
if err != nil {
|
||||
t.Fatalf("ListConnections() error: %v", err)
|
||||
}
|
||||
if len(conns) != 2 {
|
||||
t.Fatalf("len(conns) = %d, want 2", len(conns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConnection(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
conn, err := svc.CreateConnection(CreateConnectionInput{
|
||||
Name: "Old Name",
|
||||
Hostname: "10.0.0.1",
|
||||
Port: 22,
|
||||
Protocol: "ssh",
|
||||
Tags: []string{"Dev"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConnection() error: %v", err)
|
||||
}
|
||||
|
||||
updated, err := svc.UpdateConnection(conn.ID, UpdateConnectionInput{
|
||||
Name: strPtr("New Name"),
|
||||
Tags: []string{"Prod", "Linux"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateConnection() error: %v", err)
|
||||
}
|
||||
if updated.Name != "New Name" {
|
||||
t.Errorf("Name = %q, want %q", updated.Name, "New Name")
|
||||
}
|
||||
if len(updated.Tags) != 2 {
|
||||
t.Fatalf("len(Tags) = %d, want 2", len(updated.Tags))
|
||||
}
|
||||
if updated.Tags[0] != "Prod" {
|
||||
t.Errorf("Tags[0] = %q, want %q", updated.Tags[0], "Prod")
|
||||
}
|
||||
// Hostname should remain unchanged
|
||||
if updated.Hostname != "10.0.0.1" {
|
||||
t.Errorf("Hostname = %q, want %q (unchanged)", updated.Hostname, "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteConnection(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
|
||||
conn, err := svc.CreateConnection(CreateConnectionInput{
|
||||
Name: "ToDelete",
|
||||
Hostname: "10.0.0.1",
|
||||
Port: 22,
|
||||
Protocol: "ssh",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConnection() error: %v", err)
|
||||
}
|
||||
|
||||
if err := svc.DeleteConnection(conn.ID); err != nil {
|
||||
t.Fatalf("DeleteConnection() error: %v", err)
|
||||
}
|
||||
|
||||
conns, err := svc.ListConnections()
|
||||
if err != nil {
|
||||
t.Fatalf("ListConnections() error: %v", err)
|
||||
}
|
||||
if len(conns) != 0 {
|
||||
t.Errorf("len(conns) = %d, want 0 after delete", len(conns))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user