feat: AI copilot backend — OAuth PKCE, Claude API streaming, 16 tools, conversations
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled

- OAuth PKCE flow for Max subscription auth (no API key needed)
- Claude API client with SSE streaming (Messages API v1)
- 16 tool definitions: terminal, SFTP, RDP, session management
- Tool dispatch router mapping to existing Wraith services
- Conversation manager with SQLite persistence
- Terminal output ring buffer for AI context
- RDP screenshot encoder (RGBA → JPEG with downscaling)
- Wired into Wails app as AIService

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Vantz Stockwell 2026-03-17 09:09:23 -04:00
parent be868e8172
commit 7ee5321d69
18 changed files with 2932 additions and 1 deletions

256
internal/ai/client.go Normal file
View File

@ -0,0 +1,256 @@
package ai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
)
const (
defaultModel = "claude-sonnet-4-20250514"
anthropicVersion = "2023-06-01"
apiBaseURL = "https://api.anthropic.com/v1/messages"
)
// ClaudeClient is the HTTP client for the Anthropic Messages API.
type ClaudeClient struct {
oauth *OAuthManager
apiKey string // fallback for non-OAuth users
model string
httpClient *http.Client
}
// NewClaudeClient creates a client that authenticates via OAuth (primary) or API key (fallback).
func NewClaudeClient(oauth *OAuthManager, model string) *ClaudeClient {
if model == "" {
model = defaultModel
}
return &ClaudeClient{
oauth: oauth,
model: model,
httpClient: &http.Client{},
}
}
// SetAPIKey sets a fallback API key for users without Max subscriptions.
func (c *ClaudeClient) SetAPIKey(key string) {
c.apiKey = key
}
// apiRequest is the JSON body sent to the Messages API.
type apiRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
System string `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
}
// SendMessage sends a request to the Messages API with streaming enabled.
// It returns a channel of StreamEvents that the caller reads until the channel is closed.
func (c *ClaudeClient) SendMessage(messages []Message, tools []Tool, systemPrompt string) (<-chan StreamEvent, error) {
reqBody := apiRequest{
Model: c.model,
Messages: messages,
MaxTokens: 8192,
Stream: true,
System: systemPrompt,
Tools: tools,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequest("POST", apiBaseURL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", anthropicVersion)
if err := c.setAuthHeader(req); err != nil {
return nil, fmt.Errorf("set auth header: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body))
}
ch := make(chan StreamEvent, 64)
go c.parseSSEStream(resp.Body, ch)
return ch, nil
}
// setAuthHeader sets the appropriate authorization header on the request.
func (c *ClaudeClient) setAuthHeader(req *http.Request) error {
// Try OAuth first
if c.oauth != nil && c.oauth.IsAuthenticated() {
token, err := c.oauth.GetAccessToken()
if err == nil {
req.Header.Set("Authorization", "Bearer "+token)
return nil
}
slog.Warn("oauth token failed, falling back to api key", "error", err)
}
// Fallback to API key
if c.apiKey != "" {
req.Header.Set("x-api-key", c.apiKey)
return nil
}
return fmt.Errorf("no authentication method available — log in via OAuth or set an API key")
}
// parseSSEStream reads the SSE response body and sends StreamEvents on the channel.
func (c *ClaudeClient) parseSSEStream(body io.ReadCloser, ch chan<- StreamEvent) {
defer body.Close()
defer close(ch)
scanner := bufio.NewScanner(body)
var currentEventType string
var currentToolID string
var currentToolName string
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
if strings.HasPrefix(line, "event: ") {
currentEventType = strings.TrimPrefix(line, "event: ")
continue
}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
switch currentEventType {
case "content_block_start":
var block struct {
Index int `json:"index"`
ContentBlock struct {
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Text string `json:"text"`
Input string `json:"input"`
} `json:"content_block"`
}
if err := json.Unmarshal([]byte(data), &block); err != nil {
slog.Warn("failed to parse content_block_start", "error", err)
continue
}
if block.ContentBlock.Type == "tool_use" {
currentToolID = block.ContentBlock.ID
currentToolName = block.ContentBlock.Name
ch <- StreamEvent{
Type: "tool_use_start",
ToolID: currentToolID,
ToolName: currentToolName,
}
} else if block.ContentBlock.Type == "text" && block.ContentBlock.Text != "" {
ch <- StreamEvent{Type: "text_delta", Data: block.ContentBlock.Text}
}
case "content_block_delta":
var delta struct {
Delta struct {
Type string `json:"type"`
Text string `json:"text"`
PartialJSON string `json:"partial_json"`
} `json:"delta"`
}
if err := json.Unmarshal([]byte(data), &delta); err != nil {
slog.Warn("failed to parse content_block_delta", "error", err)
continue
}
if delta.Delta.Type == "text_delta" {
ch <- StreamEvent{Type: "text_delta", Data: delta.Delta.Text}
} else if delta.Delta.Type == "input_json_delta" {
ch <- StreamEvent{
Type: "tool_use_delta",
Data: delta.Delta.PartialJSON,
ToolID: currentToolID,
ToolName: currentToolName,
}
}
case "content_block_stop":
// nothing to do; tool input is accumulated by the caller
case "message_delta":
// contains stop_reason and usage; emit usage info
ch <- StreamEvent{Type: "done", Data: data}
case "message_stop":
// end of message stream
return
case "message_start":
// contains message metadata; could extract usage but we handle it at message_delta
case "ping":
// heartbeat; ignore
case "error":
ch <- StreamEvent{Type: "error", Data: data}
return
}
}
if err := scanner.Err(); err != nil {
ch <- StreamEvent{Type: "error", Data: err.Error()}
}
}
// BuildRequestBody creates the JSON request body for testing purposes.
func BuildRequestBody(messages []Message, tools []Tool, systemPrompt, model string) ([]byte, error) {
if model == "" {
model = defaultModel
}
req := apiRequest{
Model: model,
Messages: messages,
MaxTokens: 8192,
Stream: true,
System: systemPrompt,
Tools: tools,
}
return json.Marshal(req)
}
// ParseSSELine parses a single SSE data line. Returns the event type and data payload.
func ParseSSELine(line string) (eventType, data string) {
if strings.HasPrefix(line, "event: ") {
return "event", strings.TrimPrefix(line, "event: ")
}
if strings.HasPrefix(line, "data: ") {
return "data", strings.TrimPrefix(line, "data: ")
}
return "", ""
}

123
internal/ai/client_test.go Normal file
View File

@ -0,0 +1,123 @@
package ai
import (
"encoding/json"
"net/http"
"testing"
)
func TestBuildRequestBody(t *testing.T) {
messages := []Message{
{
Role: "user",
Content: []ContentBlock{
{Type: "text", Text: "Hello"},
},
},
}
tools := []Tool{
{
Name: "test_tool",
Description: "A test tool",
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
},
}
body, err := BuildRequestBody(messages, tools, "You are helpful.", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var parsed map[string]interface{}
if err := json.Unmarshal(body, &parsed); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
// Check model
if m, ok := parsed["model"].(string); !ok || m != defaultModel {
t.Errorf("expected model %s, got %v", defaultModel, parsed["model"])
}
// Check stream
if s, ok := parsed["stream"].(bool); !ok || !s {
t.Errorf("expected stream true, got %v", parsed["stream"])
}
// Check system prompt
if s, ok := parsed["system"].(string); !ok || s != "You are helpful." {
t.Errorf("expected system prompt, got %v", parsed["system"])
}
// Check max_tokens
if mt, ok := parsed["max_tokens"].(float64); !ok || int(mt) != 8192 {
t.Errorf("expected max_tokens 8192, got %v", parsed["max_tokens"])
}
// Check messages array exists
msgs, ok := parsed["messages"].([]interface{})
if !ok || len(msgs) != 1 {
t.Errorf("expected 1 message, got %v", parsed["messages"])
}
// Check tools array exists
tls, ok := parsed["tools"].([]interface{})
if !ok || len(tls) != 1 {
t.Errorf("expected 1 tool, got %v", parsed["tools"])
}
}
func TestParseSSELine(t *testing.T) {
tests := []struct {
input string
wantType string
wantData string
}{
{"event: content_block_delta", "event", "content_block_delta"},
{"data: {\"delta\":{\"text\":\"hello\"}}", "data", "{\"delta\":{\"text\":\"hello\"}}"},
{"", "", ""},
{"random line", "", ""},
{"event: message_stop", "event", "message_stop"},
{"data: [DONE]", "data", "[DONE]"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
gotType, gotData := ParseSSELine(tt.input)
if gotType != tt.wantType {
t.Errorf("ParseSSELine(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
}
if gotData != tt.wantData {
t.Errorf("ParseSSELine(%q) data = %q, want %q", tt.input, gotData, tt.wantData)
}
})
}
}
func TestAuthHeader(t *testing.T) {
// Test with API key only (no OAuth)
client := NewClaudeClient(nil, "")
client.SetAPIKey("sk-test-12345")
req, _ := http.NewRequest("POST", "https://example.com", nil)
if err := client.setAuthHeader(req); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := req.Header.Get("x-api-key"); got != "sk-test-12345" {
t.Errorf("expected x-api-key sk-test-12345, got %q", got)
}
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("expected no Authorization header, got %q", got)
}
}
func TestAuthHeaderNoAuth(t *testing.T) {
// No OAuth, no API key — should return error
client := NewClaudeClient(nil, "")
req, _ := http.NewRequest("POST", "https://example.com", nil)
err := client.setAuthHeader(req)
if err == nil {
t.Error("expected error when no auth method available")
}
}

169
internal/ai/conversation.go Normal file
View File

@ -0,0 +1,169 @@
package ai
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
)
// ConversationManager handles CRUD operations for AI conversations stored in SQLite.
type ConversationManager struct {
db *sql.DB
}
// NewConversationManager creates a manager backed by the given database.
func NewConversationManager(db *sql.DB) *ConversationManager {
return &ConversationManager{db: db}
}
// Create starts a new conversation and returns its summary.
func (m *ConversationManager) Create(model string) (*ConversationSummary, error) {
id := uuid.NewString()
now := time.Now()
_, err := m.db.Exec(
`INSERT INTO conversations (id, title, model, messages, tokens_in, tokens_out, created_at, updated_at)
VALUES (?, ?, ?, '[]', 0, 0, ?, ?)`,
id, "New conversation", model, now, now,
)
if err != nil {
return nil, fmt.Errorf("create conversation: %w", err)
}
return &ConversationSummary{
ID: id,
Title: "New conversation",
Model: model,
CreatedAt: now,
TokensIn: 0,
TokensOut: 0,
}, nil
}
// AddMessage appends a message to the conversation's message list.
func (m *ConversationManager) AddMessage(convId string, msg Message) error {
// Get existing messages
messages, err := m.GetMessages(convId)
if err != nil {
return err
}
messages = append(messages, msg)
data, err := json.Marshal(messages)
if err != nil {
return fmt.Errorf("marshal messages: %w", err)
}
_, err = m.db.Exec(
"UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?",
string(data), time.Now(), convId,
)
if err != nil {
return fmt.Errorf("update messages: %w", err)
}
// Auto-title from first user message
if len(messages) == 1 && msg.Role == "user" {
title := extractTitle(msg)
if title != "" {
m.db.Exec("UPDATE conversations SET title = ? WHERE id = ?", title, convId)
}
}
return nil
}
// GetMessages returns all messages in a conversation.
func (m *ConversationManager) GetMessages(convId string) ([]Message, error) {
var messagesJSON string
err := m.db.QueryRow("SELECT messages FROM conversations WHERE id = ?", convId).Scan(&messagesJSON)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("conversation %s not found", convId)
}
if err != nil {
return nil, fmt.Errorf("get messages: %w", err)
}
var messages []Message
if err := json.Unmarshal([]byte(messagesJSON), &messages); err != nil {
return nil, fmt.Errorf("unmarshal messages: %w", err)
}
return messages, nil
}
// List returns all conversations ordered by most recent.
func (m *ConversationManager) List() ([]ConversationSummary, error) {
rows, err := m.db.Query(
`SELECT id, title, model, tokens_in, tokens_out, created_at
FROM conversations ORDER BY updated_at DESC`,
)
if err != nil {
return nil, fmt.Errorf("list conversations: %w", err)
}
defer rows.Close()
var summaries []ConversationSummary
for rows.Next() {
var s ConversationSummary
var title sql.NullString
if err := rows.Scan(&s.ID, &title, &s.Model, &s.TokensIn, &s.TokensOut, &s.CreatedAt); err != nil {
return nil, fmt.Errorf("scan conversation: %w", err)
}
if title.Valid {
s.Title = title.String
}
summaries = append(summaries, s)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate conversations: %w", err)
}
if summaries == nil {
summaries = []ConversationSummary{}
}
return summaries, nil
}
// Delete removes a conversation and all its messages.
func (m *ConversationManager) Delete(convId string) error {
result, err := m.db.Exec("DELETE FROM conversations WHERE id = ?", convId)
if err != nil {
return fmt.Errorf("delete conversation: %w", err)
}
affected, _ := result.RowsAffected()
if affected == 0 {
return fmt.Errorf("conversation %s not found", convId)
}
return nil
}
// UpdateTokenUsage adds to the token counters for a conversation.
func (m *ConversationManager) UpdateTokenUsage(convId string, tokensIn, tokensOut int) error {
_, err := m.db.Exec(
`UPDATE conversations
SET tokens_in = tokens_in + ?, tokens_out = tokens_out + ?, updated_at = ?
WHERE id = ?`,
tokensIn, tokensOut, time.Now(), convId,
)
if err != nil {
return fmt.Errorf("update token usage: %w", err)
}
return nil
}
// extractTitle generates a title from the first user message (truncated to 80 chars).
func extractTitle(msg Message) string {
for _, block := range msg.Content {
if block.Type == "text" && block.Text != "" {
title := block.Text
if len(title) > 80 {
title = title[:77] + "..."
}
return title
}
}
return ""
}

View File

@ -0,0 +1,220 @@
package ai
import (
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
)
func setupConversationManager(t *testing.T) *ConversationManager {
t.Helper()
database, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate: %v", err)
}
// Create the conversations table (002 migration)
_, err = database.Exec(`CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY, title TEXT, model TEXT NOT NULL,
messages TEXT NOT NULL DEFAULT '[]',
tokens_in INTEGER DEFAULT 0, tokens_out INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP)`)
if err != nil {
t.Fatalf("create conversations table: %v", err)
}
t.Cleanup(func() { database.Close() })
return NewConversationManager(database)
}
func TestCreateConversation(t *testing.T) {
mgr := setupConversationManager(t)
conv, err := mgr.Create("claude-sonnet-4-20250514")
if err != nil {
t.Fatalf("create: %v", err)
}
if conv.ID == "" {
t.Error("expected non-empty ID")
}
if conv.Model != "claude-sonnet-4-20250514" {
t.Errorf("expected model claude-sonnet-4-20250514, got %s", conv.Model)
}
if conv.Title != "New conversation" {
t.Errorf("expected title 'New conversation', got %s", conv.Title)
}
if conv.TokensIn != 0 || conv.TokensOut != 0 {
t.Errorf("expected zero tokens, got in=%d out=%d", conv.TokensIn, conv.TokensOut)
}
}
func TestAddAndGetMessages(t *testing.T) {
mgr := setupConversationManager(t)
conv, err := mgr.Create("test-model")
if err != nil {
t.Fatalf("create: %v", err)
}
// Add a user message
userMsg := Message{
Role: "user",
Content: []ContentBlock{
{Type: "text", Text: "What is running on port 8080?"},
},
}
if err := mgr.AddMessage(conv.ID, userMsg); err != nil {
t.Fatalf("add user message: %v", err)
}
// Add an assistant message
assistantMsg := Message{
Role: "assistant",
Content: []ContentBlock{
{Type: "text", Text: "Let me check that for you."},
},
}
if err := mgr.AddMessage(conv.ID, assistantMsg); err != nil {
t.Fatalf("add assistant message: %v", err)
}
// Get messages
messages, err := mgr.GetMessages(conv.ID)
if err != nil {
t.Fatalf("get messages: %v", err)
}
if len(messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(messages))
}
if messages[0].Role != "user" {
t.Errorf("expected first message role 'user', got %s", messages[0].Role)
}
if messages[1].Role != "assistant" {
t.Errorf("expected second message role 'assistant', got %s", messages[1].Role)
}
if messages[0].Content[0].Text != "What is running on port 8080?" {
t.Errorf("unexpected message text: %s", messages[0].Content[0].Text)
}
}
func TestListConversations(t *testing.T) {
mgr := setupConversationManager(t)
// Create multiple conversations
_, err := mgr.Create("model-a")
if err != nil {
t.Fatalf("create 1: %v", err)
}
_, err = mgr.Create("model-b")
if err != nil {
t.Fatalf("create 2: %v", err)
}
list, err := mgr.List()
if err != nil {
t.Fatalf("list: %v", err)
}
if len(list) != 2 {
t.Errorf("expected 2 conversations, got %d", len(list))
}
}
func TestDeleteConversation(t *testing.T) {
mgr := setupConversationManager(t)
conv, err := mgr.Create("test-model")
if err != nil {
t.Fatalf("create: %v", err)
}
if err := mgr.Delete(conv.ID); err != nil {
t.Fatalf("delete: %v", err)
}
// Verify it's gone
list, err := mgr.List()
if err != nil {
t.Fatalf("list: %v", err)
}
if len(list) != 0 {
t.Errorf("expected 0 conversations after delete, got %d", len(list))
}
// Delete non-existent should error
if err := mgr.Delete("nonexistent"); err == nil {
t.Error("expected error deleting non-existent conversation")
}
}
func TestTokenUsageTracking(t *testing.T) {
mgr := setupConversationManager(t)
conv, err := mgr.Create("test-model")
if err != nil {
t.Fatalf("create: %v", err)
}
// Update token usage multiple times
if err := mgr.UpdateTokenUsage(conv.ID, 100, 50); err != nil {
t.Fatalf("update tokens 1: %v", err)
}
if err := mgr.UpdateTokenUsage(conv.ID, 200, 100); err != nil {
t.Fatalf("update tokens 2: %v", err)
}
// Verify totals
list, err := mgr.List()
if err != nil {
t.Fatalf("list: %v", err)
}
if len(list) != 1 {
t.Fatalf("expected 1 conversation, got %d", len(list))
}
if list[0].TokensIn != 300 {
t.Errorf("expected 300 tokens in, got %d", list[0].TokensIn)
}
if list[0].TokensOut != 150 {
t.Errorf("expected 150 tokens out, got %d", list[0].TokensOut)
}
}
func TestGetMessagesNonExistent(t *testing.T) {
mgr := setupConversationManager(t)
_, err := mgr.GetMessages("nonexistent-id")
if err == nil {
t.Error("expected error for non-existent conversation")
}
}
func TestAutoTitle(t *testing.T) {
mgr := setupConversationManager(t)
conv, err := mgr.Create("test-model")
if err != nil {
t.Fatalf("create: %v", err)
}
msg := Message{
Role: "user",
Content: []ContentBlock{
{Type: "text", Text: "Check disk usage on server-01"},
},
}
if err := mgr.AddMessage(conv.ID, msg); err != nil {
t.Fatalf("add message: %v", err)
}
// Verify the title was auto-set
list, err := mgr.List()
if err != nil {
t.Fatalf("list: %v", err)
}
if list[0].Title != "Check disk usage on server-01" {
t.Errorf("expected auto-title, got %q", list[0].Title)
}
}

371
internal/ai/oauth.go Normal file
View File

@ -0,0 +1,371 @@
package ai
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/vstockwell/wraith/internal/settings"
"github.com/vstockwell/wraith/internal/vault"
)
const (
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauthAuthorizeURL = "https://claude.ai/oauth/authorize"
oauthTokenURL = "https://platform.claude.com/v1/oauth/token"
)
// OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions.
type OAuthManager struct {
settings *settings.SettingsService
vault *vault.VaultService
clientID string
authorizeURL string
tokenURL string
mu sync.RWMutex
}
// tokenResponse is the JSON body returned by the token endpoint.
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// NewOAuthManager creates an OAuthManager wired to settings and vault services.
func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager {
return &OAuthManager{
settings: s,
vault: v,
clientID: oauthClientID,
authorizeURL: oauthAuthorizeURL,
tokenURL: oauthTokenURL,
}
}
// StartLogin begins the OAuth PKCE flow. It starts a local HTTP server,
// opens the authorization URL in the user's browser, and returns a channel
// that receives nil on success or an error on failure.
func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) {
verifier, err := generateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier: %w", err)
}
challenge := generateCodeChallenge(verifier)
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generate state: %w", err)
}
// Start a local HTTP server on a random port for the callback
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen for callback: %w", err)
}
port := listener.Addr().(*net.TCPAddr).Port
redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port)
done := make(chan error, 1)
mux := http.NewServeMux()
server := &http.Server{Handler: mux}
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
defer func() {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
}()
// Verify state
if r.URL.Query().Get("state") != state {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "State mismatch — please try logging in again.")
done <- fmt.Errorf("state mismatch")
return
}
if errParam := r.URL.Query().Get("error"); errParam != "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Authorization error: %s", errParam)
done <- fmt.Errorf("authorization error: %s", errParam)
return
}
code := r.URL.Query().Get("code")
if code == "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "No authorization code received.")
done <- fmt.Errorf("no authorization code")
return
}
// Exchange code for tokens
if err := o.exchangeCode(code, verifier, redirectURI); err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "Failed to exchange authorization code.")
done <- fmt.Errorf("exchange code: %w", err)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, `<html><body><h2>Authenticated!</h2><p>You can close this window and return to Wraith.</p></body></html>`)
done <- nil
})
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
slog.Error("oauth callback server error", "error", err)
done <- fmt.Errorf("callback server: %w", err)
}
}()
// Build the authorize URL
params := url.Values{
"response_type": {"code"},
"client_id": {o.clientID},
"redirect_uri": {redirectURI},
"scope": {"user:inference"},
"state": {state},
"code_challenge": {challenge},
"code_challenge_method": {"S256"},
}
authURL := o.authorizeURL + "?" + params.Encode()
if openURL != nil {
if err := openURL(authURL); err != nil {
listener.Close()
return nil, fmt.Errorf("open browser: %w", err)
}
}
return done, nil
}
// exchangeCode exchanges an authorization code for access and refresh tokens.
func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {o.clientID},
"code_verifier": {verifier},
}
resp, err := http.PostForm(o.tokenURL, data)
if err != nil {
return fmt.Errorf("post token request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read token response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body))
}
var tokens tokenResponse
if err := json.Unmarshal(body, &tokens); err != nil {
return fmt.Errorf("unmarshal token response: %w", err)
}
return o.storeTokens(tokens)
}
// storeTokens encrypts and persists OAuth tokens in settings.
func (o *OAuthManager) storeTokens(tokens tokenResponse) error {
o.mu.Lock()
defer o.mu.Unlock()
if o.vault == nil {
return fmt.Errorf("vault is not unlocked — cannot store tokens")
}
encAccess, err := o.vault.Encrypt(tokens.AccessToken)
if err != nil {
return fmt.Errorf("encrypt access token: %w", err)
}
if err := o.settings.Set("ai_access_token", encAccess); err != nil {
return fmt.Errorf("store access token: %w", err)
}
encRefresh, err := o.vault.Encrypt(tokens.RefreshToken)
if err != nil {
return fmt.Errorf("encrypt refresh token: %w", err)
}
if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil {
return fmt.Errorf("store refresh token: %w", err)
}
expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix()
if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil {
return fmt.Errorf("store token expiry: %w", err)
}
slog.Info("oauth tokens stored successfully")
return nil
}
// GetAccessToken returns a valid access token, refreshing if expired.
func (o *OAuthManager) GetAccessToken() (string, error) {
o.mu.RLock()
vlt := o.vault
o.mu.RUnlock()
if vlt == nil {
return "", fmt.Errorf("vault is not unlocked")
}
// Check expiry
expiresStr, _ := o.settings.Get("ai_token_expires_at")
if expiresStr != "" {
expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64)
if time.Now().Unix() >= expiresAt {
// Token expired, try to refresh
if err := o.refreshToken(); err != nil {
return "", fmt.Errorf("refresh expired token: %w", err)
}
}
}
encAccess, err := o.settings.Get("ai_access_token")
if err != nil || encAccess == "" {
return "", fmt.Errorf("no access token stored")
}
accessToken, err := vlt.Decrypt(encAccess)
if err != nil {
return "", fmt.Errorf("decrypt access token: %w", err)
}
return accessToken, nil
}
// refreshToken uses the stored refresh token to obtain new tokens.
func (o *OAuthManager) refreshToken() error {
o.mu.RLock()
vlt := o.vault
o.mu.RUnlock()
if vlt == nil {
return fmt.Errorf("vault is not unlocked")
}
encRefresh, err := o.settings.Get("ai_refresh_token")
if err != nil || encRefresh == "" {
return fmt.Errorf("no refresh token stored")
}
refreshTok, err := vlt.Decrypt(encRefresh)
if err != nil {
return fmt.Errorf("decrypt refresh token: %w", err)
}
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshTok},
"client_id": {o.clientID},
}
resp, err := http.PostForm(o.tokenURL, data)
if err != nil {
return fmt.Errorf("post refresh request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body))
}
var tokens tokenResponse
if err := json.Unmarshal(body, &tokens); err != nil {
return fmt.Errorf("unmarshal refresh response: %w", err)
}
// If the refresh endpoint doesn't return a new refresh token, keep the old one
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshTok
}
return o.storeTokens(tokens)
}
// IsAuthenticated checks if we have stored tokens.
func (o *OAuthManager) IsAuthenticated() bool {
encAccess, err := o.settings.Get("ai_access_token")
return err == nil && encAccess != ""
}
// Logout clears all stored OAuth tokens.
func (o *OAuthManager) Logout() error {
for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} {
if err := o.settings.Delete(key); err != nil {
return fmt.Errorf("delete %s: %w", key, err)
}
}
return nil
}
// generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding).
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// generateCodeChallenge creates a PKCE S256 code challenge from a verifier.
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}
// generateState creates a random state parameter (32 random bytes, base64url no padding).
func generateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// SetVault updates the vault reference (called after vault unlock).
func (o *OAuthManager) SetVault(v *vault.VaultService) {
o.mu.Lock()
defer o.mu.Unlock()
o.vault = v
}
// isBase64URL checks if a string contains only base64url characters (no padding).
func isBase64URL(s string) bool {
for _, c := range s {
if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') {
return false
}
}
return true
}

91
internal/ai/oauth_test.go Normal file
View File

@ -0,0 +1,91 @@
package ai
import (
"crypto/sha256"
"encoding/base64"
"path/filepath"
"testing"
"github.com/vstockwell/wraith/internal/db"
"github.com/vstockwell/wraith/internal/settings"
"github.com/vstockwell/wraith/internal/vault"
)
func TestGenerateCodeVerifier(t *testing.T) {
v, err := generateCodeVerifier()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 32 bytes -> 43 base64url chars (no padding)
if len(v) != 43 {
t.Errorf("expected verifier length 43, got %d", len(v))
}
if !isBase64URL(v) {
t.Errorf("verifier contains non-base64url characters: %s", v)
}
// Should be different each time
v2, _ := generateCodeVerifier()
if v == v2 {
t.Error("two generated verifiers should not be identical")
}
}
func TestGenerateCodeChallenge(t *testing.T) {
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
challenge := generateCodeChallenge(verifier)
// Verify it matches a manually computed S256 hash
h := sha256.Sum256([]byte(verifier))
expected := base64.RawURLEncoding.EncodeToString(h[:])
if challenge != expected {
t.Errorf("expected challenge %s, got %s", expected, challenge)
}
// Same verifier should produce the same challenge (deterministic)
challenge2 := generateCodeChallenge(verifier)
if challenge != challenge2 {
t.Error("challenge should be deterministic for the same verifier")
}
}
func TestGenerateState(t *testing.T) {
s, err := generateState()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 32 bytes -> 43 base64url chars (no padding)
if len(s) != 43 {
t.Errorf("expected state length 43, got %d", len(s))
}
if !isBase64URL(s) {
t.Errorf("state contains non-base64url characters: %s", s)
}
}
func TestIsAuthenticatedWhenNoTokens(t *testing.T) {
database, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { database.Close() })
settingsSvc := settings.NewSettingsService(database)
key := vault.DeriveKey("test-password", []byte("test-salt-1234567890123456789012"))
vaultSvc := vault.NewVaultService(key)
oauth := NewOAuthManager(settingsSvc, vaultSvc)
if oauth.IsAuthenticated() {
t.Error("expected IsAuthenticated to return false when no tokens are stored")
}
}

564
internal/ai/router.go Normal file
View File

@ -0,0 +1,564 @@
package ai
import (
"encoding/json"
"fmt"
)
// ToolRouter dispatches tool calls to the appropriate service.
// Services are stored as interface{} to avoid circular imports between packages.
type ToolRouter struct {
ssh interface{} // *ssh.SSHService
sftp interface{} // *sftp.SFTPService
rdp interface{} // *rdp.RDPService
sessions interface{} // *session.Manager
connections interface{} // *connections.ConnectionService
aiService interface{} // *AIService — for terminal buffer access
}
// NewToolRouter creates an empty ToolRouter. Call SetServices to wire in backends.
func NewToolRouter() *ToolRouter {
return &ToolRouter{}
}
// SetServices wires the router to actual service implementations.
func (r *ToolRouter) SetServices(ssh, sftp, rdp, sessions, connections interface{}) {
r.ssh = ssh
r.sftp = sftp
r.rdp = rdp
r.sessions = sessions
r.connections = connections
}
// SetAIService wires the router to the AI service for terminal buffer access.
func (r *ToolRouter) SetAIService(aiService interface{}) {
r.aiService = aiService
}
// sshWriter is the interface we need from SSHService for terminal_write.
type sshWriter interface {
Write(sessionID string, data string) error
}
// sshSessionLister is the interface for listing SSH sessions.
type sshSessionLister interface {
ListSessions() interface{}
}
// sftpLister is the interface for SFTP list operations.
type sftpLister interface {
List(sessionID string, path string) (interface{}, error)
}
// sftpReader is the interface for SFTP read operations.
type sftpReader interface {
ReadFile(sessionID string, path string) (string, error)
}
// sftpWriter is the interface for SFTP write operations.
type sftpWriter interface {
WriteFile(sessionID string, path string, content string) error
}
// rdpFrameGetter is the interface for getting RDP screenshots.
type rdpFrameGetter interface {
GetFrame(sessionID string) ([]byte, error)
GetSessionInfo(sessionID string) (interface{}, error)
}
// rdpMouseSender is the interface for RDP mouse events.
type rdpMouseSender interface {
SendMouse(sessionID string, x, y int, flags uint32) error
}
// rdpKeySender is the interface for RDP key events.
type rdpKeySender interface {
SendKey(sessionID string, scancode uint32, pressed bool) error
}
// rdpClipboardSender is the interface for RDP clipboard.
type rdpClipboardSender interface {
SendClipboard(sessionID string, data string) error
}
// rdpSessionLister is the interface for listing RDP sessions.
type rdpSessionLister interface {
ListSessions() interface{}
}
// sessionLister is the interface for listing sessions.
type sessionLister interface {
List() interface{}
}
// bufferProvider provides terminal output buffers.
type bufferProvider interface {
GetBuffer(sessionId string) *TerminalBuffer
}
// Dispatch routes a tool call to the appropriate handler.
func (r *ToolRouter) Dispatch(toolName string, input json.RawMessage) (interface{}, error) {
switch toolName {
case "terminal_write":
return r.handleTerminalWrite(input)
case "terminal_read":
return r.handleTerminalRead(input)
case "terminal_cwd":
return r.handleTerminalCwd(input)
case "sftp_list":
return r.handleSFTPList(input)
case "sftp_read":
return r.handleSFTPRead(input)
case "sftp_write":
return r.handleSFTPWrite(input)
case "rdp_screenshot":
return r.handleRDPScreenshot(input)
case "rdp_click":
return r.handleRDPClick(input)
case "rdp_doubleclick":
return r.handleRDPDoubleClick(input)
case "rdp_type":
return r.handleRDPType(input)
case "rdp_keypress":
return r.handleRDPKeypress(input)
case "rdp_scroll":
return r.handleRDPScroll(input)
case "rdp_move":
return r.handleRDPMove(input)
case "list_sessions":
return r.handleListSessions(input)
case "connect_ssh":
return r.handleConnectSSH(input)
case "disconnect":
return r.handleDisconnect(input)
default:
return nil, fmt.Errorf("unknown tool: %s", toolName)
}
}
func (r *ToolRouter) handleTerminalWrite(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Text string `json:"text"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
w, ok := r.ssh.(sshWriter)
if !ok || r.ssh == nil {
return nil, fmt.Errorf("SSH service not available")
}
if err := w.Write(params.SessionID, params.Text); err != nil {
return nil, err
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleTerminalRead(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Lines int `json:"lines"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
if params.Lines <= 0 {
params.Lines = 50
}
bp, ok := r.aiService.(bufferProvider)
if !ok || r.aiService == nil {
return nil, fmt.Errorf("terminal buffer not available")
}
buf := bp.GetBuffer(params.SessionID)
lines := buf.ReadLast(params.Lines)
return map[string]interface{}{
"lines": lines,
"count": len(lines),
}, nil
}
func (r *ToolRouter) handleTerminalCwd(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
// terminal_cwd works by writing "pwd" to the terminal and reading output
w, ok := r.ssh.(sshWriter)
if !ok || r.ssh == nil {
return nil, fmt.Errorf("SSH service not available")
}
if err := w.Write(params.SessionID, "pwd\n"); err != nil {
return nil, fmt.Errorf("send pwd: %w", err)
}
return map[string]string{
"status": "pwd command sent — read terminal output for result",
}, nil
}
func (r *ToolRouter) handleSFTPList(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Path string `json:"path"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
l, ok := r.sftp.(sftpLister)
if !ok || r.sftp == nil {
return nil, fmt.Errorf("SFTP service not available")
}
return l.List(params.SessionID, params.Path)
}
func (r *ToolRouter) handleSFTPRead(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Path string `json:"path"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
reader, ok := r.sftp.(sftpReader)
if !ok || r.sftp == nil {
return nil, fmt.Errorf("SFTP service not available")
}
content, err := reader.ReadFile(params.SessionID, params.Path)
if err != nil {
return nil, err
}
return map[string]string{"content": content}, nil
}
func (r *ToolRouter) handleSFTPWrite(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Path string `json:"path"`
Content string `json:"content"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
w, ok := r.sftp.(sftpWriter)
if !ok || r.sftp == nil {
return nil, fmt.Errorf("SFTP service not available")
}
if err := w.WriteFile(params.SessionID, params.Path, params.Content); err != nil {
return nil, err
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleRDPScreenshot(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
fg, ok := r.rdp.(rdpFrameGetter)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
frame, err := fg.GetFrame(params.SessionID)
if err != nil {
return nil, err
}
// Get session info for dimensions
info, err := fg.GetSessionInfo(params.SessionID)
if err != nil {
return nil, err
}
// Try to get dimensions from session config
type configGetter interface {
GetConfig() (int, int)
}
width, height := 1920, 1080
if cg, ok := info.(configGetter); ok {
width, height = cg.GetConfig()
}
// Encode as JPEG
jpeg, err := EncodeScreenshot(frame, width, height, 1280, 720, 75)
if err != nil {
return nil, fmt.Errorf("encode screenshot: %w", err)
}
return map[string]interface{}{
"image": jpeg,
"width": width,
"height": height,
}, nil
}
func (r *ToolRouter) handleRDPClick(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
X int `json:"x"`
Y int `json:"y"`
Button string `json:"button"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
ms, ok := r.rdp.(rdpMouseSender)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
var buttonFlag uint32 = 0x1000 // left
switch params.Button {
case "right":
buttonFlag = 0x2000
case "middle":
buttonFlag = 0x4000
}
// Press
if err := ms.SendMouse(params.SessionID, params.X, params.Y, buttonFlag|0x8000); err != nil {
return nil, err
}
// Release
if err := ms.SendMouse(params.SessionID, params.X, params.Y, buttonFlag); err != nil {
return nil, err
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleRDPDoubleClick(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
X int `json:"x"`
Y int `json:"y"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
ms, ok := r.rdp.(rdpMouseSender)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
// Two clicks
for i := 0; i < 2; i++ {
if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x1000|0x8000); err != nil {
return nil, err
}
if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x1000); err != nil {
return nil, err
}
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleRDPType(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Text string `json:"text"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
cs, ok := r.rdp.(rdpClipboardSender)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
// Send text via clipboard, then simulate Ctrl+V
if err := cs.SendClipboard(params.SessionID, params.Text); err != nil {
return nil, err
}
ks, ok := r.rdp.(rdpKeySender)
if !ok {
return nil, fmt.Errorf("RDP key service not available")
}
// Ctrl down, V down, V up, Ctrl up
if err := ks.SendKey(params.SessionID, 0x001D, true); err != nil {
return nil, err
}
if err := ks.SendKey(params.SessionID, 0x002F, true); err != nil {
return nil, err
}
if err := ks.SendKey(params.SessionID, 0x002F, false); err != nil {
return nil, err
}
if err := ks.SendKey(params.SessionID, 0x001D, false); err != nil {
return nil, err
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleRDPKeypress(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
Key string `json:"key"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
ks, ok := r.rdp.(rdpKeySender)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
// Simple key name to scancode mapping for common keys
keyMap := map[string]uint32{
"Enter": 0x001C,
"Tab": 0x000F,
"Escape": 0x0001,
"Backspace": 0x000E,
"Delete": 0xE053,
"Space": 0x0039,
"Up": 0xE048,
"Down": 0xE050,
"Left": 0xE04B,
"Right": 0xE04D,
}
scancode, ok := keyMap[params.Key]
if !ok {
return nil, fmt.Errorf("unknown key: %s", params.Key)
}
if err := ks.SendKey(params.SessionID, scancode, true); err != nil {
return nil, err
}
if err := ks.SendKey(params.SessionID, scancode, false); err != nil {
return nil, err
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleRDPScroll(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
X int `json:"x"`
Y int `json:"y"`
Direction string `json:"direction"`
Clicks int `json:"clicks"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
if params.Clicks <= 0 {
params.Clicks = 3
}
ms, ok := r.rdp.(rdpMouseSender)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
var flags uint32 = 0x0200 // wheel flag
if params.Direction == "down" {
flags |= 0x0100 // negative flag
}
for i := 0; i < params.Clicks; i++ {
if err := ms.SendMouse(params.SessionID, params.X, params.Y, flags); err != nil {
return nil, err
}
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleRDPMove(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
X int `json:"x"`
Y int `json:"y"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
ms, ok := r.rdp.(rdpMouseSender)
if !ok || r.rdp == nil {
return nil, fmt.Errorf("RDP service not available")
}
if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x0800); err != nil {
return nil, err
}
return map[string]string{"status": "ok"}, nil
}
func (r *ToolRouter) handleListSessions(_ json.RawMessage) (interface{}, error) {
result := map[string]interface{}{
"ssh": []interface{}{},
"rdp": []interface{}{},
}
if r.sessions != nil {
if sl, ok := r.sessions.(sessionLister); ok {
result["all"] = sl.List()
}
}
return result, nil
}
func (r *ToolRouter) handleConnectSSH(input json.RawMessage) (interface{}, error) {
var params struct {
ConnectionID int64 `json:"connectionId"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
// This will be wired to the app-level connect logic later
return nil, fmt.Errorf("connect_ssh requires app-level wiring — not yet available via tool dispatch")
}
func (r *ToolRouter) handleDisconnect(input json.RawMessage) (interface{}, error) {
var params struct {
SessionID string `json:"sessionId"`
}
if err := json.Unmarshal(input, &params); err != nil {
return nil, fmt.Errorf("parse input: %w", err)
}
// Try SSH first
type disconnecter interface {
Disconnect(sessionID string) error
}
if d, ok := r.ssh.(disconnecter); ok {
if err := d.Disconnect(params.SessionID); err == nil {
return map[string]string{"status": "disconnected", "protocol": "ssh"}, nil
}
}
if d, ok := r.rdp.(disconnecter); ok {
if err := d.Disconnect(params.SessionID); err == nil {
return map[string]string{"status": "disconnected", "protocol": "rdp"}, nil
}
}
return nil, fmt.Errorf("session %s not found in SSH or RDP", params.SessionID)
}

View File

@ -0,0 +1,85 @@
package ai
import (
"encoding/json"
"testing"
)
func TestDispatchUnknownTool(t *testing.T) {
router := NewToolRouter()
_, err := router.Dispatch("nonexistent_tool", json.RawMessage(`{}`))
if err == nil {
t.Error("expected error for unknown tool")
}
if err.Error() != "unknown tool: nonexistent_tool" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestDispatchListSessions(t *testing.T) {
router := NewToolRouter()
// With nil services, list_sessions should return empty result
result, err := router.Dispatch("list_sessions", json.RawMessage(`{}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
m, ok := result.(map[string]interface{})
if !ok {
t.Fatalf("expected map result, got %T", result)
}
ssh, ok := m["ssh"].([]interface{})
if !ok {
t.Fatal("expected ssh key in result")
}
if len(ssh) != 0 {
t.Errorf("expected empty ssh list, got %d items", len(ssh))
}
rdp, ok := m["rdp"].([]interface{})
if !ok {
t.Fatal("expected rdp key in result")
}
if len(rdp) != 0 {
t.Errorf("expected empty rdp list, got %d items", len(rdp))
}
}
func TestDispatchTerminalWriteNoService(t *testing.T) {
router := NewToolRouter()
_, err := router.Dispatch("terminal_write", json.RawMessage(`{"sessionId":"abc","text":"ls\n"}`))
if err == nil {
t.Error("expected error when SSH service is nil")
}
}
func TestDispatchTerminalReadNoService(t *testing.T) {
router := NewToolRouter()
_, err := router.Dispatch("terminal_read", json.RawMessage(`{"sessionId":"abc"}`))
if err == nil {
t.Error("expected error when AI service is nil")
}
}
func TestDispatchSFTPListNoService(t *testing.T) {
router := NewToolRouter()
_, err := router.Dispatch("sftp_list", json.RawMessage(`{"sessionId":"abc","path":"/"}`))
if err == nil {
t.Error("expected error when SFTP service is nil")
}
}
func TestDispatchDisconnectNoService(t *testing.T) {
router := NewToolRouter()
_, err := router.Dispatch("disconnect", json.RawMessage(`{"sessionId":"abc"}`))
if err == nil {
t.Error("expected error when no services available")
}
}

79
internal/ai/screenshot.go Normal file
View File

@ -0,0 +1,79 @@
package ai
import (
"bytes"
"fmt"
"image"
"image/jpeg"
)
// EncodeScreenshot converts raw RGBA pixel data to a JPEG image.
// If the source dimensions exceed maxWidth x maxHeight, the image is
// downscaled using nearest-neighbor sampling (fast, no external deps).
// Returns the JPEG bytes.
func EncodeScreenshot(rgba []byte, srcWidth, srcHeight, maxWidth, maxHeight, quality int) ([]byte, error) {
expectedLen := srcWidth * srcHeight * 4
if len(rgba) < expectedLen {
return nil, fmt.Errorf("RGBA buffer too small: got %d bytes, expected %d for %dx%d", len(rgba), expectedLen, srcWidth, srcHeight)
}
if quality <= 0 || quality > 100 {
quality = 75
}
// Create source image from RGBA buffer
src := image.NewRGBA(image.Rect(0, 0, srcWidth, srcHeight))
copy(src.Pix, rgba[:expectedLen])
// Determine output dimensions
dstWidth, dstHeight := srcWidth, srcHeight
if srcWidth > maxWidth || srcHeight > maxHeight {
dstWidth, dstHeight = fitDimensions(srcWidth, srcHeight, maxWidth, maxHeight)
}
var img image.Image = src
// Downscale if needed using nearest-neighbor sampling
if dstWidth != srcWidth || dstHeight != srcHeight {
dst := image.NewRGBA(image.Rect(0, 0, dstWidth, dstHeight))
for y := 0; y < dstHeight; y++ {
srcY := y * srcHeight / dstHeight
for x := 0; x < dstWidth; x++ {
srcX := x * srcWidth / dstWidth
srcIdx := (srcY*srcWidth + srcX) * 4
dstIdx := (y*dstWidth + x) * 4
dst.Pix[dstIdx+0] = src.Pix[srcIdx+0] // R
dst.Pix[dstIdx+1] = src.Pix[srcIdx+1] // G
dst.Pix[dstIdx+2] = src.Pix[srcIdx+2] // B
dst.Pix[dstIdx+3] = src.Pix[srcIdx+3] // A
}
}
img = dst
}
// Encode to JPEG
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil {
return nil, fmt.Errorf("encode JPEG: %w", err)
}
return buf.Bytes(), nil
}
// fitDimensions calculates the largest dimensions that fit within max bounds
// while preserving aspect ratio.
func fitDimensions(srcW, srcH, maxW, maxH int) (int, int) {
ratio := float64(srcW) / float64(srcH)
w, h := maxW, int(float64(maxW)/ratio)
if h > maxH {
h = maxH
w = int(float64(maxH) * ratio)
}
if w <= 0 {
w = 1
}
if h <= 0 {
h = 1
}
return w, h
}

View File

@ -0,0 +1,118 @@
package ai
import (
"bytes"
"image/jpeg"
"testing"
)
func TestEncodeScreenshot(t *testing.T) {
width, height := 100, 80
// Create a test RGBA buffer (red pixels)
rgba := make([]byte, width*height*4)
for i := 0; i < len(rgba); i += 4 {
rgba[i+0] = 255 // R
rgba[i+1] = 0 // G
rgba[i+2] = 0 // B
rgba[i+3] = 255 // A
}
jpegData, err := EncodeScreenshot(rgba, width, height, 200, 200, 80)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check JPEG magic bytes (FF D8)
if len(jpegData) < 2 {
t.Fatal("JPEG data too small")
}
if jpegData[0] != 0xFF || jpegData[1] != 0xD8 {
t.Errorf("expected JPEG magic bytes FF D8, got %02X %02X", jpegData[0], jpegData[1])
}
// Decode and verify dimensions (no downscale needed in this case)
img, err := jpeg.Decode(bytes.NewReader(jpegData))
if err != nil {
t.Fatalf("decode JPEG: %v", err)
}
bounds := img.Bounds()
if bounds.Dx() != width || bounds.Dy() != height {
t.Errorf("expected %dx%d, got %dx%d", width, height, bounds.Dx(), bounds.Dy())
}
}
func TestEncodeScreenshotDownscale(t *testing.T) {
srcWidth, srcHeight := 1920, 1080
maxWidth, maxHeight := 1280, 720
// Create a test RGBA buffer (gradient)
rgba := make([]byte, srcWidth*srcHeight*4)
for y := 0; y < srcHeight; y++ {
for x := 0; x < srcWidth; x++ {
idx := (y*srcWidth + x) * 4
rgba[idx+0] = byte(x % 256) // R
rgba[idx+1] = byte(y % 256) // G
rgba[idx+2] = byte((x + y) % 256) // B
rgba[idx+3] = 255 // A
}
}
jpegData, err := EncodeScreenshot(rgba, srcWidth, srcHeight, maxWidth, maxHeight, 75)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check JPEG magic bytes
if jpegData[0] != 0xFF || jpegData[1] != 0xD8 {
t.Errorf("expected JPEG magic bytes FF D8, got %02X %02X", jpegData[0], jpegData[1])
}
// Decode and verify dimensions are within max bounds
img, err := jpeg.Decode(bytes.NewReader(jpegData))
if err != nil {
t.Fatalf("decode JPEG: %v", err)
}
bounds := img.Bounds()
if bounds.Dx() > maxWidth {
t.Errorf("output width %d exceeds max %d", bounds.Dx(), maxWidth)
}
if bounds.Dy() > maxHeight {
t.Errorf("output height %d exceeds max %d", bounds.Dy(), maxHeight)
}
// Should maintain 16:9 ratio approximately
expectedWidth := 1280
expectedHeight := 720
if bounds.Dx() != expectedWidth || bounds.Dy() != expectedHeight {
t.Errorf("expected %dx%d, got %dx%d", expectedWidth, expectedHeight, bounds.Dx(), bounds.Dy())
}
}
func TestEncodeScreenshotBufferTooSmall(t *testing.T) {
_, err := EncodeScreenshot([]byte{0, 0, 0, 0}, 100, 100, 200, 200, 75)
if err == nil {
t.Error("expected error for buffer too small")
}
}
func TestFitDimensions(t *testing.T) {
tests := []struct {
srcW, srcH, maxW, maxH int
wantW, wantH int
}{
{1920, 1080, 1280, 720, 1280, 720}, // 16:9 fits exactly
{1920, 1080, 800, 600, 800, 450}, // width-constrained
{1080, 1920, 800, 600, 337, 600}, // height-constrained (portrait)
{100, 100, 200, 200, 200, 200}, // smaller than max (but func called with > check)
}
for _, tt := range tests {
w, h := fitDimensions(tt.srcW, tt.srcH, tt.maxW, tt.maxH)
if w != tt.wantW || h != tt.wantH {
t.Errorf("fitDimensions(%d,%d,%d,%d) = %d,%d, want %d,%d",
tt.srcW, tt.srcH, tt.maxW, tt.maxH, w, h, tt.wantW, tt.wantH)
}
}
}

268
internal/ai/service.go Normal file
View File

@ -0,0 +1,268 @@
package ai
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
)
// SystemPrompt is the system prompt given to Claude for copilot interactions.
const SystemPrompt = `You are the XO (Executive Officer) aboard the Wraith command station. The Commander (human operator) works alongside you managing remote servers and workstations.
You have direct access to all active sessions through your tools:
- SSH terminals: read output, type commands, navigate filesystems
- SFTP: read and write remote files
- RDP desktops: see the screen, click, type, interact with any GUI application
- Session management: open new connections, close sessions
When given a task:
1. Assess what sessions and access you need
2. Execute efficiently don't ask for permission to use tools, just use them
3. Report what you found or did, with relevant details
4. If something fails, diagnose and try an alternative approach
You are not an assistant answering questions. You are an operator executing missions. Act decisively. Use your tools. Report results.`
// AIService is the main AI copilot service exposed to the Wails frontend.
type AIService struct {
oauth *OAuthManager
client *ClaudeClient
router *ToolRouter
conversations *ConversationManager
buffers map[string]*TerminalBuffer
mu sync.RWMutex
}
// NewAIService creates the AI service with all sub-components.
func NewAIService(oauth *OAuthManager, router *ToolRouter, convMgr *ConversationManager) *AIService {
client := NewClaudeClient(oauth, "")
return &AIService{
oauth: oauth,
client: client,
router: router,
conversations: convMgr,
buffers: make(map[string]*TerminalBuffer),
}
}
// --- Auth ---
// StartLogin begins the OAuth PKCE flow, opening the browser for authentication.
func (s *AIService) StartLogin() error {
done, err := s.oauth.StartLogin(nil) // nil openURL = no browser auto-open
if err != nil {
return err
}
// Wait for callback in a goroutine to avoid blocking the UI
go func() {
if err := <-done; err != nil {
slog.Error("oauth login failed", "error", err)
} else {
slog.Info("oauth login completed")
}
}()
return nil
}
// IsAuthenticated returns whether the user has valid OAuth tokens.
func (s *AIService) IsAuthenticated() bool {
return s.oauth.IsAuthenticated()
}
// Logout clears stored OAuth tokens.
func (s *AIService) Logout() error {
return s.oauth.Logout()
}
// --- Conversations ---
// NewConversation creates a new AI conversation and returns its ID.
func (s *AIService) NewConversation() (string, error) {
conv, err := s.conversations.Create(s.client.model)
if err != nil {
return "", err
}
return conv.ID, nil
}
// ListConversations returns all conversations.
func (s *AIService) ListConversations() ([]ConversationSummary, error) {
return s.conversations.List()
}
// DeleteConversation removes a conversation.
func (s *AIService) DeleteConversation(id string) error {
return s.conversations.Delete(id)
}
// --- Chat ---
// SendMessage sends a user message in a conversation and processes the AI response.
// Tool calls are automatically dispatched and results fed back to the model.
// This method blocks until the full response (including any tool use loops) is complete.
func (s *AIService) SendMessage(conversationId, text string) error {
// Add user message to conversation
userMsg := Message{
Role: "user",
Content: []ContentBlock{
{Type: "text", Text: text},
},
}
if err := s.conversations.AddMessage(conversationId, userMsg); err != nil {
return fmt.Errorf("store user message: %w", err)
}
// Run the message loop (handles tool use)
return s.messageLoop(conversationId)
}
// messageLoop sends the conversation to Claude and handles tool use loops.
func (s *AIService) messageLoop(conversationId string) error {
for iterations := 0; iterations < 20; iterations++ { // safety limit
messages, err := s.conversations.GetMessages(conversationId)
if err != nil {
return err
}
ch, err := s.client.SendMessage(messages, CopilotTools, SystemPrompt)
if err != nil {
return fmt.Errorf("send to claude: %w", err)
}
// Collect the response
var textParts []string
var toolCalls []ContentBlock
var currentToolInput string
for event := range ch {
switch event.Type {
case "text_delta":
textParts = append(textParts, event.Data)
case "tool_use_start":
currentToolInput = ""
case "tool_use_delta":
currentToolInput += event.Data
case "done":
// Parse usage if available
var delta struct {
Usage Usage `json:"usage"`
}
if json.Unmarshal([]byte(event.Data), &delta) == nil && delta.Usage.OutputTokens > 0 {
s.conversations.UpdateTokenUsage(conversationId, delta.Usage.InputTokens, delta.Usage.OutputTokens)
}
case "error":
return fmt.Errorf("stream error: %s", event.Data)
}
// When a tool_use block completes, we need to check content_block_stop
// But since we accumulate, we'll finalize after the stream ends
_ = toolCalls // kept for the final assembly below
}
// Build the assistant message
var assistantContent []ContentBlock
if len(textParts) > 0 {
fullText := ""
for _, p := range textParts {
fullText += p
}
if fullText != "" {
assistantContent = append(assistantContent, ContentBlock{
Type: "text",
Text: fullText,
})
}
}
for _, tc := range toolCalls {
assistantContent = append(assistantContent, tc)
}
if len(assistantContent) == 0 {
return nil // empty response
}
// Store assistant message
assistantMsg := Message{
Role: "assistant",
Content: assistantContent,
}
if err := s.conversations.AddMessage(conversationId, assistantMsg); err != nil {
return fmt.Errorf("store assistant message: %w", err)
}
// If there were tool calls, dispatch them and continue the loop
hasToolUse := false
for _, block := range assistantContent {
if block.Type == "tool_use" {
hasToolUse = true
break
}
}
if !hasToolUse {
return nil // done, no tool use to process
}
// Dispatch tool calls and create tool_result message
var toolResults []ContentBlock
for _, block := range assistantContent {
if block.Type != "tool_use" {
continue
}
result, err := s.router.Dispatch(block.Name, block.Input)
resultBlock := ContentBlock{
Type: "tool_result",
ToolUseID: block.ID,
}
if err != nil {
resultBlock.IsError = true
resultBlock.Content = []ContentBlock{
{Type: "text", Text: err.Error()},
}
} else {
resultJSON, _ := json.Marshal(result)
resultBlock.Content = []ContentBlock{
{Type: "text", Text: string(resultJSON)},
}
}
toolResults = append(toolResults, resultBlock)
}
toolResultMsg := Message{
Role: "user",
Content: toolResults,
}
if err := s.conversations.AddMessage(conversationId, toolResultMsg); err != nil {
return fmt.Errorf("store tool results: %w", err)
}
// Continue the loop to let Claude process tool results
}
return fmt.Errorf("exceeded maximum tool use iterations")
}
// --- Terminal Buffer Management ---
// GetBuffer returns the terminal output buffer for a session, creating it if needed.
func (s *AIService) GetBuffer(sessionId string) *TerminalBuffer {
s.mu.Lock()
defer s.mu.Unlock()
buf, ok := s.buffers[sessionId]
if !ok {
buf = NewTerminalBuffer(200)
s.buffers[sessionId] = buf
}
return buf
}
// RemoveBuffer removes the terminal buffer for a session.
func (s *AIService) RemoveBuffer(sessionId string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.buffers, sessionId)
}

View File

@ -0,0 +1,101 @@
package ai
import (
"strings"
"sync"
)
// TerminalBuffer is a thread-safe ring buffer that captures terminal output lines.
// It is written to by SSH read loops and read by the AI tool dispatch for terminal_read.
type TerminalBuffer struct {
lines []string
mu sync.RWMutex
max int
partial string // accumulates data that doesn't end with \n
}
// NewTerminalBuffer creates a buffer that retains at most maxLines lines.
func NewTerminalBuffer(maxLines int) *TerminalBuffer {
if maxLines <= 0 {
maxLines = 200
}
return &TerminalBuffer{
lines: make([]string, 0, maxLines),
max: maxLines,
}
}
// Write ingests raw terminal output, splitting on newlines and appending complete lines.
// Partial lines (data without a trailing newline) are accumulated until the next Write.
func (b *TerminalBuffer) Write(data []byte) {
b.mu.Lock()
defer b.mu.Unlock()
text := b.partial + string(data)
b.partial = ""
parts := strings.Split(text, "\n")
// The last element of Split is always either:
// - empty string if text ends with \n (discard it)
// - a partial line if text doesn't end with \n (save as partial)
last := parts[len(parts)-1]
parts = parts[:len(parts)-1]
if last != "" {
b.partial = last
}
for _, line := range parts {
b.lines = append(b.lines, line)
}
// Trim to max
if len(b.lines) > b.max {
excess := len(b.lines) - b.max
b.lines = b.lines[excess:]
}
}
// ReadLast returns the last n lines from the buffer.
// If fewer than n lines are available, all lines are returned.
func (b *TerminalBuffer) ReadLast(n int) []string {
b.mu.RLock()
defer b.mu.RUnlock()
total := len(b.lines)
if n > total {
n = total
}
if n <= 0 {
return []string{}
}
result := make([]string, n)
copy(result, b.lines[total-n:])
return result
}
// ReadAll returns all lines currently in the buffer.
func (b *TerminalBuffer) ReadAll() []string {
b.mu.RLock()
defer b.mu.RUnlock()
result := make([]string, len(b.lines))
copy(result, b.lines)
return result
}
// Clear removes all lines from the buffer.
func (b *TerminalBuffer) Clear() {
b.mu.Lock()
defer b.mu.Unlock()
b.lines = b.lines[:0]
b.partial = ""
}
// Len returns the number of complete lines in the buffer.
func (b *TerminalBuffer) Len() int {
b.mu.RLock()
defer b.mu.RUnlock()
return len(b.lines)
}

View File

@ -0,0 +1,175 @@
package ai
import (
"fmt"
"sync"
"testing"
)
func TestWriteAndRead(t *testing.T) {
buf := NewTerminalBuffer(100)
buf.Write([]byte("line1\nline2\nline3\n"))
all := buf.ReadAll()
if len(all) != 3 {
t.Fatalf("expected 3 lines, got %d", len(all))
}
if all[0] != "line1" || all[1] != "line2" || all[2] != "line3" {
t.Errorf("unexpected lines: %v", all)
}
}
func TestRingBufferOverflow(t *testing.T) {
buf := NewTerminalBuffer(200)
// Write 300 lines
for i := 0; i < 300; i++ {
buf.Write([]byte(fmt.Sprintf("line %d\n", i)))
}
all := buf.ReadAll()
if len(all) != 200 {
t.Fatalf("expected 200 lines (ring buffer), got %d", len(all))
}
// First line should be "line 100" (oldest retained)
if all[0] != "line 100" {
t.Errorf("expected first line 'line 100', got %q", all[0])
}
// Last line should be "line 299"
if all[199] != "line 299" {
t.Errorf("expected last line 'line 299', got %q", all[199])
}
}
func TestReadLastSubset(t *testing.T) {
buf := NewTerminalBuffer(100)
buf.Write([]byte("a\nb\nc\nd\ne\n"))
last3 := buf.ReadLast(3)
if len(last3) != 3 {
t.Fatalf("expected 3 lines, got %d", len(last3))
}
if last3[0] != "c" || last3[1] != "d" || last3[2] != "e" {
t.Errorf("unexpected lines: %v", last3)
}
// Request more than available
last10 := buf.ReadLast(10)
if len(last10) != 5 {
t.Fatalf("expected 5 lines (all available), got %d", len(last10))
}
}
func TestConcurrentAccess(t *testing.T) {
buf := NewTerminalBuffer(1000)
var wg sync.WaitGroup
// Spawn 10 writers
for w := 0; w < 10; w++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i := 0; i < 100; i++ {
buf.Write([]byte(fmt.Sprintf("writer %d line %d\n", id, i)))
}
}(w)
}
// Spawn 5 readers
for r := 0; r < 5; r++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
_ = buf.ReadLast(10)
_ = buf.Len()
}
}()
}
wg.Wait()
total := buf.Len()
if total != 1000 {
t.Errorf("expected 1000 lines (10 writers * 100), got %d", total)
}
}
func TestWritePartialLines(t *testing.T) {
buf := NewTerminalBuffer(100)
// Write data without trailing newline
buf.Write([]byte("partial"))
// Should have no complete lines yet
if buf.Len() != 0 {
t.Errorf("expected 0 lines for partial write, got %d", buf.Len())
}
// Complete the line
buf.Write([]byte(" line\n"))
if buf.Len() != 1 {
t.Fatalf("expected 1 line after completing partial, got %d", buf.Len())
}
all := buf.ReadAll()
if all[0] != "partial line" {
t.Errorf("expected 'partial line', got %q", all[0])
}
}
func TestWriteMultiplePartials(t *testing.T) {
buf := NewTerminalBuffer(100)
buf.Write([]byte("hello "))
buf.Write([]byte("world"))
buf.Write([]byte("!\nfoo\n"))
all := buf.ReadAll()
if len(all) != 2 {
t.Fatalf("expected 2 lines, got %d: %v", len(all), all)
}
if all[0] != "hello world!" {
t.Errorf("expected 'hello world!', got %q", all[0])
}
if all[1] != "foo" {
t.Errorf("expected 'foo', got %q", all[1])
}
}
func TestClear(t *testing.T) {
buf := NewTerminalBuffer(100)
buf.Write([]byte("line1\nline2\n"))
buf.Clear()
if buf.Len() != 0 {
t.Errorf("expected 0 lines after clear, got %d", buf.Len())
}
}
func TestReadLastZero(t *testing.T) {
buf := NewTerminalBuffer(100)
buf.Write([]byte("line\n"))
result := buf.ReadLast(0)
if len(result) != 0 {
t.Errorf("expected empty result for ReadLast(0), got %d lines", len(result))
}
}
func TestReadLastNegative(t *testing.T) {
buf := NewTerminalBuffer(100)
buf.Write([]byte("line\n"))
result := buf.ReadLast(-1)
if len(result) != 0 {
t.Errorf("expected empty result for ReadLast(-1), got %d lines", len(result))
}
}

