diff --git a/internal/ai/client.go b/internal/ai/client.go new file mode 100644 index 0000000..316960b --- /dev/null +++ b/internal/ai/client.go @@ -0,0 +1,256 @@ +package ai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" +) + +const ( + defaultModel = "claude-sonnet-4-20250514" + anthropicVersion = "2023-06-01" + apiBaseURL = "https://api.anthropic.com/v1/messages" +) + +// ClaudeClient is the HTTP client for the Anthropic Messages API. +type ClaudeClient struct { + oauth *OAuthManager + apiKey string // fallback for non-OAuth users + model string + httpClient *http.Client +} + +// NewClaudeClient creates a client that authenticates via OAuth (primary) or API key (fallback). +func NewClaudeClient(oauth *OAuthManager, model string) *ClaudeClient { + if model == "" { + model = defaultModel + } + return &ClaudeClient{ + oauth: oauth, + model: model, + httpClient: &http.Client{}, + } +} + +// SetAPIKey sets a fallback API key for users without Max subscriptions. +func (c *ClaudeClient) SetAPIKey(key string) { + c.apiKey = key +} + +// apiRequest is the JSON body sent to the Messages API. +type apiRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + System string `json:"system,omitempty"` + Tools []Tool `json:"tools,omitempty"` +} + +// SendMessage sends a request to the Messages API with streaming enabled. +// It returns a channel of StreamEvents that the caller reads until the channel is closed. +func (c *ClaudeClient) SendMessage(messages []Message, tools []Tool, systemPrompt string) (<-chan StreamEvent, error) { + reqBody := apiRequest{ + Model: c.model, + Messages: messages, + MaxTokens: 8192, + Stream: true, + System: systemPrompt, + Tools: tools, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + req, err := http.NewRequest("POST", apiBaseURL, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-version", anthropicVersion) + + if err := c.setAuthHeader(req); err != nil { + return nil, fmt.Errorf("set auth header: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body)) + } + + ch := make(chan StreamEvent, 64) + go c.parseSSEStream(resp.Body, ch) + return ch, nil +} + +// setAuthHeader sets the appropriate authorization header on the request. +func (c *ClaudeClient) setAuthHeader(req *http.Request) error { + // Try OAuth first + if c.oauth != nil && c.oauth.IsAuthenticated() { + token, err := c.oauth.GetAccessToken() + if err == nil { + req.Header.Set("Authorization", "Bearer "+token) + return nil + } + slog.Warn("oauth token failed, falling back to api key", "error", err) + } + + // Fallback to API key + if c.apiKey != "" { + req.Header.Set("x-api-key", c.apiKey) + return nil + } + + return fmt.Errorf("no authentication method available — log in via OAuth or set an API key") +} + +// parseSSEStream reads the SSE response body and sends StreamEvents on the channel. +func (c *ClaudeClient) parseSSEStream(body io.ReadCloser, ch chan<- StreamEvent) { + defer body.Close() + defer close(ch) + + scanner := bufio.NewScanner(body) + + var currentEventType string + var currentToolID string + var currentToolName string + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + continue + } + + if strings.HasPrefix(line, "event: ") { + currentEventType = strings.TrimPrefix(line, "event: ") + continue + } + + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + switch currentEventType { + case "content_block_start": + var block struct { + Index int `json:"index"` + ContentBlock struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Text string `json:"text"` + Input string `json:"input"` + } `json:"content_block"` + } + if err := json.Unmarshal([]byte(data), &block); err != nil { + slog.Warn("failed to parse content_block_start", "error", err) + continue + } + + if block.ContentBlock.Type == "tool_use" { + currentToolID = block.ContentBlock.ID + currentToolName = block.ContentBlock.Name + ch <- StreamEvent{ + Type: "tool_use_start", + ToolID: currentToolID, + ToolName: currentToolName, + } + } else if block.ContentBlock.Type == "text" && block.ContentBlock.Text != "" { + ch <- StreamEvent{Type: "text_delta", Data: block.ContentBlock.Text} + } + + case "content_block_delta": + var delta struct { + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + PartialJSON string `json:"partial_json"` + } `json:"delta"` + } + if err := json.Unmarshal([]byte(data), &delta); err != nil { + slog.Warn("failed to parse content_block_delta", "error", err) + continue + } + + if delta.Delta.Type == "text_delta" { + ch <- StreamEvent{Type: "text_delta", Data: delta.Delta.Text} + } else if delta.Delta.Type == "input_json_delta" { + ch <- StreamEvent{ + Type: "tool_use_delta", + Data: delta.Delta.PartialJSON, + ToolID: currentToolID, + ToolName: currentToolName, + } + } + + case "content_block_stop": + // nothing to do; tool input is accumulated by the caller + + case "message_delta": + // contains stop_reason and usage; emit usage info + ch <- StreamEvent{Type: "done", Data: data} + + case "message_stop": + // end of message stream + return + + case "message_start": + // contains message metadata; could extract usage but we handle it at message_delta + + case "ping": + // heartbeat; ignore + + case "error": + ch <- StreamEvent{Type: "error", Data: data} + return + } + } + + if err := scanner.Err(); err != nil { + ch <- StreamEvent{Type: "error", Data: err.Error()} + } +} + +// BuildRequestBody creates the JSON request body for testing purposes. +func BuildRequestBody(messages []Message, tools []Tool, systemPrompt, model string) ([]byte, error) { + if model == "" { + model = defaultModel + } + req := apiRequest{ + Model: model, + Messages: messages, + MaxTokens: 8192, + Stream: true, + System: systemPrompt, + Tools: tools, + } + return json.Marshal(req) +} + +// ParseSSELine parses a single SSE data line. Returns the event type and data payload. +func ParseSSELine(line string) (eventType, data string) { + if strings.HasPrefix(line, "event: ") { + return "event", strings.TrimPrefix(line, "event: ") + } + if strings.HasPrefix(line, "data: ") { + return "data", strings.TrimPrefix(line, "data: ") + } + return "", "" +} diff --git a/internal/ai/client_test.go b/internal/ai/client_test.go new file mode 100644 index 0000000..d28ec4c --- /dev/null +++ b/internal/ai/client_test.go @@ -0,0 +1,123 @@ +package ai + +import ( + "encoding/json" + "net/http" + "testing" +) + +func TestBuildRequestBody(t *testing.T) { + messages := []Message{ + { + Role: "user", + Content: []ContentBlock{ + {Type: "text", Text: "Hello"}, + }, + }, + } + tools := []Tool{ + { + Name: "test_tool", + Description: "A test tool", + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }, + } + + body, err := BuildRequestBody(messages, tools, "You are helpful.", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var parsed map[string]interface{} + if err := json.Unmarshal(body, &parsed); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + // Check model + if m, ok := parsed["model"].(string); !ok || m != defaultModel { + t.Errorf("expected model %s, got %v", defaultModel, parsed["model"]) + } + + // Check stream + if s, ok := parsed["stream"].(bool); !ok || !s { + t.Errorf("expected stream true, got %v", parsed["stream"]) + } + + // Check system prompt + if s, ok := parsed["system"].(string); !ok || s != "You are helpful." { + t.Errorf("expected system prompt, got %v", parsed["system"]) + } + + // Check max_tokens + if mt, ok := parsed["max_tokens"].(float64); !ok || int(mt) != 8192 { + t.Errorf("expected max_tokens 8192, got %v", parsed["max_tokens"]) + } + + // Check messages array exists + msgs, ok := parsed["messages"].([]interface{}) + if !ok || len(msgs) != 1 { + t.Errorf("expected 1 message, got %v", parsed["messages"]) + } + + // Check tools array exists + tls, ok := parsed["tools"].([]interface{}) + if !ok || len(tls) != 1 { + t.Errorf("expected 1 tool, got %v", parsed["tools"]) + } +} + +func TestParseSSELine(t *testing.T) { + tests := []struct { + input string + wantType string + wantData string + }{ + {"event: content_block_delta", "event", "content_block_delta"}, + {"data: {\"delta\":{\"text\":\"hello\"}}", "data", "{\"delta\":{\"text\":\"hello\"}}"}, + {"", "", ""}, + {"random line", "", ""}, + {"event: message_stop", "event", "message_stop"}, + {"data: [DONE]", "data", "[DONE]"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + gotType, gotData := ParseSSELine(tt.input) + if gotType != tt.wantType { + t.Errorf("ParseSSELine(%q) type = %q, want %q", tt.input, gotType, tt.wantType) + } + if gotData != tt.wantData { + t.Errorf("ParseSSELine(%q) data = %q, want %q", tt.input, gotData, tt.wantData) + } + }) + } +} + +func TestAuthHeader(t *testing.T) { + // Test with API key only (no OAuth) + client := NewClaudeClient(nil, "") + client.SetAPIKey("sk-test-12345") + + req, _ := http.NewRequest("POST", "https://example.com", nil) + if err := client.setAuthHeader(req); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := req.Header.Get("x-api-key"); got != "sk-test-12345" { + t.Errorf("expected x-api-key sk-test-12345, got %q", got) + } + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("expected no Authorization header, got %q", got) + } +} + +func TestAuthHeaderNoAuth(t *testing.T) { + // No OAuth, no API key — should return error + client := NewClaudeClient(nil, "") + + req, _ := http.NewRequest("POST", "https://example.com", nil) + err := client.setAuthHeader(req) + if err == nil { + t.Error("expected error when no auth method available") + } +} diff --git a/internal/ai/conversation.go b/internal/ai/conversation.go new file mode 100644 index 0000000..dea0249 --- /dev/null +++ b/internal/ai/conversation.go @@ -0,0 +1,169 @@ +package ai + +import ( + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" +) + +// ConversationManager handles CRUD operations for AI conversations stored in SQLite. +type ConversationManager struct { + db *sql.DB +} + +// NewConversationManager creates a manager backed by the given database. +func NewConversationManager(db *sql.DB) *ConversationManager { + return &ConversationManager{db: db} +} + +// Create starts a new conversation and returns its summary. +func (m *ConversationManager) Create(model string) (*ConversationSummary, error) { + id := uuid.NewString() + now := time.Now() + + _, err := m.db.Exec( + `INSERT INTO conversations (id, title, model, messages, tokens_in, tokens_out, created_at, updated_at) + VALUES (?, ?, ?, '[]', 0, 0, ?, ?)`, + id, "New conversation", model, now, now, + ) + if err != nil { + return nil, fmt.Errorf("create conversation: %w", err) + } + + return &ConversationSummary{ + ID: id, + Title: "New conversation", + Model: model, + CreatedAt: now, + TokensIn: 0, + TokensOut: 0, + }, nil +} + +// AddMessage appends a message to the conversation's message list. +func (m *ConversationManager) AddMessage(convId string, msg Message) error { + // Get existing messages + messages, err := m.GetMessages(convId) + if err != nil { + return err + } + + messages = append(messages, msg) + data, err := json.Marshal(messages) + if err != nil { + return fmt.Errorf("marshal messages: %w", err) + } + + _, err = m.db.Exec( + "UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?", + string(data), time.Now(), convId, + ) + if err != nil { + return fmt.Errorf("update messages: %w", err) + } + + // Auto-title from first user message + if len(messages) == 1 && msg.Role == "user" { + title := extractTitle(msg) + if title != "" { + m.db.Exec("UPDATE conversations SET title = ? WHERE id = ?", title, convId) + } + } + + return nil +} + +// GetMessages returns all messages in a conversation. +func (m *ConversationManager) GetMessages(convId string) ([]Message, error) { + var messagesJSON string + err := m.db.QueryRow("SELECT messages FROM conversations WHERE id = ?", convId).Scan(&messagesJSON) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("conversation %s not found", convId) + } + if err != nil { + return nil, fmt.Errorf("get messages: %w", err) + } + + var messages []Message + if err := json.Unmarshal([]byte(messagesJSON), &messages); err != nil { + return nil, fmt.Errorf("unmarshal messages: %w", err) + } + return messages, nil +} + +// List returns all conversations ordered by most recent. +func (m *ConversationManager) List() ([]ConversationSummary, error) { + rows, err := m.db.Query( + `SELECT id, title, model, tokens_in, tokens_out, created_at + FROM conversations ORDER BY updated_at DESC`, + ) + if err != nil { + return nil, fmt.Errorf("list conversations: %w", err) + } + defer rows.Close() + + var summaries []ConversationSummary + for rows.Next() { + var s ConversationSummary + var title sql.NullString + if err := rows.Scan(&s.ID, &title, &s.Model, &s.TokensIn, &s.TokensOut, &s.CreatedAt); err != nil { + return nil, fmt.Errorf("scan conversation: %w", err) + } + if title.Valid { + s.Title = title.String + } + summaries = append(summaries, s) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate conversations: %w", err) + } + + if summaries == nil { + summaries = []ConversationSummary{} + } + return summaries, nil +} + +// Delete removes a conversation and all its messages. +func (m *ConversationManager) Delete(convId string) error { + result, err := m.db.Exec("DELETE FROM conversations WHERE id = ?", convId) + if err != nil { + return fmt.Errorf("delete conversation: %w", err) + } + affected, _ := result.RowsAffected() + if affected == 0 { + return fmt.Errorf("conversation %s not found", convId) + } + return nil +} + +// UpdateTokenUsage adds to the token counters for a conversation. +func (m *ConversationManager) UpdateTokenUsage(convId string, tokensIn, tokensOut int) error { + _, err := m.db.Exec( + `UPDATE conversations + SET tokens_in = tokens_in + ?, tokens_out = tokens_out + ?, updated_at = ? + WHERE id = ?`, + tokensIn, tokensOut, time.Now(), convId, + ) + if err != nil { + return fmt.Errorf("update token usage: %w", err) + } + return nil +} + +// extractTitle generates a title from the first user message (truncated to 80 chars). +func extractTitle(msg Message) string { + for _, block := range msg.Content { + if block.Type == "text" && block.Text != "" { + title := block.Text + if len(title) > 80 { + title = title[:77] + "..." + } + return title + } + } + return "" +} diff --git a/internal/ai/conversation_test.go b/internal/ai/conversation_test.go new file mode 100644 index 0000000..bc433cc --- /dev/null +++ b/internal/ai/conversation_test.go @@ -0,0 +1,220 @@ +package ai + +import ( + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" +) + +func setupConversationManager(t *testing.T) *ConversationManager { + t.Helper() + database, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatalf("open db: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("migrate: %v", err) + } + // Create the conversations table (002 migration) + _, err = database.Exec(`CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, title TEXT, model TEXT NOT NULL, + messages TEXT NOT NULL DEFAULT '[]', + tokens_in INTEGER DEFAULT 0, tokens_out INTEGER DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP)`) + if err != nil { + t.Fatalf("create conversations table: %v", err) + } + t.Cleanup(func() { database.Close() }) + return NewConversationManager(database) +} + +func TestCreateConversation(t *testing.T) { + mgr := setupConversationManager(t) + + conv, err := mgr.Create("claude-sonnet-4-20250514") + if err != nil { + t.Fatalf("create: %v", err) + } + + if conv.ID == "" { + t.Error("expected non-empty ID") + } + if conv.Model != "claude-sonnet-4-20250514" { + t.Errorf("expected model claude-sonnet-4-20250514, got %s", conv.Model) + } + if conv.Title != "New conversation" { + t.Errorf("expected title 'New conversation', got %s", conv.Title) + } + if conv.TokensIn != 0 || conv.TokensOut != 0 { + t.Errorf("expected zero tokens, got in=%d out=%d", conv.TokensIn, conv.TokensOut) + } +} + +func TestAddAndGetMessages(t *testing.T) { + mgr := setupConversationManager(t) + + conv, err := mgr.Create("test-model") + if err != nil { + t.Fatalf("create: %v", err) + } + + // Add a user message + userMsg := Message{ + Role: "user", + Content: []ContentBlock{ + {Type: "text", Text: "What is running on port 8080?"}, + }, + } + if err := mgr.AddMessage(conv.ID, userMsg); err != nil { + t.Fatalf("add user message: %v", err) + } + + // Add an assistant message + assistantMsg := Message{ + Role: "assistant", + Content: []ContentBlock{ + {Type: "text", Text: "Let me check that for you."}, + }, + } + if err := mgr.AddMessage(conv.ID, assistantMsg); err != nil { + t.Fatalf("add assistant message: %v", err) + } + + // Get messages + messages, err := mgr.GetMessages(conv.ID) + if err != nil { + t.Fatalf("get messages: %v", err) + } + if len(messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(messages)) + } + if messages[0].Role != "user" { + t.Errorf("expected first message role 'user', got %s", messages[0].Role) + } + if messages[1].Role != "assistant" { + t.Errorf("expected second message role 'assistant', got %s", messages[1].Role) + } + if messages[0].Content[0].Text != "What is running on port 8080?" { + t.Errorf("unexpected message text: %s", messages[0].Content[0].Text) + } +} + +func TestListConversations(t *testing.T) { + mgr := setupConversationManager(t) + + // Create multiple conversations + _, err := mgr.Create("model-a") + if err != nil { + t.Fatalf("create 1: %v", err) + } + _, err = mgr.Create("model-b") + if err != nil { + t.Fatalf("create 2: %v", err) + } + + list, err := mgr.List() + if err != nil { + t.Fatalf("list: %v", err) + } + if len(list) != 2 { + t.Errorf("expected 2 conversations, got %d", len(list)) + } +} + +func TestDeleteConversation(t *testing.T) { + mgr := setupConversationManager(t) + + conv, err := mgr.Create("test-model") + if err != nil { + t.Fatalf("create: %v", err) + } + + if err := mgr.Delete(conv.ID); err != nil { + t.Fatalf("delete: %v", err) + } + + // Verify it's gone + list, err := mgr.List() + if err != nil { + t.Fatalf("list: %v", err) + } + if len(list) != 0 { + t.Errorf("expected 0 conversations after delete, got %d", len(list)) + } + + // Delete non-existent should error + if err := mgr.Delete("nonexistent"); err == nil { + t.Error("expected error deleting non-existent conversation") + } +} + +func TestTokenUsageTracking(t *testing.T) { + mgr := setupConversationManager(t) + + conv, err := mgr.Create("test-model") + if err != nil { + t.Fatalf("create: %v", err) + } + + // Update token usage multiple times + if err := mgr.UpdateTokenUsage(conv.ID, 100, 50); err != nil { + t.Fatalf("update tokens 1: %v", err) + } + if err := mgr.UpdateTokenUsage(conv.ID, 200, 100); err != nil { + t.Fatalf("update tokens 2: %v", err) + } + + // Verify totals + list, err := mgr.List() + if err != nil { + t.Fatalf("list: %v", err) + } + if len(list) != 1 { + t.Fatalf("expected 1 conversation, got %d", len(list)) + } + if list[0].TokensIn != 300 { + t.Errorf("expected 300 tokens in, got %d", list[0].TokensIn) + } + if list[0].TokensOut != 150 { + t.Errorf("expected 150 tokens out, got %d", list[0].TokensOut) + } +} + +func TestGetMessagesNonExistent(t *testing.T) { + mgr := setupConversationManager(t) + + _, err := mgr.GetMessages("nonexistent-id") + if err == nil { + t.Error("expected error for non-existent conversation") + } +} + +func TestAutoTitle(t *testing.T) { + mgr := setupConversationManager(t) + + conv, err := mgr.Create("test-model") + if err != nil { + t.Fatalf("create: %v", err) + } + + msg := Message{ + Role: "user", + Content: []ContentBlock{ + {Type: "text", Text: "Check disk usage on server-01"}, + }, + } + if err := mgr.AddMessage(conv.ID, msg); err != nil { + t.Fatalf("add message: %v", err) + } + + // Verify the title was auto-set + list, err := mgr.List() + if err != nil { + t.Fatalf("list: %v", err) + } + if list[0].Title != "Check disk usage on server-01" { + t.Errorf("expected auto-title, got %q", list[0].Title) + } +} diff --git a/internal/ai/oauth.go b/internal/ai/oauth.go new file mode 100644 index 0000000..75eb7be --- /dev/null +++ b/internal/ai/oauth.go @@ -0,0 +1,371 @@ +package ai + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "strconv" + "sync" + "time" + + "github.com/vstockwell/wraith/internal/settings" + "github.com/vstockwell/wraith/internal/vault" +) + +const ( + oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + oauthAuthorizeURL = "https://claude.ai/oauth/authorize" + oauthTokenURL = "https://platform.claude.com/v1/oauth/token" +) + +// OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions. +type OAuthManager struct { + settings *settings.SettingsService + vault *vault.VaultService + clientID string + authorizeURL string + tokenURL string + mu sync.RWMutex +} + +// tokenResponse is the JSON body returned by the token endpoint. +type tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` +} + +// NewOAuthManager creates an OAuthManager wired to settings and vault services. +func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager { + return &OAuthManager{ + settings: s, + vault: v, + clientID: oauthClientID, + authorizeURL: oauthAuthorizeURL, + tokenURL: oauthTokenURL, + } +} + +// StartLogin begins the OAuth PKCE flow. It starts a local HTTP server, +// opens the authorization URL in the user's browser, and returns a channel +// that receives nil on success or an error on failure. +func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) { + verifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("generate code verifier: %w", err) + } + challenge := generateCodeChallenge(verifier) + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("generate state: %w", err) + } + + // Start a local HTTP server on a random port for the callback + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("listen for callback: %w", err) + } + port := listener.Addr().(*net.TCPAddr).Port + redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port) + + done := make(chan error, 1) + + mux := http.NewServeMux() + server := &http.Server{Handler: mux} + + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + defer func() { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + server.Shutdown(ctx) + }() + }() + + // Verify state + if r.URL.Query().Get("state") != state { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, "State mismatch — please try logging in again.") + done <- fmt.Errorf("state mismatch") + return + } + + if errParam := r.URL.Query().Get("error"); errParam != "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Authorization error: %s", errParam) + done <- fmt.Errorf("authorization error: %s", errParam) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, "No authorization code received.") + done <- fmt.Errorf("no authorization code") + return + } + + // Exchange code for tokens + if err := o.exchangeCode(code, verifier, redirectURI); err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, "Failed to exchange authorization code.") + done <- fmt.Errorf("exchange code: %w", err) + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Authenticated!

