wraith/internal/connections/service.go
Vantz Stockwell 8a096d7f7b
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled
Wraith v0.1.0 — Desktop SSH + RDP + SFTP Client
Go + Wails v3 + Vue 3 + SQLite + FreeRDP3 (purego)
183 tests, 76 source files, 9,910 lines of code

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 08:19:29 -04:00

362 lines
9.4 KiB
Go

package connections
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
ParentID *int64 `json:"parentId"`
SortOrder int `json:"sortOrder"`
Icon string `json:"icon"`
CreatedAt time.Time `json:"createdAt"`
Children []Group `json:"children,omitempty"`
}
type Connection struct {
ID int64 `json:"id"`
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color string `json:"color"`
Tags []string `json:"tags"`
Notes string `json:"notes"`
Options string `json:"options"`
SortOrder int `json:"sortOrder"`
LastConnected *time.Time `json:"lastConnected"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type CreateConnectionInput struct {
Name string `json:"name"`
Hostname string `json:"hostname"`
Port int `json:"port"`
Protocol string `json:"protocol"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color string `json:"color"`
Tags []string `json:"tags"`
Notes string `json:"notes"`
Options string `json:"options"`
}
type UpdateConnectionInput struct {
Name *string `json:"name"`
Hostname *string `json:"hostname"`
Port *int `json:"port"`
GroupID *int64 `json:"groupId"`
CredentialID *int64 `json:"credentialId"`
Color *string `json:"color"`
Tags []string `json:"tags"`
Notes *string `json:"notes"`
Options *string `json:"options"`
}
type ConnectionService struct {
db *sql.DB
}
func NewConnectionService(db *sql.DB) *ConnectionService {
return &ConnectionService{db: db}
}
// ---------- Group CRUD ----------
func (s *ConnectionService) CreateGroup(name string, parentID *int64) (*Group, error) {
result, err := s.db.Exec(
"INSERT INTO groups (name, parent_id) VALUES (?, ?)",
name, parentID,
)
if err != nil {
return nil, fmt.Errorf("create group: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get group id: %w", err)
}
var g Group
var icon sql.NullString
err = s.db.QueryRow(
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups WHERE id = ?", id,
).Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt)
if err != nil {
return nil, fmt.Errorf("get created group: %w", err)
}
if icon.Valid {
g.Icon = icon.String
}
return &g, nil
}
func (s *ConnectionService) ListGroups() ([]Group, error) {
rows, err := s.db.Query(
"SELECT id, name, parent_id, sort_order, icon, created_at FROM groups ORDER BY sort_order, name",
)
if err != nil {
return nil, fmt.Errorf("list groups: %w", err)
}
defer rows.Close()
groupMap := make(map[int64]*Group)
var allGroups []*Group
for rows.Next() {
var g Group
var icon sql.NullString
if err := rows.Scan(&g.ID, &g.Name, &g.ParentID, &g.SortOrder, &icon, &g.CreatedAt); err != nil {
return nil, fmt.Errorf("scan group: %w", err)
}
if icon.Valid {
g.Icon = icon.String
}
groupMap[g.ID] = &g
allGroups = append(allGroups, &g)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate groups: %w", err)
}
// Build tree: attach children to parents, collect roots
var roots []Group
for _, g := range allGroups {
if g.ParentID != nil {
if parent, ok := groupMap[*g.ParentID]; ok {
parent.Children = append(parent.Children, *g)
}
} else {
roots = append(roots, *g)
}
}
// Re-attach children to root copies (since we copied into roots)
for i := range roots {
if orig, ok := groupMap[roots[i].ID]; ok {
roots[i].Children = orig.Children
}
}
if roots == nil {
roots = []Group{}
}
return roots, nil
}
func (s *ConnectionService) DeleteGroup(id int64) error {
_, err := s.db.Exec("DELETE FROM groups WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
}
// ---------- Connection CRUD ----------
func (s *ConnectionService) CreateConnection(input CreateConnectionInput) (*Connection, error) {
tags, err := json.Marshal(input.Tags)
if err != nil {
return nil, fmt.Errorf("marshal tags: %w", err)
}
if input.Tags == nil {
tags = []byte("[]")
}
options := input.Options
if options == "" {
options = "{}"
}
result, err := s.db.Exec(
`INSERT INTO connections (name, hostname, port, protocol, group_id, credential_id, color, tags, notes, options)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
input.Name, input.Hostname, input.Port, input.Protocol,
input.GroupID, input.CredentialID, input.Color,
string(tags), input.Notes, options,
)
if err != nil {
return nil, fmt.Errorf("create connection: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("get connection id: %w", err)
}
return s.GetConnection(id)
}
func (s *ConnectionService) GetConnection(id int64) (*Connection, error) {
row := s.db.QueryRow(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
FROM connections WHERE id = ?`, id,
)
var c Connection
var tagsJSON string
var color, notes, options sql.NullString
var lastConnected sql.NullTime
err := row.Scan(
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
&c.GroupID, &c.CredentialID,
&color, &tagsJSON, &notes, &options,
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("get connection: %w", err)
}
if color.Valid {
c.Color = color.String
}
if notes.Valid {
c.Notes = notes.String
}
if options.Valid {
c.Options = options.String
}
if lastConnected.Valid {
c.LastConnected = &lastConnected.Time
}
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
c.Tags = []string{}
}
if c.Tags == nil {
c.Tags = []string{}
}
return &c, nil
}
func (s *ConnectionService) ListConnections() ([]Connection, error) {
rows, err := s.db.Query(
`SELECT id, name, hostname, port, protocol, group_id, credential_id,
color, tags, notes, options, sort_order, last_connected, created_at, updated_at
FROM connections ORDER BY sort_order, name`,
)
if err != nil {
return nil, fmt.Errorf("list connections: %w", err)
}
defer rows.Close()
return scanConnections(rows)
}
// scanConnections is a shared helper used by ListConnections and (later) Search.
func scanConnections(rows *sql.Rows) ([]Connection, error) {
var conns []Connection
for rows.Next() {
var c Connection
var tagsJSON string
var color, notes, options sql.NullString
var lastConnected sql.NullTime
if err := rows.Scan(
&c.ID, &c.Name, &c.Hostname, &c.Port, &c.Protocol,
&c.GroupID, &c.CredentialID,
&color, &tagsJSON, &notes, &options,
&c.SortOrder, &lastConnected, &c.CreatedAt, &c.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan connection: %w", err)
}
if color.Valid {
c.Color = color.String
}
if notes.Valid {
c.Notes = notes.String
}
if options.Valid {
c.Options = options.String
}
if lastConnected.Valid {
c.LastConnected = &lastConnected.Time
}
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
c.Tags = []string{}
}
if c.Tags == nil {
c.Tags = []string{}
}
conns = append(conns, c)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate connections: %w", err)
}
if conns == nil {
conns = []Connection{}
}
return conns, nil
}
func (s *ConnectionService) UpdateConnection(id int64, input UpdateConnectionInput) (*Connection, error) {
setClauses := []string{"updated_at = CURRENT_TIMESTAMP"}
args := []interface{}{}
if input.Name != nil {
setClauses = append(setClauses, "name = ?")
args = append(args, *input.Name)
}
if input.Hostname != nil {
setClauses = append(setClauses, "hostname = ?")
args = append(args, *input.Hostname)
}
if input.Port != nil {
setClauses = append(setClauses, "port = ?")
args = append(args, *input.Port)
}
if input.GroupID != nil {
setClauses = append(setClauses, "group_id = ?")
args = append(args, *input.GroupID)
}
if input.CredentialID != nil {
setClauses = append(setClauses, "credential_id = ?")
args = append(args, *input.CredentialID)
}
if input.Tags != nil {
tags, _ := json.Marshal(input.Tags)
setClauses = append(setClauses, "tags = ?")
args = append(args, string(tags))
}
if input.Notes != nil {
setClauses = append(setClauses, "notes = ?")
args = append(args, *input.Notes)
}
if input.Color != nil {
setClauses = append(setClauses, "color = ?")
args = append(args, *input.Color)
}
if input.Options != nil {
setClauses = append(setClauses, "options = ?")
args = append(args, *input.Options)
}
args = append(args, id)
query := fmt.Sprintf("UPDATE connections SET %s WHERE id = ?", strings.Join(setClauses, ", "))
if _, err := s.db.Exec(query, args...); err != nil {
return nil, fmt.Errorf("update connection: %w", err)
}
return s.GetConnection(id)
}
func (s *ConnectionService) DeleteConnection(id int64) error {
_, err := s.db.Exec("DELETE FROM connections WHERE id = ?", id)
if err != nil {
return fmt.Errorf("delete connection: %w", err)
}
return nil
}