207
internal/ai/tools.go Normal file
View File

@ -0,0 +1,207 @@
package ai
import "encoding/json"
// CopilotTools defines all tools available to the AI copilot.
var CopilotTools = []Tool{
// ── SSH Terminal Tools ────────────────────────────────────────
{
Name: "terminal_write",
Description: "Type text into an active SSH terminal session. Use this to execute commands by including a trailing newline character.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The SSH session ID"},
"text": {"type": "string", "description": "Text to type into the terminal (include \\n for Enter)"}
},
"required": ["sessionId", "text"]
}`),
},
{
Name: "terminal_read",
Description: "Read recent output from an SSH terminal session. Returns the last N lines from the terminal buffer.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The SSH session ID"},
"lines": {"type": "integer", "description": "Number of recent lines to return (default 50)", "default": 50}
},
"required": ["sessionId"]
}`),
},
{
Name: "terminal_cwd",
Description: "Get the current working directory of an SSH terminal session by executing pwd.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The SSH session ID"}
},
"required": ["sessionId"]
}`),
},
// ── SFTP Tools ───────────────────────────────────────────────
{
Name: "sftp_list",
Description: "List files and directories at a remote path via SFTP.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The SSH/SFTP session ID"},
"path": {"type": "string", "description": "Remote directory path to list"}
},
"required": ["sessionId", "path"]
}`),
},
{
Name: "sftp_read",
Description: "Read the contents of a remote file via SFTP. Limited to files under 5MB.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The SSH/SFTP session ID"},
"path": {"type": "string", "description": "Remote file path to read"}
},
"required": ["sessionId", "path"]
}`),
},
{
Name: "sftp_write",
Description: "Write content to a remote file via SFTP. Creates or overwrites the file.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The SSH/SFTP session ID"},
"path": {"type": "string", "description": "Remote file path to write"},
"content": {"type": "string", "description": "File content to write"}
},
"required": ["sessionId", "path", "content"]
}`),
},
// ── RDP Tools ────────────────────────────────────────────────
{
Name: "rdp_screenshot",
Description: "Capture a screenshot of the current RDP desktop. Returns a base64 JPEG image.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"}
},
"required": ["sessionId"]
}`),
},
{
Name: "rdp_click",
Description: "Click at a specific position on the RDP desktop.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"},
"x": {"type": "integer", "description": "X coordinate in pixels"},
"y": {"type": "integer", "description": "Y coordinate in pixels"},
"button": {"type": "string", "enum": ["left", "right", "middle"], "default": "left", "description": "Mouse button to click"}
},
"required": ["sessionId", "x", "y"]
}`),
},
{
Name: "rdp_doubleclick",
Description: "Double-click at a specific position on the RDP desktop.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"},
"x": {"type": "integer", "description": "X coordinate in pixels"},
"y": {"type": "integer", "description": "Y coordinate in pixels"}
},
"required": ["sessionId", "x", "y"]
}`),
},
{
Name: "rdp_type",
Description: "Type text into the focused element on the RDP desktop. Uses clipboard paste for reliability.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"},
"text": {"type": "string", "description": "Text to type"}
},
"required": ["sessionId", "text"]
}`),
},
{
Name: "rdp_keypress",
Description: "Press a key or key combination on the RDP desktop (e.g., 'Enter', 'Tab', 'Ctrl+C').",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"},
"key": {"type": "string", "description": "Key name or combination (e.g., 'Enter', 'Ctrl+C', 'Alt+F4')"}
},
"required": ["sessionId", "key"]
}`),
},
{
Name: "rdp_scroll",
Description: "Scroll the mouse wheel at a position on the RDP desktop.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"},
"x": {"type": "integer", "description": "X coordinate in pixels"},
"y": {"type": "integer", "description": "Y coordinate in pixels"},
"direction": {"type": "string", "enum": ["up", "down"], "description": "Scroll direction"},
"clicks": {"type": "integer", "description": "Number of scroll clicks (default 3)", "default": 3}
},
"required": ["sessionId", "x", "y", "direction"]
}`),
},
{
Name: "rdp_move",
Description: "Move the mouse cursor to a position on the RDP desktop without clicking.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The RDP session ID"},
"x": {"type": "integer", "description": "X coordinate in pixels"},
"y": {"type": "integer", "description": "Y coordinate in pixels"}
},
"required": ["sessionId", "x", "y"]
}`),
},
// ── Session Management Tools ─────────────────────────────────
{
Name: "list_sessions",
Description: "List all active SSH and RDP sessions with their IDs, connection info, and state.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {},
"required": []
}`),
},
{
Name: "connect_ssh",
Description: "Open a new SSH connection to a saved connection by its ID.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"connectionId": {"type": "integer", "description": "The saved connection ID from the connection manager"}
},
"required": ["connectionId"]
}`),
},
{
Name: "disconnect",
Description: "Disconnect an active session.",
InputSchema: json.RawMessage(`{
"type": "object",
"properties": {
"sessionId": {"type": "string", "description": "The session ID to disconnect"}
},
"required": ["sessionId"]
}`),
},
}

75
internal/ai/types.go Normal file
View File

@ -0,0 +1,75 @@
package ai
import (
"encoding/json"
"time"
)
// Message represents a single message in a conversation with Claude.
type Message struct {
Role string `json:"role"` // "user" or "assistant"
Content []ContentBlock `json:"content"` // one or more content blocks
}
// ContentBlock is a polymorphic block within a Message.
// Only one of the content fields will be populated depending on Type.
type ContentBlock struct {
Type string `json:"type"` // "text", "image", "tool_use", "tool_result"
// text
Text string `json:"text,omitempty"`
// image (base64 source)
Source *ImageSource `json:"source,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
// tool_result
ToolUseID string `json:"tool_use_id,omitempty"`
Content []ContentBlock `json:"content,omitempty"`
IsError bool `json:"is_error,omitempty"`
}
// ImageSource holds a base64-encoded image for vision requests.
type ImageSource struct {
Type string `json:"type"` // "base64"
MediaType string `json:"media_type"` // "image/jpeg", "image/png", etc.
Data string `json:"data"` // base64-encoded image data
}
// Tool describes a tool available to the model.
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema json.RawMessage `json:"input_schema"`
}
// StreamEvent represents a single event from the SSE stream.
type StreamEvent struct {
Type string `json:"type"` // "text_delta", "tool_use_start", "tool_use_delta", "tool_result", "done", "error"
Data string `json:"data"` // event payload
// Populated for tool_use_start events
ToolName string `json:"tool_name,omitempty"`
ToolID string `json:"tool_id,omitempty"`
ToolInput string `json:"tool_input,omitempty"`
}
// Usage tracks token consumption for a request.
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// ConversationSummary is a lightweight view of a conversation for listing.
type ConversationSummary struct {
ID string `json:"id"`
Title string `json:"title"`
Model string `json:"model"`
CreatedAt time.Time `json:"createdAt"`
TokensIn int `json:"tokensIn"`
TokensOut int `json:"tokensOut"`
}