You can close this window and return to Wraith.

`) + done <- nil + }) + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + slog.Error("oauth callback server error", "error", err) + done <- fmt.Errorf("callback server: %w", err) + } + }() + + // Build the authorize URL + params := url.Values{ + "response_type": {"code"}, + "client_id": {o.clientID}, + "redirect_uri": {redirectURI}, + "scope": {"user:inference"}, + "state": {state}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + } + authURL := o.authorizeURL + "?" + params.Encode() + + if openURL != nil { + if err := openURL(authURL); err != nil { + listener.Close() + return nil, fmt.Errorf("open browser: %w", err) + } + } + + return done, nil +} + +// exchangeCode exchanges an authorization code for access and refresh tokens. +func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {o.clientID}, + "code_verifier": {verifier}, + } + + resp, err := http.PostForm(o.tokenURL, data) + if err != nil { + return fmt.Errorf("post token request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body)) + } + + var tokens tokenResponse + if err := json.Unmarshal(body, &tokens); err != nil { + return fmt.Errorf("unmarshal token response: %w", err) + } + + return o.storeTokens(tokens) +} + +// storeTokens encrypts and persists OAuth tokens in settings. +func (o *OAuthManager) storeTokens(tokens tokenResponse) error { + o.mu.Lock() + defer o.mu.Unlock() + + if o.vault == nil { + return fmt.Errorf("vault is not unlocked — cannot store tokens") + } + + encAccess, err := o.vault.Encrypt(tokens.AccessToken) + if err != nil { + return fmt.Errorf("encrypt access token: %w", err) + } + if err := o.settings.Set("ai_access_token", encAccess); err != nil { + return fmt.Errorf("store access token: %w", err) + } + + encRefresh, err := o.vault.Encrypt(tokens.RefreshToken) + if err != nil { + return fmt.Errorf("encrypt refresh token: %w", err) + } + if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil { + return fmt.Errorf("store refresh token: %w", err) + } + + expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix() + if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil { + return fmt.Errorf("store token expiry: %w", err) + } + + slog.Info("oauth tokens stored successfully") + return nil +} + +// GetAccessToken returns a valid access token, refreshing if expired. +func (o *OAuthManager) GetAccessToken() (string, error) { + o.mu.RLock() + vlt := o.vault + o.mu.RUnlock() + + if vlt == nil { + return "", fmt.Errorf("vault is not unlocked") + } + + // Check expiry + expiresStr, _ := o.settings.Get("ai_token_expires_at") + if expiresStr != "" { + expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64) + if time.Now().Unix() >= expiresAt { + // Token expired, try to refresh + if err := o.refreshToken(); err != nil { + return "", fmt.Errorf("refresh expired token: %w", err) + } + } + } + + encAccess, err := o.settings.Get("ai_access_token") + if err != nil || encAccess == "" { + return "", fmt.Errorf("no access token stored") + } + + accessToken, err := vlt.Decrypt(encAccess) + if err != nil { + return "", fmt.Errorf("decrypt access token: %w", err) + } + + return accessToken, nil +} + +// refreshToken uses the stored refresh token to obtain new tokens. +func (o *OAuthManager) refreshToken() error { + o.mu.RLock() + vlt := o.vault + o.mu.RUnlock() + + if vlt == nil { + return fmt.Errorf("vault is not unlocked") + } + + encRefresh, err := o.settings.Get("ai_refresh_token") + if err != nil || encRefresh == "" { + return fmt.Errorf("no refresh token stored") + } + + refreshTok, err := vlt.Decrypt(encRefresh) + if err != nil { + return fmt.Errorf("decrypt refresh token: %w", err) + } + + data := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshTok}, + "client_id": {o.clientID}, + } + + resp, err := http.PostForm(o.tokenURL, data) + if err != nil { + return fmt.Errorf("post refresh request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body)) + } + + var tokens tokenResponse + if err := json.Unmarshal(body, &tokens); err != nil { + return fmt.Errorf("unmarshal refresh response: %w", err) + } + + // If the refresh endpoint doesn't return a new refresh token, keep the old one + if tokens.RefreshToken == "" { + tokens.RefreshToken = refreshTok + } + + return o.storeTokens(tokens) +} + +// IsAuthenticated checks if we have stored tokens. +func (o *OAuthManager) IsAuthenticated() bool { + encAccess, err := o.settings.Get("ai_access_token") + return err == nil && encAccess != "" +} + +// Logout clears all stored OAuth tokens. +func (o *OAuthManager) Logout() error { + for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} { + if err := o.settings.Delete(key); err != nil { + return fmt.Errorf("delete %s: %w", key, err) + } + } + return nil +} + +// generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding). +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// generateCodeChallenge creates a PKCE S256 code challenge from a verifier. +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +// generateState creates a random state parameter (32 random bytes, base64url no padding). +func generateState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// SetVault updates the vault reference (called after vault unlock). +func (o *OAuthManager) SetVault(v *vault.VaultService) { + o.mu.Lock() + defer o.mu.Unlock() + o.vault = v +} + +// isBase64URL checks if a string contains only base64url characters (no padding). +func isBase64URL(s string) bool { + for _, c := range s { + if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') { + return false + } + } + return true +} + diff --git a/internal/ai/oauth_test.go b/internal/ai/oauth_test.go new file mode 100644 index 0000000..25aff48 --- /dev/null +++ b/internal/ai/oauth_test.go @@ -0,0 +1,91 @@ +package ai + +import ( + "crypto/sha256" + "encoding/base64" + "path/filepath" + "testing" + + "github.com/vstockwell/wraith/internal/db" + "github.com/vstockwell/wraith/internal/settings" + "github.com/vstockwell/wraith/internal/vault" +) + +func TestGenerateCodeVerifier(t *testing.T) { + v, err := generateCodeVerifier() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 32 bytes -> 43 base64url chars (no padding) + if len(v) != 43 { + t.Errorf("expected verifier length 43, got %d", len(v)) + } + + if !isBase64URL(v) { + t.Errorf("verifier contains non-base64url characters: %s", v) + } + + // Should be different each time + v2, _ := generateCodeVerifier() + if v == v2 { + t.Error("two generated verifiers should not be identical") + } +} + +func TestGenerateCodeChallenge(t *testing.T) { + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + challenge := generateCodeChallenge(verifier) + + // Verify it matches a manually computed S256 hash + h := sha256.Sum256([]byte(verifier)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + + if challenge != expected { + t.Errorf("expected challenge %s, got %s", expected, challenge) + } + + // Same verifier should produce the same challenge (deterministic) + challenge2 := generateCodeChallenge(verifier) + if challenge != challenge2 { + t.Error("challenge should be deterministic for the same verifier") + } +} + +func TestGenerateState(t *testing.T) { + s, err := generateState() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 32 bytes -> 43 base64url chars (no padding) + if len(s) != 43 { + t.Errorf("expected state length 43, got %d", len(s)) + } + + if !isBase64URL(s) { + t.Errorf("state contains non-base64url characters: %s", s) + } +} + +func TestIsAuthenticatedWhenNoTokens(t *testing.T) { + database, err := db.Open(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatalf("open db: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("migrate: %v", err) + } + t.Cleanup(func() { database.Close() }) + + settingsSvc := settings.NewSettingsService(database) + key := vault.DeriveKey("test-password", []byte("test-salt-1234567890123456789012")) + vaultSvc := vault.NewVaultService(key) + + oauth := NewOAuthManager(settingsSvc, vaultSvc) + + if oauth.IsAuthenticated() { + t.Error("expected IsAuthenticated to return false when no tokens are stored") + } +} diff --git a/internal/ai/router.go b/internal/ai/router.go new file mode 100644 index 0000000..acac332 --- /dev/null +++ b/internal/ai/router.go @@ -0,0 +1,564 @@ +package ai + +import ( + "encoding/json" + "fmt" +) + +// ToolRouter dispatches tool calls to the appropriate service. +// Services are stored as interface{} to avoid circular imports between packages. +type ToolRouter struct { + ssh interface{} // *ssh.SSHService + sftp interface{} // *sftp.SFTPService + rdp interface{} // *rdp.RDPService + sessions interface{} // *session.Manager + connections interface{} // *connections.ConnectionService + aiService interface{} // *AIService — for terminal buffer access +} + +// NewToolRouter creates an empty ToolRouter. Call SetServices to wire in backends. +func NewToolRouter() *ToolRouter { + return &ToolRouter{} +} + +// SetServices wires the router to actual service implementations. +func (r *ToolRouter) SetServices(ssh, sftp, rdp, sessions, connections interface{}) { + r.ssh = ssh + r.sftp = sftp + r.rdp = rdp + r.sessions = sessions + r.connections = connections +} + +// SetAIService wires the router to the AI service for terminal buffer access. +func (r *ToolRouter) SetAIService(aiService interface{}) { + r.aiService = aiService +} + +// sshWriter is the interface we need from SSHService for terminal_write. +type sshWriter interface { + Write(sessionID string, data string) error +} + +// sshSessionLister is the interface for listing SSH sessions. +type sshSessionLister interface { + ListSessions() interface{} +} + +// sftpLister is the interface for SFTP list operations. +type sftpLister interface { + List(sessionID string, path string) (interface{}, error) +} + +// sftpReader is the interface for SFTP read operations. +type sftpReader interface { + ReadFile(sessionID string, path string) (string, error) +} + +// sftpWriter is the interface for SFTP write operations. +type sftpWriter interface { + WriteFile(sessionID string, path string, content string) error +} + +// rdpFrameGetter is the interface for getting RDP screenshots. +type rdpFrameGetter interface { + GetFrame(sessionID string) ([]byte, error) + GetSessionInfo(sessionID string) (interface{}, error) +} + +// rdpMouseSender is the interface for RDP mouse events. +type rdpMouseSender interface { + SendMouse(sessionID string, x, y int, flags uint32) error +} + +// rdpKeySender is the interface for RDP key events. +type rdpKeySender interface { + SendKey(sessionID string, scancode uint32, pressed bool) error +} + +// rdpClipboardSender is the interface for RDP clipboard. +type rdpClipboardSender interface { + SendClipboard(sessionID string, data string) error +} + +// rdpSessionLister is the interface for listing RDP sessions. +type rdpSessionLister interface { + ListSessions() interface{} +} + +// sessionLister is the interface for listing sessions. +type sessionLister interface { + List() interface{} +} + +// bufferProvider provides terminal output buffers. +type bufferProvider interface { + GetBuffer(sessionId string) *TerminalBuffer +} + +// Dispatch routes a tool call to the appropriate handler. +func (r *ToolRouter) Dispatch(toolName string, input json.RawMessage) (interface{}, error) { + switch toolName { + case "terminal_write": + return r.handleTerminalWrite(input) + case "terminal_read": + return r.handleTerminalRead(input) + case "terminal_cwd": + return r.handleTerminalCwd(input) + case "sftp_list": + return r.handleSFTPList(input) + case "sftp_read": + return r.handleSFTPRead(input) + case "sftp_write": + return r.handleSFTPWrite(input) + case "rdp_screenshot": + return r.handleRDPScreenshot(input) + case "rdp_click": + return r.handleRDPClick(input) + case "rdp_doubleclick": + return r.handleRDPDoubleClick(input) + case "rdp_type": + return r.handleRDPType(input) + case "rdp_keypress": + return r.handleRDPKeypress(input) + case "rdp_scroll": + return r.handleRDPScroll(input) + case "rdp_move": + return r.handleRDPMove(input) + case "list_sessions": + return r.handleListSessions(input) + case "connect_ssh": + return r.handleConnectSSH(input) + case "disconnect": + return r.handleDisconnect(input) + default: + return nil, fmt.Errorf("unknown tool: %s", toolName) + } +} + +func (r *ToolRouter) handleTerminalWrite(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Text string `json:"text"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + w, ok := r.ssh.(sshWriter) + if !ok || r.ssh == nil { + return nil, fmt.Errorf("SSH service not available") + } + + if err := w.Write(params.SessionID, params.Text); err != nil { + return nil, err + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleTerminalRead(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Lines int `json:"lines"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + if params.Lines <= 0 { + params.Lines = 50 + } + + bp, ok := r.aiService.(bufferProvider) + if !ok || r.aiService == nil { + return nil, fmt.Errorf("terminal buffer not available") + } + + buf := bp.GetBuffer(params.SessionID) + lines := buf.ReadLast(params.Lines) + return map[string]interface{}{ + "lines": lines, + "count": len(lines), + }, nil +} + +func (r *ToolRouter) handleTerminalCwd(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + // terminal_cwd works by writing "pwd" to the terminal and reading output + w, ok := r.ssh.(sshWriter) + if !ok || r.ssh == nil { + return nil, fmt.Errorf("SSH service not available") + } + + if err := w.Write(params.SessionID, "pwd\n"); err != nil { + return nil, fmt.Errorf("send pwd: %w", err) + } + + return map[string]string{ + "status": "pwd command sent — read terminal output for result", + }, nil +} + +func (r *ToolRouter) handleSFTPList(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + l, ok := r.sftp.(sftpLister) + if !ok || r.sftp == nil { + return nil, fmt.Errorf("SFTP service not available") + } + + return l.List(params.SessionID, params.Path) +} + +func (r *ToolRouter) handleSFTPRead(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + reader, ok := r.sftp.(sftpReader) + if !ok || r.sftp == nil { + return nil, fmt.Errorf("SFTP service not available") + } + + content, err := reader.ReadFile(params.SessionID, params.Path) + if err != nil { + return nil, err + } + return map[string]string{"content": content}, nil +} + +func (r *ToolRouter) handleSFTPWrite(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + w, ok := r.sftp.(sftpWriter) + if !ok || r.sftp == nil { + return nil, fmt.Errorf("SFTP service not available") + } + + if err := w.WriteFile(params.SessionID, params.Path, params.Content); err != nil { + return nil, err + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleRDPScreenshot(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + fg, ok := r.rdp.(rdpFrameGetter) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + frame, err := fg.GetFrame(params.SessionID) + if err != nil { + return nil, err + } + + // Get session info for dimensions + info, err := fg.GetSessionInfo(params.SessionID) + if err != nil { + return nil, err + } + + // Try to get dimensions from session config + type configGetter interface { + GetConfig() (int, int) + } + width, height := 1920, 1080 + if cg, ok := info.(configGetter); ok { + width, height = cg.GetConfig() + } + + // Encode as JPEG + jpeg, err := EncodeScreenshot(frame, width, height, 1280, 720, 75) + if err != nil { + return nil, fmt.Errorf("encode screenshot: %w", err) + } + + return map[string]interface{}{ + "image": jpeg, + "width": width, + "height": height, + }, nil +} + +func (r *ToolRouter) handleRDPClick(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + X int `json:"x"` + Y int `json:"y"` + Button string `json:"button"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + ms, ok := r.rdp.(rdpMouseSender) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + var buttonFlag uint32 = 0x1000 // left + switch params.Button { + case "right": + buttonFlag = 0x2000 + case "middle": + buttonFlag = 0x4000 + } + + // Press + if err := ms.SendMouse(params.SessionID, params.X, params.Y, buttonFlag|0x8000); err != nil { + return nil, err + } + // Release + if err := ms.SendMouse(params.SessionID, params.X, params.Y, buttonFlag); err != nil { + return nil, err + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleRDPDoubleClick(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + X int `json:"x"` + Y int `json:"y"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + ms, ok := r.rdp.(rdpMouseSender) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + // Two clicks + for i := 0; i < 2; i++ { + if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x1000|0x8000); err != nil { + return nil, err + } + if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x1000); err != nil { + return nil, err + } + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleRDPType(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Text string `json:"text"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + cs, ok := r.rdp.(rdpClipboardSender) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + // Send text via clipboard, then simulate Ctrl+V + if err := cs.SendClipboard(params.SessionID, params.Text); err != nil { + return nil, err + } + + ks, ok := r.rdp.(rdpKeySender) + if !ok { + return nil, fmt.Errorf("RDP key service not available") + } + + // Ctrl down, V down, V up, Ctrl up + if err := ks.SendKey(params.SessionID, 0x001D, true); err != nil { + return nil, err + } + if err := ks.SendKey(params.SessionID, 0x002F, true); err != nil { + return nil, err + } + if err := ks.SendKey(params.SessionID, 0x002F, false); err != nil { + return nil, err + } + if err := ks.SendKey(params.SessionID, 0x001D, false); err != nil { + return nil, err + } + + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleRDPKeypress(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + Key string `json:"key"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + ks, ok := r.rdp.(rdpKeySender) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + // Simple key name to scancode mapping for common keys + keyMap := map[string]uint32{ + "Enter": 0x001C, + "Tab": 0x000F, + "Escape": 0x0001, + "Backspace": 0x000E, + "Delete": 0xE053, + "Space": 0x0039, + "Up": 0xE048, + "Down": 0xE050, + "Left": 0xE04B, + "Right": 0xE04D, + } + + scancode, ok := keyMap[params.Key] + if !ok { + return nil, fmt.Errorf("unknown key: %s", params.Key) + } + + if err := ks.SendKey(params.SessionID, scancode, true); err != nil { + return nil, err + } + if err := ks.SendKey(params.SessionID, scancode, false); err != nil { + return nil, err + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleRDPScroll(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + X int `json:"x"` + Y int `json:"y"` + Direction string `json:"direction"` + Clicks int `json:"clicks"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + if params.Clicks <= 0 { + params.Clicks = 3 + } + + ms, ok := r.rdp.(rdpMouseSender) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + var flags uint32 = 0x0200 // wheel flag + if params.Direction == "down" { + flags |= 0x0100 // negative flag + } + + for i := 0; i < params.Clicks; i++ { + if err := ms.SendMouse(params.SessionID, params.X, params.Y, flags); err != nil { + return nil, err + } + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleRDPMove(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + X int `json:"x"` + Y int `json:"y"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + ms, ok := r.rdp.(rdpMouseSender) + if !ok || r.rdp == nil { + return nil, fmt.Errorf("RDP service not available") + } + + if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x0800); err != nil { + return nil, err + } + return map[string]string{"status": "ok"}, nil +} + +func (r *ToolRouter) handleListSessions(_ json.RawMessage) (interface{}, error) { + result := map[string]interface{}{ + "ssh": []interface{}{}, + "rdp": []interface{}{}, + } + + if r.sessions != nil { + if sl, ok := r.sessions.(sessionLister); ok { + result["all"] = sl.List() + } + } + + return result, nil +} + +func (r *ToolRouter) handleConnectSSH(input json.RawMessage) (interface{}, error) { + var params struct { + ConnectionID int64 `json:"connectionId"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + // This will be wired to the app-level connect logic later + return nil, fmt.Errorf("connect_ssh requires app-level wiring — not yet available via tool dispatch") +} + +func (r *ToolRouter) handleDisconnect(input json.RawMessage) (interface{}, error) { + var params struct { + SessionID string `json:"sessionId"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return nil, fmt.Errorf("parse input: %w", err) + } + + // Try SSH first + type disconnecter interface { + Disconnect(sessionID string) error + } + + if d, ok := r.ssh.(disconnecter); ok { + if err := d.Disconnect(params.SessionID); err == nil { + return map[string]string{"status": "disconnected", "protocol": "ssh"}, nil + } + } + + if d, ok := r.rdp.(disconnecter); ok { + if err := d.Disconnect(params.SessionID); err == nil { + return map[string]string{"status": "disconnected", "protocol": "rdp"}, nil + } + } + + return nil, fmt.Errorf("session %s not found in SSH or RDP", params.SessionID) +} diff --git a/internal/ai/router_test.go b/internal/ai/router_test.go new file mode 100644 index 0000000..d28d84b --- /dev/null +++ b/internal/ai/router_test.go @@ -0,0 +1,85 @@ +package ai + +import ( + "encoding/json" + "testing" +) + +func TestDispatchUnknownTool(t *testing.T) { + router := NewToolRouter() + + _, err := router.Dispatch("nonexistent_tool", json.RawMessage(`{}`)) + if err == nil { + t.Error("expected error for unknown tool") + } + if err.Error() != "unknown tool: nonexistent_tool" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDispatchListSessions(t *testing.T) { + router := NewToolRouter() + + // With nil services, list_sessions should return empty result + result, err := router.Dispatch("list_sessions", json.RawMessage(`{}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + m, ok := result.(map[string]interface{}) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + + ssh, ok := m["ssh"].([]interface{}) + if !ok { + t.Fatal("expected ssh key in result") + } + if len(ssh) != 0 { + t.Errorf("expected empty ssh list, got %d items", len(ssh)) + } + + rdp, ok := m["rdp"].([]interface{}) + if !ok { + t.Fatal("expected rdp key in result") + } + if len(rdp) != 0 { + t.Errorf("expected empty rdp list, got %d items", len(rdp)) + } +} + +func TestDispatchTerminalWriteNoService(t *testing.T) { + router := NewToolRouter() + + _, err := router.Dispatch("terminal_write", json.RawMessage(`{"sessionId":"abc","text":"ls\n"}`)) + if err == nil { + t.Error("expected error when SSH service is nil") + } +} + +func TestDispatchTerminalReadNoService(t *testing.T) { + router := NewToolRouter() + + _, err := router.Dispatch("terminal_read", json.RawMessage(`{"sessionId":"abc"}`)) + if err == nil { + t.Error("expected error when AI service is nil") + } +} + +func TestDispatchSFTPListNoService(t *testing.T) { + router := NewToolRouter() + + _, err := router.Dispatch("sftp_list", json.RawMessage(`{"sessionId":"abc","path":"/"}`)) + if err == nil { + t.Error("expected error when SFTP service is nil") + } +} + +func TestDispatchDisconnectNoService(t *testing.T) { + router := NewToolRouter() + + _, err := router.Dispatch("disconnect", json.RawMessage(`{"sessionId":"abc"}`)) + if err == nil { + t.Error("expected error when no services available") + } +} diff --git a/internal/ai/screenshot.go b/internal/ai/screenshot.go new file mode 100644 index 0000000..e5e84e7 --- /dev/null +++ b/internal/ai/screenshot.go @@ -0,0 +1,79 @@ +package ai + +import ( + "bytes" + "fmt" + "image" + "image/jpeg" +) + +// EncodeScreenshot converts raw RGBA pixel data to a JPEG image. +// If the source dimensions exceed maxWidth x maxHeight, the image is +// downscaled using nearest-neighbor sampling (fast, no external deps). +// Returns the JPEG bytes. +func EncodeScreenshot(rgba []byte, srcWidth, srcHeight, maxWidth, maxHeight, quality int) ([]byte, error) { + expectedLen := srcWidth * srcHeight * 4 + if len(rgba) < expectedLen { + return nil, fmt.Errorf("RGBA buffer too small: got %d bytes, expected %d for %dx%d", len(rgba), expectedLen, srcWidth, srcHeight) + } + + if quality <= 0 || quality > 100 { + quality = 75 + } + + // Create source image from RGBA buffer + src := image.NewRGBA(image.Rect(0, 0, srcWidth, srcHeight)) + copy(src.Pix, rgba[:expectedLen]) + + // Determine output dimensions + dstWidth, dstHeight := srcWidth, srcHeight + if srcWidth > maxWidth || srcHeight > maxHeight { + dstWidth, dstHeight = fitDimensions(srcWidth, srcHeight, maxWidth, maxHeight) + } + + var img image.Image = src + + // Downscale if needed using nearest-neighbor sampling + if dstWidth != srcWidth || dstHeight != srcHeight { + dst := image.NewRGBA(image.Rect(0, 0, dstWidth, dstHeight)) + for y := 0; y < dstHeight; y++ { + srcY := y * srcHeight / dstHeight + for x := 0; x < dstWidth; x++ { + srcX := x * srcWidth / dstWidth + srcIdx := (srcY*srcWidth + srcX) * 4 + dstIdx := (y*dstWidth + x) * 4 + dst.Pix[dstIdx+0] = src.Pix[srcIdx+0] // R + dst.Pix[dstIdx+1] = src.Pix[srcIdx+1] // G + dst.Pix[dstIdx+2] = src.Pix[srcIdx+2] // B + dst.Pix[dstIdx+3] = src.Pix[srcIdx+3] // A + } + } + img = dst + } + + // Encode to JPEG + var buf bytes.Buffer + if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil { + return nil, fmt.Errorf("encode JPEG: %w", err) + } + + return buf.Bytes(), nil +} + +// fitDimensions calculates the largest dimensions that fit within max bounds +// while preserving aspect ratio. +func fitDimensions(srcW, srcH, maxW, maxH int) (int, int) { + ratio := float64(srcW) / float64(srcH) + w, h := maxW, int(float64(maxW)/ratio) + if h > maxH { + h = maxH + w = int(float64(maxH) * ratio) + } + if w <= 0 { + w = 1 + } + if h <= 0 { + h = 1 + } + return w, h +} diff --git a/internal/ai/screenshot_test.go b/internal/ai/screenshot_test.go new file mode 100644 index 0000000..d6e71c9 --- /dev/null +++ b/internal/ai/screenshot_test.go @@ -0,0 +1,118 @@ +package ai + +import ( + "bytes" + "image/jpeg" + "testing" +) + +func TestEncodeScreenshot(t *testing.T) { + width, height := 100, 80 + + // Create a test RGBA buffer (red pixels) + rgba := make([]byte, width*height*4) + for i := 0; i < len(rgba); i += 4 { + rgba[i+0] = 255 // R + rgba[i+1] = 0 // G + rgba[i+2] = 0 // B + rgba[i+3] = 255 // A + } + + jpegData, err := EncodeScreenshot(rgba, width, height, 200, 200, 80) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check JPEG magic bytes (FF D8) + if len(jpegData) < 2 { + t.Fatal("JPEG data too small") + } + if jpegData[0] != 0xFF || jpegData[1] != 0xD8 { + t.Errorf("expected JPEG magic bytes FF D8, got %02X %02X", jpegData[0], jpegData[1]) + } + + // Decode and verify dimensions (no downscale needed in this case) + img, err := jpeg.Decode(bytes.NewReader(jpegData)) + if err != nil { + t.Fatalf("decode JPEG: %v", err) + } + bounds := img.Bounds() + if bounds.Dx() != width || bounds.Dy() != height { + t.Errorf("expected %dx%d, got %dx%d", width, height, bounds.Dx(), bounds.Dy()) + } +} + +func TestEncodeScreenshotDownscale(t *testing.T) { + srcWidth, srcHeight := 1920, 1080 + maxWidth, maxHeight := 1280, 720 + + // Create a test RGBA buffer (gradient) + rgba := make([]byte, srcWidth*srcHeight*4) + for y := 0; y < srcHeight; y++ { + for x := 0; x < srcWidth; x++ { + idx := (y*srcWidth + x) * 4 + rgba[idx+0] = byte(x % 256) // R + rgba[idx+1] = byte(y % 256) // G + rgba[idx+2] = byte((x + y) % 256) // B + rgba[idx+3] = 255 // A + } + } + + jpegData, err := EncodeScreenshot(rgba, srcWidth, srcHeight, maxWidth, maxHeight, 75) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check JPEG magic bytes + if jpegData[0] != 0xFF || jpegData[1] != 0xD8 { + t.Errorf("expected JPEG magic bytes FF D8, got %02X %02X", jpegData[0], jpegData[1]) + } + + // Decode and verify dimensions are within max bounds + img, err := jpeg.Decode(bytes.NewReader(jpegData)) + if err != nil { + t.Fatalf("decode JPEG: %v", err) + } + bounds := img.Bounds() + + if bounds.Dx() > maxWidth { + t.Errorf("output width %d exceeds max %d", bounds.Dx(), maxWidth) + } + if bounds.Dy() > maxHeight { + t.Errorf("output height %d exceeds max %d", bounds.Dy(), maxHeight) + } + + // Should maintain 16:9 ratio approximately + expectedWidth := 1280 + expectedHeight := 720 + if bounds.Dx() != expectedWidth || bounds.Dy() != expectedHeight { + t.Errorf("expected %dx%d, got %dx%d", expectedWidth, expectedHeight, bounds.Dx(), bounds.Dy()) + } +} + +func TestEncodeScreenshotBufferTooSmall(t *testing.T) { + _, err := EncodeScreenshot([]byte{0, 0, 0, 0}, 100, 100, 200, 200, 75) + if err == nil { + t.Error("expected error for buffer too small") + } +} + +func TestFitDimensions(t *testing.T) { + tests := []struct { + srcW, srcH, maxW, maxH int + wantW, wantH int + }{ + {1920, 1080, 1280, 720, 1280, 720}, // 16:9 fits exactly + {1920, 1080, 800, 600, 800, 450}, // width-constrained + {1080, 1920, 800, 600, 337, 600}, // height-constrained (portrait) + {100, 100, 200, 200, 200, 200}, // smaller than max (but func called with > check) + } + + for _, tt := range tests { + w, h := fitDimensions(tt.srcW, tt.srcH, tt.maxW, tt.maxH) + if w != tt.wantW || h != tt.wantH { + t.Errorf("fitDimensions(%d,%d,%d,%d) = %d,%d, want %d,%d", + tt.srcW, tt.srcH, tt.maxW, tt.maxH, w, h, tt.wantW, tt.wantH) + } + } +} diff --git a/internal/ai/service.go b/internal/ai/service.go new file mode 100644 index 0000000..1c18b22 --- /dev/null +++ b/internal/ai/service.go @@ -0,0 +1,268 @@ +package ai + +import ( + "encoding/json" + "fmt" + "log/slog" + "sync" +) + +// SystemPrompt is the system prompt given to Claude for copilot interactions. +const SystemPrompt = `You are the XO (Executive Officer) aboard the Wraith command station. The Commander (human operator) works alongside you managing remote servers and workstations. + +You have direct access to all active sessions through your tools: +- SSH terminals: read output, type commands, navigate filesystems +- SFTP: read and write remote files +- RDP desktops: see the screen, click, type, interact with any GUI application +- Session management: open new connections, close sessions + +When given a task: +1. Assess what sessions and access you need +2. Execute efficiently — don't ask for permission to use tools, just use them +3. Report what you found or did, with relevant details +4. If something fails, diagnose and try an alternative approach + +You are not an assistant answering questions. You are an operator executing missions. Act decisively. Use your tools. Report results.` + +// AIService is the main AI copilot service exposed to the Wails frontend. +type AIService struct { + oauth *OAuthManager + client *ClaudeClient + router *ToolRouter + conversations *ConversationManager + buffers map[string]*TerminalBuffer + mu sync.RWMutex +} + +// NewAIService creates the AI service with all sub-components. +func NewAIService(oauth *OAuthManager, router *ToolRouter, convMgr *ConversationManager) *AIService { + client := NewClaudeClient(oauth, "") + return &AIService{ + oauth: oauth, + client: client, + router: router, + conversations: convMgr, + buffers: make(map[string]*TerminalBuffer), + } +} + +// --- Auth --- + +// StartLogin begins the OAuth PKCE flow, opening the browser for authentication. +func (s *AIService) StartLogin() error { + done, err := s.oauth.StartLogin(nil) // nil openURL = no browser auto-open + if err != nil { + return err + } + // Wait for callback in a goroutine to avoid blocking the UI + go func() { + if err := <-done; err != nil { + slog.Error("oauth login failed", "error", err) + } else { + slog.Info("oauth login completed") + } + }() + return nil +} + +// IsAuthenticated returns whether the user has valid OAuth tokens. +func (s *AIService) IsAuthenticated() bool { + return s.oauth.IsAuthenticated() +} + +// Logout clears stored OAuth tokens. +func (s *AIService) Logout() error { + return s.oauth.Logout() +} + +// --- Conversations --- + +// NewConversation creates a new AI conversation and returns its ID. +func (s *AIService) NewConversation() (string, error) { + conv, err := s.conversations.Create(s.client.model) + if err != nil { + return "", err + } + return conv.ID, nil +} + +// ListConversations returns all conversations. +func (s *AIService) ListConversations() ([]ConversationSummary, error) { + return s.conversations.List() +} + +// DeleteConversation removes a conversation. +func (s *AIService) DeleteConversation(id string) error { + return s.conversations.Delete(id) +} + +// --- Chat --- + +// SendMessage sends a user message in a conversation and processes the AI response. +// Tool calls are automatically dispatched and results fed back to the model. +// This method blocks until the full response (including any tool use loops) is complete. +func (s *AIService) SendMessage(conversationId, text string) error { + // Add user message to conversation + userMsg := Message{ + Role: "user", + Content: []ContentBlock{ + {Type: "text", Text: text}, + }, + } + if err := s.conversations.AddMessage(conversationId, userMsg); err != nil { + return fmt.Errorf("store user message: %w", err) + } + + // Run the message loop (handles tool use) + return s.messageLoop(conversationId) +} + +// messageLoop sends the conversation to Claude and handles tool use loops. +func (s *AIService) messageLoop(conversationId string) error { + for iterations := 0; iterations < 20; iterations++ { // safety limit + messages, err := s.conversations.GetMessages(conversationId) + if err != nil { + return err + } + + ch, err := s.client.SendMessage(messages, CopilotTools, SystemPrompt) + if err != nil { + return fmt.Errorf("send to claude: %w", err) + } + + // Collect the response + var textParts []string + var toolCalls []ContentBlock + var currentToolInput string + + for event := range ch { + switch event.Type { + case "text_delta": + textParts = append(textParts, event.Data) + case "tool_use_start": + currentToolInput = "" + case "tool_use_delta": + currentToolInput += event.Data + case "done": + // Parse usage if available + var delta struct { + Usage Usage `json:"usage"` + } + if json.Unmarshal([]byte(event.Data), &delta) == nil && delta.Usage.OutputTokens > 0 { + s.conversations.UpdateTokenUsage(conversationId, delta.Usage.InputTokens, delta.Usage.OutputTokens) + } + case "error": + return fmt.Errorf("stream error: %s", event.Data) + } + + // When a tool_use block completes, we need to check content_block_stop + // But since we accumulate, we'll finalize after the stream ends + _ = toolCalls // kept for the final assembly below + } + + // Build the assistant message + var assistantContent []ContentBlock + if len(textParts) > 0 { + fullText := "" + for _, p := range textParts { + fullText += p + } + if fullText != "" { + assistantContent = append(assistantContent, ContentBlock{ + Type: "text", + Text: fullText, + }) + } + } + for _, tc := range toolCalls { + assistantContent = append(assistantContent, tc) + } + + if len(assistantContent) == 0 { + return nil // empty response + } + + // Store assistant message + assistantMsg := Message{ + Role: "assistant", + Content: assistantContent, + } + if err := s.conversations.AddMessage(conversationId, assistantMsg); err != nil { + return fmt.Errorf("store assistant message: %w", err) + } + + // If there were tool calls, dispatch them and continue the loop + hasToolUse := false + for _, block := range assistantContent { + if block.Type == "tool_use" { + hasToolUse = true + break + } + } + + if !hasToolUse { + return nil // done, no tool use to process + } + + // Dispatch tool calls and create tool_result message + var toolResults []ContentBlock + for _, block := range assistantContent { + if block.Type != "tool_use" { + continue + } + + result, err := s.router.Dispatch(block.Name, block.Input) + resultBlock := ContentBlock{ + Type: "tool_result", + ToolUseID: block.ID, + } + + if err != nil { + resultBlock.IsError = true + resultBlock.Content = []ContentBlock{ + {Type: "text", Text: err.Error()}, + } + } else { + resultJSON, _ := json.Marshal(result) + resultBlock.Content = []ContentBlock{ + {Type: "text", Text: string(resultJSON)}, + } + } + toolResults = append(toolResults, resultBlock) + } + + toolResultMsg := Message{ + Role: "user", + Content: toolResults, + } + if err := s.conversations.AddMessage(conversationId, toolResultMsg); err != nil { + return fmt.Errorf("store tool results: %w", err) + } + + // Continue the loop to let Claude process tool results + } + + return fmt.Errorf("exceeded maximum tool use iterations") +} + +// --- Terminal Buffer Management --- + +// GetBuffer returns the terminal output buffer for a session, creating it if needed. +func (s *AIService) GetBuffer(sessionId string) *TerminalBuffer { + s.mu.Lock() + defer s.mu.Unlock() + + buf, ok := s.buffers[sessionId] + if !ok { + buf = NewTerminalBuffer(200) + s.buffers[sessionId] = buf + } + return buf +} + +// RemoveBuffer removes the terminal buffer for a session. +func (s *AIService) RemoveBuffer(sessionId string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.buffers, sessionId) +} diff --git a/internal/ai/terminal_buffer.go b/internal/ai/terminal_buffer.go new file mode 100644 index 0000000..a315054 --- /dev/null +++ b/internal/ai/terminal_buffer.go @@ -0,0 +1,101 @@ +package ai + +import ( + "strings" + "sync" +) + +// TerminalBuffer is a thread-safe ring buffer that captures terminal output lines. +// It is written to by SSH read loops and read by the AI tool dispatch for terminal_read. +type TerminalBuffer struct { + lines []string + mu sync.RWMutex + max int + partial string // accumulates data that doesn't end with \n +} + +// NewTerminalBuffer creates a buffer that retains at most maxLines lines. +func NewTerminalBuffer(maxLines int) *TerminalBuffer { + if maxLines <= 0 { + maxLines = 200 + } + return &TerminalBuffer{ + lines: make([]string, 0, maxLines), + max: maxLines, + } +} + +// Write ingests raw terminal output, splitting on newlines and appending complete lines. +// Partial lines (data without a trailing newline) are accumulated until the next Write. +func (b *TerminalBuffer) Write(data []byte) { + b.mu.Lock() + defer b.mu.Unlock() + + text := b.partial + string(data) + b.partial = "" + + parts := strings.Split(text, "\n") + + // The last element of Split is always either: + // - empty string if text ends with \n (discard it) + // - a partial line if text doesn't end with \n (save as partial) + last := parts[len(parts)-1] + parts = parts[:len(parts)-1] + if last != "" { + b.partial = last + } + + for _, line := range parts { + b.lines = append(b.lines, line) + } + + // Trim to max + if len(b.lines) > b.max { + excess := len(b.lines) - b.max + b.lines = b.lines[excess:] + } +} + +// ReadLast returns the last n lines from the buffer. +// If fewer than n lines are available, all lines are returned. +func (b *TerminalBuffer) ReadLast(n int) []string { + b.mu.RLock() + defer b.mu.RUnlock() + + total := len(b.lines) + if n > total { + n = total + } + if n <= 0 { + return []string{} + } + + result := make([]string, n) + copy(result, b.lines[total-n:]) + return result +} + +// ReadAll returns all lines currently in the buffer. +func (b *TerminalBuffer) ReadAll() []string { + b.mu.RLock() + defer b.mu.RUnlock() + + result := make([]string, len(b.lines)) + copy(result, b.lines) + return result +} + +// Clear removes all lines from the buffer. +func (b *TerminalBuffer) Clear() { + b.mu.Lock() + defer b.mu.Unlock() + b.lines = b.lines[:0] + b.partial = "" +} + +// Len returns the number of complete lines in the buffer. +func (b *TerminalBuffer) Len() int { + b.mu.RLock() + defer b.mu.RUnlock() + return len(b.lines) +} diff --git a/internal/ai/terminal_buffer_test.go b/internal/ai/terminal_buffer_test.go new file mode 100644 index 0000000..19a76a8 --- /dev/null +++ b/internal/ai/terminal_buffer_test.go @@ -0,0 +1,175 @@ +package ai + +import ( + "fmt" + "sync" + "testing" +) + +func TestWriteAndRead(t *testing.T) { + buf := NewTerminalBuffer(100) + + buf.Write([]byte("line1\nline2\nline3\n")) + + all := buf.ReadAll() + if len(all) != 3 { + t.Fatalf("expected 3 lines, got %d", len(all)) + } + if all[0] != "line1" || all[1] != "line2" || all[2] != "line3" { + t.Errorf("unexpected lines: %v", all) + } +} + +func TestRingBufferOverflow(t *testing.T) { + buf := NewTerminalBuffer(200) + + // Write 300 lines + for i := 0; i < 300; i++ { + buf.Write([]byte(fmt.Sprintf("line %d\n", i))) + } + + all := buf.ReadAll() + if len(all) != 200 { + t.Fatalf("expected 200 lines (ring buffer), got %d", len(all)) + } + + // First line should be "line 100" (oldest retained) + if all[0] != "line 100" { + t.Errorf("expected first line 'line 100', got %q", all[0]) + } + + // Last line should be "line 299" + if all[199] != "line 299" { + t.Errorf("expected last line 'line 299', got %q", all[199]) + } +} + +func TestReadLastSubset(t *testing.T) { + buf := NewTerminalBuffer(100) + + buf.Write([]byte("a\nb\nc\nd\ne\n")) + + last3 := buf.ReadLast(3) + if len(last3) != 3 { + t.Fatalf("expected 3 lines, got %d", len(last3)) + } + if last3[0] != "c" || last3[1] != "d" || last3[2] != "e" { + t.Errorf("unexpected lines: %v", last3) + } + + // Request more than available + last10 := buf.ReadLast(10) + if len(last10) != 5 { + t.Fatalf("expected 5 lines (all available), got %d", len(last10)) + } +} + +func TestConcurrentAccess(t *testing.T) { + buf := NewTerminalBuffer(1000) + + var wg sync.WaitGroup + + // Spawn 10 writers + for w := 0; w < 10; w++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < 100; i++ { + buf.Write([]byte(fmt.Sprintf("writer %d line %d\n", id, i))) + } + }(w) + } + + // Spawn 5 readers + for r := 0; r < 5; r++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + _ = buf.ReadLast(10) + _ = buf.Len() + } + }() + } + + wg.Wait() + + total := buf.Len() + if total != 1000 { + t.Errorf("expected 1000 lines (10 writers * 100), got %d", total) + } +} + +func TestWritePartialLines(t *testing.T) { + buf := NewTerminalBuffer(100) + + // Write data without trailing newline + buf.Write([]byte("partial")) + + // Should have no complete lines yet + if buf.Len() != 0 { + t.Errorf("expected 0 lines for partial write, got %d", buf.Len()) + } + + // Complete the line + buf.Write([]byte(" line\n")) + + if buf.Len() != 1 { + t.Fatalf("expected 1 line after completing partial, got %d", buf.Len()) + } + + all := buf.ReadAll() + if all[0] != "partial line" { + t.Errorf("expected 'partial line', got %q", all[0]) + } +} + +func TestWriteMultiplePartials(t *testing.T) { + buf := NewTerminalBuffer(100) + + buf.Write([]byte("hello ")) + buf.Write([]byte("world")) + buf.Write([]byte("!\nfoo\n")) + + all := buf.ReadAll() + if len(all) != 2 { + t.Fatalf("expected 2 lines, got %d: %v", len(all), all) + } + if all[0] != "hello world!" { + t.Errorf("expected 'hello world!', got %q", all[0]) + } + if all[1] != "foo" { + t.Errorf("expected 'foo', got %q", all[1]) + } +} + +func TestClear(t *testing.T) { + buf := NewTerminalBuffer(100) + + buf.Write([]byte("line1\nline2\n")) + buf.Clear() + + if buf.Len() != 0 { + t.Errorf("expected 0 lines after clear, got %d", buf.Len()) + } +} + +func TestReadLastZero(t *testing.T) { + buf := NewTerminalBuffer(100) + buf.Write([]byte("line\n")) + + result := buf.ReadLast(0) + if len(result) != 0 { + t.Errorf("expected empty result for ReadLast(0), got %d lines", len(result)) + } +} + +func TestReadLastNegative(t *testing.T) { + buf := NewTerminalBuffer(100) + buf.Write([]byte("line\n")) + + result := buf.ReadLast(-1) + if len(result) != 0 { + t.Errorf("expected empty result for ReadLast(-1), got %d lines", len(result)) + } +} diff --git a/internal/ai/tools.go b/internal/ai/tools.go new file mode 100644 index 0000000..f10f7c5 --- /dev/null +++ b/internal/ai/tools.go @@ -0,0 +1,207 @@ +package ai + +import "encoding/json" + +// CopilotTools defines all tools available to the AI copilot. +var CopilotTools = []Tool{ + // ── SSH Terminal Tools ──────────────────────────────────────── + { + Name: "terminal_write", + Description: "Type text into an active SSH terminal session. Use this to execute commands by including a trailing newline character.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The SSH session ID"}, + "text": {"type": "string", "description": "Text to type into the terminal (include \\n for Enter)"} + }, + "required": ["sessionId", "text"] + }`), + }, + { + Name: "terminal_read", + Description: "Read recent output from an SSH terminal session. Returns the last N lines from the terminal buffer.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The SSH session ID"}, + "lines": {"type": "integer", "description": "Number of recent lines to return (default 50)", "default": 50} + }, + "required": ["sessionId"] + }`), + }, + { + Name: "terminal_cwd", + Description: "Get the current working directory of an SSH terminal session by executing pwd.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The SSH session ID"} + }, + "required": ["sessionId"] + }`), + }, + + // ── SFTP Tools ─────────────────────────────────────────────── + { + Name: "sftp_list", + Description: "List files and directories at a remote path via SFTP.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The SSH/SFTP session ID"}, + "path": {"type": "string", "description": "Remote directory path to list"} + }, + "required": ["sessionId", "path"] + }`), + }, + { + Name: "sftp_read", + Description: "Read the contents of a remote file via SFTP. Limited to files under 5MB.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The SSH/SFTP session ID"}, + "path": {"type": "string", "description": "Remote file path to read"} + }, + "required": ["sessionId", "path"] + }`), + }, + { + Name: "sftp_write", + Description: "Write content to a remote file via SFTP. Creates or overwrites the file.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The SSH/SFTP session ID"}, + "path": {"type": "string", "description": "Remote file path to write"}, + "content": {"type": "string", "description": "File content to write"} + }, + "required": ["sessionId", "path", "content"] + }`), + }, + + // ── RDP Tools ──────────────────────────────────────────────── + { + Name: "rdp_screenshot", + Description: "Capture a screenshot of the current RDP desktop. Returns a base64 JPEG image.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"} + }, + "required": ["sessionId"] + }`), + }, + { + Name: "rdp_click", + Description: "Click at a specific position on the RDP desktop.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"}, + "x": {"type": "integer", "description": "X coordinate in pixels"}, + "y": {"type": "integer", "description": "Y coordinate in pixels"}, + "button": {"type": "string", "enum": ["left", "right", "middle"], "default": "left", "description": "Mouse button to click"} + }, + "required": ["sessionId", "x", "y"] + }`), + }, + { + Name: "rdp_doubleclick", + Description: "Double-click at a specific position on the RDP desktop.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"}, + "x": {"type": "integer", "description": "X coordinate in pixels"}, + "y": {"type": "integer", "description": "Y coordinate in pixels"} + }, + "required": ["sessionId", "x", "y"] + }`), + }, + { + Name: "rdp_type", + Description: "Type text into the focused element on the RDP desktop. Uses clipboard paste for reliability.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"}, + "text": {"type": "string", "description": "Text to type"} + }, + "required": ["sessionId", "text"] + }`), + }, + { + Name: "rdp_keypress", + Description: "Press a key or key combination on the RDP desktop (e.g., 'Enter', 'Tab', 'Ctrl+C').", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"}, + "key": {"type": "string", "description": "Key name or combination (e.g., 'Enter', 'Ctrl+C', 'Alt+F4')"} + }, + "required": ["sessionId", "key"] + }`), + }, + { + Name: "rdp_scroll", + Description: "Scroll the mouse wheel at a position on the RDP desktop.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"}, + "x": {"type": "integer", "description": "X coordinate in pixels"}, + "y": {"type": "integer", "description": "Y coordinate in pixels"}, + "direction": {"type": "string", "enum": ["up", "down"], "description": "Scroll direction"}, + "clicks": {"type": "integer", "description": "Number of scroll clicks (default 3)", "default": 3} + }, + "required": ["sessionId", "x", "y", "direction"] + }`), + }, + { + Name: "rdp_move", + Description: "Move the mouse cursor to a position on the RDP desktop without clicking.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The RDP session ID"}, + "x": {"type": "integer", "description": "X coordinate in pixels"}, + "y": {"type": "integer", "description": "Y coordinate in pixels"} + }, + "required": ["sessionId", "x", "y"] + }`), + }, + + // ── Session Management Tools ───────────────────────────────── + { + Name: "list_sessions", + Description: "List all active SSH and RDP sessions with their IDs, connection info, and state.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": {}, + "required": [] + }`), + }, + { + Name: "connect_ssh", + Description: "Open a new SSH connection to a saved connection by its ID.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "connectionId": {"type": "integer", "description": "The saved connection ID from the connection manager"} + }, + "required": ["connectionId"] + }`), + }, + { + Name: "disconnect", + Description: "Disconnect an active session.", + InputSchema: json.RawMessage(`{ + "type": "object", + "properties": { + "sessionId": {"type": "string", "description": "The session ID to disconnect"} + }, + "required": ["sessionId"] + }`), + }, +} diff --git a/internal/ai/types.go b/internal/ai/types.go new file mode 100644 index 0000000..28847ee --- /dev/null +++ b/internal/ai/types.go @@ -0,0 +1,75 @@ +package ai + +import ( + "encoding/json" + "time" +) + +// Message represents a single message in a conversation with Claude. +type Message struct { + Role string `json:"role"` // "user" or "assistant" + Content []ContentBlock `json:"content"` // one or more content blocks +} + +// ContentBlock is a polymorphic block within a Message. +// Only one of the content fields will be populated depending on Type. +type ContentBlock struct { + Type string `json:"type"` // "text", "image", "tool_use", "tool_result" + + // text + Text string `json:"text,omitempty"` + + // image (base64 source) + Source *ImageSource `json:"source,omitempty"` + + // tool_use + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + + // tool_result + ToolUseID string `json:"tool_use_id,omitempty"` + Content []ContentBlock `json:"content,omitempty"` + IsError bool `json:"is_error,omitempty"` +} + +// ImageSource holds a base64-encoded image for vision requests. +type ImageSource struct { + Type string `json:"type"` // "base64" + MediaType string `json:"media_type"` // "image/jpeg", "image/png", etc. + Data string `json:"data"` // base64-encoded image data +} + +// Tool describes a tool available to the model. +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema json.RawMessage `json:"input_schema"` +} + +// StreamEvent represents a single event from the SSE stream. +type StreamEvent struct { + Type string `json:"type"` // "text_delta", "tool_use_start", "tool_use_delta", "tool_result", "done", "error" + Data string `json:"data"` // event payload + + // Populated for tool_use_start events + ToolName string `json:"tool_name,omitempty"` + ToolID string `json:"tool_id,omitempty"` + ToolInput string `json:"tool_input,omitempty"` +} + +// Usage tracks token consumption for a request. +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// ConversationSummary is a lightweight view of a conversation for listing. +type ConversationSummary struct { + ID string `json:"id"` + Title string `json:"title"` + Model string `json:"model"` + CreatedAt time.Time `json:"createdAt"` + TokensIn int `json:"tokensIn"` + TokensOut int `json:"tokensOut"` +} diff --git a/internal/app/app.go b/internal/app/app.go index 21716d0..329c8ef 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" + "github.com/vstockwell/wraith/internal/ai" "github.com/vstockwell/wraith/internal/connections" "github.com/vstockwell/wraith/internal/credentials" "github.com/vstockwell/wraith/internal/db" @@ -35,6 +36,8 @@ type WraithApp struct { SFTP *sftp.SFTPService RDP *rdp.RDPService Credentials *credentials.CredentialService + AI *ai.AIService + oauthMgr *ai.OAuthManager unlocked bool } @@ -78,6 +81,15 @@ func New() (*WraithApp, error) { // CredentialService requires the vault to be unlocked, so it starts nil. // It is created lazily after the vault is unlocked via initCredentials(). + // AI copilot services — OAuthManager starts without a vault reference; + // it will be wired once the vault is unlocked. + oauthMgr := ai.NewOAuthManager(settingsSvc, nil) + toolRouter := ai.NewToolRouter() + toolRouter.SetServices(sshSvc, sftpSvc, rdpSvc, sessionMgr, connSvc) + convMgr := ai.NewConversationManager(database) + aiSvc := ai.NewAIService(oauthMgr, toolRouter, convMgr) + toolRouter.SetAIService(aiSvc) + // Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent) if err := themeSvc.SeedBuiltins(); err != nil { slog.Warn("failed to seed themes", "error", err) @@ -93,6 +105,8 @@ func New() (*WraithApp, error) { SSH: sshSvc, SFTP: sftpSvc, RDP: rdpSvc, + AI: aiSvc, + oauthMgr: oauthMgr, }, nil } @@ -197,9 +211,13 @@ func (a *WraithApp) IsUnlocked() bool { return a.unlocked } -// initCredentials creates the CredentialService after the vault is unlocked. +// initCredentials creates the CredentialService after the vault is unlocked +// and wires the vault to services that need it (OAuth, etc.). func (a *WraithApp) initCredentials() { if a.Vault != nil { a.Credentials = credentials.NewCredentialService(a.db, a.Vault) + if a.oauthMgr != nil { + a.oauthMgr.SetVault(a.Vault) + } } } diff --git a/internal/db/migrations/002_ai_copilot.sql b/internal/db/migrations/002_ai_copilot.sql new file mode 100644 index 0000000..affb48a --- /dev/null +++ b/internal/db/migrations/002_ai_copilot.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + title TEXT, + model TEXT NOT NULL, + messages TEXT NOT NULL DEFAULT '[]', + tokens_in INTEGER DEFAULT 0, + tokens_out INTEGER DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); diff --git a/main.go b/main.go index 1940ae1..603a826 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,7 @@ func main() { application.NewService(wraith.SSH), application.NewService(wraith.SFTP), application.NewService(wraith.RDP), + application.NewService(wraith.AI), }, Assets: application.AssetOptions{ Handler: application.BundledAssetFileServer(assets),