View File

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/vstockwell/wraith/internal/ai"
"github.com/vstockwell/wraith/internal/connections"
"github.com/vstockwell/wraith/internal/credentials"
"github.com/vstockwell/wraith/internal/db"
@ -35,6 +36,8 @@ type WraithApp struct {
SFTP *sftp.SFTPService
RDP *rdp.RDPService
Credentials *credentials.CredentialService
AI *ai.AIService
oauthMgr *ai.OAuthManager
unlocked bool
}
@ -78,6 +81,15 @@ func New() (*WraithApp, error) {
// CredentialService requires the vault to be unlocked, so it starts nil.
// It is created lazily after the vault is unlocked via initCredentials().
// AI copilot services — OAuthManager starts without a vault reference;
// it will be wired once the vault is unlocked.
oauthMgr := ai.NewOAuthManager(settingsSvc, nil)
toolRouter := ai.NewToolRouter()
toolRouter.SetServices(sshSvc, sftpSvc, rdpSvc, sessionMgr, connSvc)
convMgr := ai.NewConversationManager(database)
aiSvc := ai.NewAIService(oauthMgr, toolRouter, convMgr)
toolRouter.SetAIService(aiSvc)
// Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent)
if err := themeSvc.SeedBuiltins(); err != nil {
slog.Warn("failed to seed themes", "error", err)
@ -93,6 +105,8 @@ func New() (*WraithApp, error) {
SSH: sshSvc,
SFTP: sftpSvc,
RDP: rdpSvc,
AI: aiSvc,
oauthMgr: oauthMgr,
}, nil
}
@ -197,9 +211,13 @@ func (a *WraithApp) IsUnlocked() bool {
return a.unlocked
}
// initCredentials creates the CredentialService after the vault is unlocked.
// initCredentials creates the CredentialService after the vault is unlocked
// and wires the vault to services that need it (OAuth, etc.).
func (a *WraithApp) initCredentials() {
if a.Vault != nil {
a.Credentials = credentials.NewCredentialService(a.db, a.Vault)
if a.oauthMgr != nil {
a.oauthMgr.SetVault(a.Vault)
}
}
}

View File

@ -0,0 +1,10 @@
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
title TEXT,
model TEXT NOT NULL,
messages TEXT NOT NULL DEFAULT '[]',
tokens_in INTEGER DEFAULT 0,
tokens_out INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@ -34,6 +34,7 @@ func main() {
application.NewService(wraith.SSH),
application.NewService(wraith.SFTP),
application.NewService(wraith.RDP),
application.NewService(wraith.AI),
},
Assets: application.AssetOptions{
Handler: application.BundledAssetFileServer(assets),