feat: AI copilot backend — OAuth PKCE, Claude API streaming, 16 tools, conversations
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled
- OAuth PKCE flow for Max subscription auth (no API key needed) - Claude API client with SSE streaming (Messages API v1) - 16 tool definitions: terminal, SFTP, RDP, session management - Tool dispatch router mapping to existing Wraith services - Conversation manager with SQLite persistence - Terminal output ring buffer for AI context - RDP screenshot encoder (RGBA → JPEG with downscaling) - Wired into Wails app as AIService Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
be868e8172
commit
7ee5321d69
256
internal/ai/client.go
Normal file
256
internal/ai/client.go
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultModel = "claude-sonnet-4-20250514"
|
||||||
|
anthropicVersion = "2023-06-01"
|
||||||
|
apiBaseURL = "https://api.anthropic.com/v1/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaudeClient is the HTTP client for the Anthropic Messages API.
|
||||||
|
type ClaudeClient struct {
|
||||||
|
oauth *OAuthManager
|
||||||
|
apiKey string // fallback for non-OAuth users
|
||||||
|
model string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeClient creates a client that authenticates via OAuth (primary) or API key (fallback).
|
||||||
|
func NewClaudeClient(oauth *OAuthManager, model string) *ClaudeClient {
|
||||||
|
if model == "" {
|
||||||
|
model = defaultModel
|
||||||
|
}
|
||||||
|
return &ClaudeClient{
|
||||||
|
oauth: oauth,
|
||||||
|
model: model,
|
||||||
|
httpClient: &http.Client{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAPIKey sets a fallback API key for users without Max subscriptions.
|
||||||
|
func (c *ClaudeClient) SetAPIKey(key string) {
|
||||||
|
c.apiKey = key
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiRequest is the JSON body sent to the Messages API.
|
||||||
|
type apiRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a request to the Messages API with streaming enabled.
|
||||||
|
// It returns a channel of StreamEvents that the caller reads until the channel is closed.
|
||||||
|
func (c *ClaudeClient) SendMessage(messages []Message, tools []Tool, systemPrompt string) (<-chan StreamEvent, error) {
|
||||||
|
reqBody := apiRequest{
|
||||||
|
Model: c.model,
|
||||||
|
Messages: messages,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
Stream: true,
|
||||||
|
System: systemPrompt,
|
||||||
|
Tools: tools,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", apiBaseURL, bytes.NewReader(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("anthropic-version", anthropicVersion)
|
||||||
|
|
||||||
|
if err := c.setAuthHeader(req); err != nil {
|
||||||
|
return nil, fmt.Errorf("set auth header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan StreamEvent, 64)
|
||||||
|
go c.parseSSEStream(resp.Body, ch)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setAuthHeader sets the appropriate authorization header on the request.
|
||||||
|
func (c *ClaudeClient) setAuthHeader(req *http.Request) error {
|
||||||
|
// Try OAuth first
|
||||||
|
if c.oauth != nil && c.oauth.IsAuthenticated() {
|
||||||
|
token, err := c.oauth.GetAccessToken()
|
||||||
|
if err == nil {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
slog.Warn("oauth token failed, falling back to api key", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to API key
|
||||||
|
if c.apiKey != "" {
|
||||||
|
req.Header.Set("x-api-key", c.apiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("no authentication method available — log in via OAuth or set an API key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSSEStream reads the SSE response body and sends StreamEvents on the channel.
|
||||||
|
func (c *ClaudeClient) parseSSEStream(body io.ReadCloser, ch chan<- StreamEvent) {
|
||||||
|
defer body.Close()
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(body)
|
||||||
|
|
||||||
|
var currentEventType string
|
||||||
|
var currentToolID string
|
||||||
|
var currentToolName string
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "event: ") {
|
||||||
|
currentEventType = strings.TrimPrefix(line, "event: ")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
|
||||||
|
switch currentEventType {
|
||||||
|
case "content_block_start":
|
||||||
|
var block struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
ContentBlock struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Input string `json:"input"`
|
||||||
|
} `json:"content_block"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(data), &block); err != nil {
|
||||||
|
slog.Warn("failed to parse content_block_start", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.ContentBlock.Type == "tool_use" {
|
||||||
|
currentToolID = block.ContentBlock.ID
|
||||||
|
currentToolName = block.ContentBlock.Name
|
||||||
|
ch <- StreamEvent{
|
||||||
|
Type: "tool_use_start",
|
||||||
|
ToolID: currentToolID,
|
||||||
|
ToolName: currentToolName,
|
||||||
|
}
|
||||||
|
} else if block.ContentBlock.Type == "text" && block.ContentBlock.Text != "" {
|
||||||
|
ch <- StreamEvent{Type: "text_delta", Data: block.ContentBlock.Text}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "content_block_delta":
|
||||||
|
var delta struct {
|
||||||
|
Delta struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
PartialJSON string `json:"partial_json"`
|
||||||
|
} `json:"delta"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(data), &delta); err != nil {
|
||||||
|
slog.Warn("failed to parse content_block_delta", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if delta.Delta.Type == "text_delta" {
|
||||||
|
ch <- StreamEvent{Type: "text_delta", Data: delta.Delta.Text}
|
||||||
|
} else if delta.Delta.Type == "input_json_delta" {
|
||||||
|
ch <- StreamEvent{
|
||||||
|
Type: "tool_use_delta",
|
||||||
|
Data: delta.Delta.PartialJSON,
|
||||||
|
ToolID: currentToolID,
|
||||||
|
ToolName: currentToolName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "content_block_stop":
|
||||||
|
// nothing to do; tool input is accumulated by the caller
|
||||||
|
|
||||||
|
case "message_delta":
|
||||||
|
// contains stop_reason and usage; emit usage info
|
||||||
|
ch <- StreamEvent{Type: "done", Data: data}
|
||||||
|
|
||||||
|
case "message_stop":
|
||||||
|
// end of message stream
|
||||||
|
return
|
||||||
|
|
||||||
|
case "message_start":
|
||||||
|
// contains message metadata; could extract usage but we handle it at message_delta
|
||||||
|
|
||||||
|
case "ping":
|
||||||
|
// heartbeat; ignore
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
ch <- StreamEvent{Type: "error", Data: data}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
ch <- StreamEvent{Type: "error", Data: err.Error()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestBody creates the JSON request body for testing purposes.
|
||||||
|
func BuildRequestBody(messages []Message, tools []Tool, systemPrompt, model string) ([]byte, error) {
|
||||||
|
if model == "" {
|
||||||
|
model = defaultModel
|
||||||
|
}
|
||||||
|
req := apiRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: messages,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
Stream: true,
|
||||||
|
System: systemPrompt,
|
||||||
|
Tools: tools,
|
||||||
|
}
|
||||||
|
return json.Marshal(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseSSELine parses a single SSE data line. Returns the event type and data payload.
|
||||||
|
func ParseSSELine(line string) (eventType, data string) {
|
||||||
|
if strings.HasPrefix(line, "event: ") {
|
||||||
|
return "event", strings.TrimPrefix(line, "event: ")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "data: ") {
|
||||||
|
return "data", strings.TrimPrefix(line, "data: ")
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
123
internal/ai/client_test.go
Normal file
123
internal/ai/client_test.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildRequestBody(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{Type: "text", Text: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tools := []Tool{
|
||||||
|
{
|
||||||
|
Name: "test_tool",
|
||||||
|
Description: "A test tool",
|
||||||
|
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := BuildRequestBody(messages, tools, "You are helpful.", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||||
|
t.Fatalf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check model
|
||||||
|
if m, ok := parsed["model"].(string); !ok || m != defaultModel {
|
||||||
|
t.Errorf("expected model %s, got %v", defaultModel, parsed["model"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check stream
|
||||||
|
if s, ok := parsed["stream"].(bool); !ok || !s {
|
||||||
|
t.Errorf("expected stream true, got %v", parsed["stream"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check system prompt
|
||||||
|
if s, ok := parsed["system"].(string); !ok || s != "You are helpful." {
|
||||||
|
t.Errorf("expected system prompt, got %v", parsed["system"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check max_tokens
|
||||||
|
if mt, ok := parsed["max_tokens"].(float64); !ok || int(mt) != 8192 {
|
||||||
|
t.Errorf("expected max_tokens 8192, got %v", parsed["max_tokens"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check messages array exists
|
||||||
|
msgs, ok := parsed["messages"].([]interface{})
|
||||||
|
if !ok || len(msgs) != 1 {
|
||||||
|
t.Errorf("expected 1 message, got %v", parsed["messages"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check tools array exists
|
||||||
|
tls, ok := parsed["tools"].([]interface{})
|
||||||
|
if !ok || len(tls) != 1 {
|
||||||
|
t.Errorf("expected 1 tool, got %v", parsed["tools"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSSELine(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
wantType string
|
||||||
|
wantData string
|
||||||
|
}{
|
||||||
|
{"event: content_block_delta", "event", "content_block_delta"},
|
||||||
|
{"data: {\"delta\":{\"text\":\"hello\"}}", "data", "{\"delta\":{\"text\":\"hello\"}}"},
|
||||||
|
{"", "", ""},
|
||||||
|
{"random line", "", ""},
|
||||||
|
{"event: message_stop", "event", "message_stop"},
|
||||||
|
{"data: [DONE]", "data", "[DONE]"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
gotType, gotData := ParseSSELine(tt.input)
|
||||||
|
if gotType != tt.wantType {
|
||||||
|
t.Errorf("ParseSSELine(%q) type = %q, want %q", tt.input, gotType, tt.wantType)
|
||||||
|
}
|
||||||
|
if gotData != tt.wantData {
|
||||||
|
t.Errorf("ParseSSELine(%q) data = %q, want %q", tt.input, gotData, tt.wantData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHeader(t *testing.T) {
|
||||||
|
// Test with API key only (no OAuth)
|
||||||
|
client := NewClaudeClient(nil, "")
|
||||||
|
client.SetAPIKey("sk-test-12345")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||||
|
if err := client.setAuthHeader(req); err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := req.Header.Get("x-api-key"); got != "sk-test-12345" {
|
||||||
|
t.Errorf("expected x-api-key sk-test-12345, got %q", got)
|
||||||
|
}
|
||||||
|
if got := req.Header.Get("Authorization"); got != "" {
|
||||||
|
t.Errorf("expected no Authorization header, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHeaderNoAuth(t *testing.T) {
|
||||||
|
// No OAuth, no API key — should return error
|
||||||
|
client := NewClaudeClient(nil, "")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", "https://example.com", nil)
|
||||||
|
err := client.setAuthHeader(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when no auth method available")
|
||||||
|
}
|
||||||
|
}
|
||||||
169
internal/ai/conversation.go
Normal file
169
internal/ai/conversation.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationManager handles CRUD operations for AI conversations stored in SQLite.
|
||||||
|
type ConversationManager struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConversationManager creates a manager backed by the given database.
|
||||||
|
func NewConversationManager(db *sql.DB) *ConversationManager {
|
||||||
|
return &ConversationManager{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create starts a new conversation and returns its summary.
|
||||||
|
func (m *ConversationManager) Create(model string) (*ConversationSummary, error) {
|
||||||
|
id := uuid.NewString()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
_, err := m.db.Exec(
|
||||||
|
`INSERT INTO conversations (id, title, model, messages, tokens_in, tokens_out, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, '[]', 0, 0, ?, ?)`,
|
||||||
|
id, "New conversation", model, now, now,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ConversationSummary{
|
||||||
|
ID: id,
|
||||||
|
Title: "New conversation",
|
||||||
|
Model: model,
|
||||||
|
CreatedAt: now,
|
||||||
|
TokensIn: 0,
|
||||||
|
TokensOut: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMessage appends a message to the conversation's message list.
|
||||||
|
func (m *ConversationManager) AddMessage(convId string, msg Message) error {
|
||||||
|
// Get existing messages
|
||||||
|
messages, err := m.GetMessages(convId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, msg)
|
||||||
|
data, err := json.Marshal(messages)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = m.db.Exec(
|
||||||
|
"UPDATE conversations SET messages = ?, updated_at = ? WHERE id = ?",
|
||||||
|
string(data), time.Now(), convId,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-title from first user message
|
||||||
|
if len(messages) == 1 && msg.Role == "user" {
|
||||||
|
title := extractTitle(msg)
|
||||||
|
if title != "" {
|
||||||
|
m.db.Exec("UPDATE conversations SET title = ? WHERE id = ?", title, convId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages returns all messages in a conversation.
|
||||||
|
func (m *ConversationManager) GetMessages(convId string) ([]Message, error) {
|
||||||
|
var messagesJSON string
|
||||||
|
err := m.db.QueryRow("SELECT messages FROM conversations WHERE id = ?", convId).Scan(&messagesJSON)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("conversation %s not found", convId)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []Message
|
||||||
|
if err := json.Unmarshal([]byte(messagesJSON), &messages); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal messages: %w", err)
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all conversations ordered by most recent.
|
||||||
|
func (m *ConversationManager) List() ([]ConversationSummary, error) {
|
||||||
|
rows, err := m.db.Query(
|
||||||
|
`SELECT id, title, model, tokens_in, tokens_out, created_at
|
||||||
|
FROM conversations ORDER BY updated_at DESC`,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list conversations: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var summaries []ConversationSummary
|
||||||
|
for rows.Next() {
|
||||||
|
var s ConversationSummary
|
||||||
|
var title sql.NullString
|
||||||
|
if err := rows.Scan(&s.ID, &title, &s.Model, &s.TokensIn, &s.TokensOut, &s.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan conversation: %w", err)
|
||||||
|
}
|
||||||
|
if title.Valid {
|
||||||
|
s.Title = title.String
|
||||||
|
}
|
||||||
|
summaries = append(summaries, s)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate conversations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if summaries == nil {
|
||||||
|
summaries = []ConversationSummary{}
|
||||||
|
}
|
||||||
|
return summaries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a conversation and all its messages.
|
||||||
|
func (m *ConversationManager) Delete(convId string) error {
|
||||||
|
result, err := m.db.Exec("DELETE FROM conversations WHERE id = ?", convId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete conversation: %w", err)
|
||||||
|
}
|
||||||
|
affected, _ := result.RowsAffected()
|
||||||
|
if affected == 0 {
|
||||||
|
return fmt.Errorf("conversation %s not found", convId)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTokenUsage adds to the token counters for a conversation.
|
||||||
|
func (m *ConversationManager) UpdateTokenUsage(convId string, tokensIn, tokensOut int) error {
|
||||||
|
_, err := m.db.Exec(
|
||||||
|
`UPDATE conversations
|
||||||
|
SET tokens_in = tokens_in + ?, tokens_out = tokens_out + ?, updated_at = ?
|
||||||
|
WHERE id = ?`,
|
||||||
|
tokensIn, tokensOut, time.Now(), convId,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update token usage: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTitle generates a title from the first user message (truncated to 80 chars).
|
||||||
|
func extractTitle(msg Message) string {
|
||||||
|
for _, block := range msg.Content {
|
||||||
|
if block.Type == "text" && block.Text != "" {
|
||||||
|
title := block.Text
|
||||||
|
if len(title) > 80 {
|
||||||
|
title = title[:77] + "..."
|
||||||
|
}
|
||||||
|
return title
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
220
internal/ai/conversation_test.go
Normal file
220
internal/ai/conversation_test.go
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupConversationManager(t *testing.T) *ConversationManager {
|
||||||
|
t.Helper()
|
||||||
|
database, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open db: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(database); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
// Create the conversations table (002 migration)
|
||||||
|
_, err = database.Exec(`CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
id TEXT PRIMARY KEY, title TEXT, model TEXT NOT NULL,
|
||||||
|
messages TEXT NOT NULL DEFAULT '[]',
|
||||||
|
tokens_in INTEGER DEFAULT 0, tokens_out INTEGER DEFAULT 0,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP)`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create conversations table: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { database.Close() })
|
||||||
|
return NewConversationManager(database)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateConversation(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
conv, err := mgr.Create("claude-sonnet-4-20250514")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conv.ID == "" {
|
||||||
|
t.Error("expected non-empty ID")
|
||||||
|
}
|
||||||
|
if conv.Model != "claude-sonnet-4-20250514" {
|
||||||
|
t.Errorf("expected model claude-sonnet-4-20250514, got %s", conv.Model)
|
||||||
|
}
|
||||||
|
if conv.Title != "New conversation" {
|
||||||
|
t.Errorf("expected title 'New conversation', got %s", conv.Title)
|
||||||
|
}
|
||||||
|
if conv.TokensIn != 0 || conv.TokensOut != 0 {
|
||||||
|
t.Errorf("expected zero tokens, got in=%d out=%d", conv.TokensIn, conv.TokensOut)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddAndGetMessages(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
conv, err := mgr.Create("test-model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a user message
|
||||||
|
userMsg := Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{Type: "text", Text: "What is running on port 8080?"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := mgr.AddMessage(conv.ID, userMsg); err != nil {
|
||||||
|
t.Fatalf("add user message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an assistant message
|
||||||
|
assistantMsg := Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{Type: "text", Text: "Let me check that for you."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := mgr.AddMessage(conv.ID, assistantMsg); err != nil {
|
||||||
|
t.Fatalf("add assistant message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get messages
|
||||||
|
messages, err := mgr.GetMessages(conv.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get messages: %v", err)
|
||||||
|
}
|
||||||
|
if len(messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(messages))
|
||||||
|
}
|
||||||
|
if messages[0].Role != "user" {
|
||||||
|
t.Errorf("expected first message role 'user', got %s", messages[0].Role)
|
||||||
|
}
|
||||||
|
if messages[1].Role != "assistant" {
|
||||||
|
t.Errorf("expected second message role 'assistant', got %s", messages[1].Role)
|
||||||
|
}
|
||||||
|
if messages[0].Content[0].Text != "What is running on port 8080?" {
|
||||||
|
t.Errorf("unexpected message text: %s", messages[0].Content[0].Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListConversations(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
// Create multiple conversations
|
||||||
|
_, err := mgr.Create("model-a")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create 1: %v", err)
|
||||||
|
}
|
||||||
|
_, err = mgr.Create("model-b")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create 2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := mgr.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 2 {
|
||||||
|
t.Errorf("expected 2 conversations, got %d", len(list))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteConversation(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
conv, err := mgr.Create("test-model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Delete(conv.ID); err != nil {
|
||||||
|
t.Fatalf("delete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
list, err := mgr.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 0 {
|
||||||
|
t.Errorf("expected 0 conversations after delete, got %d", len(list))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete non-existent should error
|
||||||
|
if err := mgr.Delete("nonexistent"); err == nil {
|
||||||
|
t.Error("expected error deleting non-existent conversation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUsageTracking(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
conv, err := mgr.Create("test-model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update token usage multiple times
|
||||||
|
if err := mgr.UpdateTokenUsage(conv.ID, 100, 50); err != nil {
|
||||||
|
t.Fatalf("update tokens 1: %v", err)
|
||||||
|
}
|
||||||
|
if err := mgr.UpdateTokenUsage(conv.ID, 200, 100); err != nil {
|
||||||
|
t.Fatalf("update tokens 2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify totals
|
||||||
|
list, err := mgr.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 1 {
|
||||||
|
t.Fatalf("expected 1 conversation, got %d", len(list))
|
||||||
|
}
|
||||||
|
if list[0].TokensIn != 300 {
|
||||||
|
t.Errorf("expected 300 tokens in, got %d", list[0].TokensIn)
|
||||||
|
}
|
||||||
|
if list[0].TokensOut != 150 {
|
||||||
|
t.Errorf("expected 150 tokens out, got %d", list[0].TokensOut)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMessagesNonExistent(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
_, err := mgr.GetMessages("nonexistent-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for non-existent conversation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoTitle(t *testing.T) {
|
||||||
|
mgr := setupConversationManager(t)
|
||||||
|
|
||||||
|
conv, err := mgr.Create("test-model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{Type: "text", Text: "Check disk usage on server-01"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := mgr.AddMessage(conv.ID, msg); err != nil {
|
||||||
|
t.Fatalf("add message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the title was auto-set
|
||||||
|
list, err := mgr.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list: %v", err)
|
||||||
|
}
|
||||||
|
if list[0].Title != "Check disk usage on server-01" {
|
||||||
|
t.Errorf("expected auto-title, got %q", list[0].Title)
|
||||||
|
}
|
||||||
|
}
|
||||||
371
internal/ai/oauth.go
Normal file
371
internal/ai/oauth.go
Normal file
@ -0,0 +1,371 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/settings"
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
|
oauthAuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||||
|
oauthTokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions.
|
||||||
|
type OAuthManager struct {
|
||||||
|
settings *settings.SettingsService
|
||||||
|
vault *vault.VaultService
|
||||||
|
clientID string
|
||||||
|
authorizeURL string
|
||||||
|
tokenURL string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenResponse is the JSON body returned by the token endpoint.
|
||||||
|
type tokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthManager creates an OAuthManager wired to settings and vault services.
|
||||||
|
func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager {
|
||||||
|
return &OAuthManager{
|
||||||
|
settings: s,
|
||||||
|
vault: v,
|
||||||
|
clientID: oauthClientID,
|
||||||
|
authorizeURL: oauthAuthorizeURL,
|
||||||
|
tokenURL: oauthTokenURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartLogin begins the OAuth PKCE flow. It starts a local HTTP server,
|
||||||
|
// opens the authorization URL in the user's browser, and returns a channel
|
||||||
|
// that receives nil on success or an error on failure.
|
||||||
|
func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) {
|
||||||
|
verifier, err := generateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate code verifier: %w", err)
|
||||||
|
}
|
||||||
|
challenge := generateCodeChallenge(verifier)
|
||||||
|
state, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a local HTTP server on a random port for the callback
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen for callback: %w", err)
|
||||||
|
}
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
server := &http.Server{Handler: mux}
|
||||||
|
|
||||||
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
server.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Verify state
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprint(w, "State mismatch — please try logging in again.")
|
||||||
|
done <- fmt.Errorf("state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, "Authorization error: %s", errParam)
|
||||||
|
done <- fmt.Errorf("authorization error: %s", errParam)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprint(w, "No authorization code received.")
|
||||||
|
done <- fmt.Errorf("no authorization code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange code for tokens
|
||||||
|
if err := o.exchangeCode(code, verifier, redirectURI); err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprint(w, "Failed to exchange authorization code.")
|
||||||
|
done <- fmt.Errorf("exchange code: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
fmt.Fprint(w, `<html><body><h2>Authenticated!</h2><p>You can close this window and return to Wraith.</p></body></html>`)
|
||||||
|
done <- nil
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
slog.Error("oauth callback server error", "error", err)
|
||||||
|
done <- fmt.Errorf("callback server: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Build the authorize URL
|
||||||
|
params := url.Values{
|
||||||
|
"response_type": {"code"},
|
||||||
|
"client_id": {o.clientID},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"scope": {"user:inference"},
|
||||||
|
"state": {state},
|
||||||
|
"code_challenge": {challenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
}
|
||||||
|
authURL := o.authorizeURL + "?" + params.Encode()
|
||||||
|
|
||||||
|
if openURL != nil {
|
||||||
|
if err := openURL(authURL); err != nil {
|
||||||
|
listener.Close()
|
||||||
|
return nil, fmt.Errorf("open browser: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return done, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeCode exchanges an authorization code for access and refresh tokens.
|
||||||
|
func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error {
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"client_id": {o.clientID},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.PostForm(o.tokenURL, data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("post token request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens tokenResponse
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.storeTokens(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeTokens encrypts and persists OAuth tokens in settings.
|
||||||
|
func (o *OAuthManager) storeTokens(tokens tokenResponse) error {
|
||||||
|
o.mu.Lock()
|
||||||
|
defer o.mu.Unlock()
|
||||||
|
|
||||||
|
if o.vault == nil {
|
||||||
|
return fmt.Errorf("vault is not unlocked — cannot store tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
encAccess, err := o.vault.Encrypt(tokens.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encrypt access token: %w", err)
|
||||||
|
}
|
||||||
|
if err := o.settings.Set("ai_access_token", encAccess); err != nil {
|
||||||
|
return fmt.Errorf("store access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encRefresh, err := o.vault.Encrypt(tokens.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encrypt refresh token: %w", err)
|
||||||
|
}
|
||||||
|
if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil {
|
||||||
|
return fmt.Errorf("store refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix()
|
||||||
|
if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil {
|
||||||
|
return fmt.Errorf("store token expiry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("oauth tokens stored successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken returns a valid access token, refreshing if expired.
|
||||||
|
func (o *OAuthManager) GetAccessToken() (string, error) {
|
||||||
|
o.mu.RLock()
|
||||||
|
vlt := o.vault
|
||||||
|
o.mu.RUnlock()
|
||||||
|
|
||||||
|
if vlt == nil {
|
||||||
|
return "", fmt.Errorf("vault is not unlocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiry
|
||||||
|
expiresStr, _ := o.settings.Get("ai_token_expires_at")
|
||||||
|
if expiresStr != "" {
|
||||||
|
expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64)
|
||||||
|
if time.Now().Unix() >= expiresAt {
|
||||||
|
// Token expired, try to refresh
|
||||||
|
if err := o.refreshToken(); err != nil {
|
||||||
|
return "", fmt.Errorf("refresh expired token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
encAccess, err := o.settings.Get("ai_access_token")
|
||||||
|
if err != nil || encAccess == "" {
|
||||||
|
return "", fmt.Errorf("no access token stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := vlt.Decrypt(encAccess)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decrypt access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshToken uses the stored refresh token to obtain new tokens.
|
||||||
|
func (o *OAuthManager) refreshToken() error {
|
||||||
|
o.mu.RLock()
|
||||||
|
vlt := o.vault
|
||||||
|
o.mu.RUnlock()
|
||||||
|
|
||||||
|
if vlt == nil {
|
||||||
|
return fmt.Errorf("vault is not unlocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
encRefresh, err := o.settings.Get("ai_refresh_token")
|
||||||
|
if err != nil || encRefresh == "" {
|
||||||
|
return fmt.Errorf("no refresh token stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshTok, err := vlt.Decrypt(encRefresh)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decrypt refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {refreshTok},
|
||||||
|
"client_id": {o.clientID},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.PostForm(o.tokenURL, data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("post refresh request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens tokenResponse
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the refresh endpoint doesn't return a new refresh token, keep the old one
|
||||||
|
if tokens.RefreshToken == "" {
|
||||||
|
tokens.RefreshToken = refreshTok
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.storeTokens(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticated checks if we have stored tokens.
|
||||||
|
func (o *OAuthManager) IsAuthenticated() bool {
|
||||||
|
encAccess, err := o.settings.Get("ai_access_token")
|
||||||
|
return err == nil && encAccess != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout clears all stored OAuth tokens.
|
||||||
|
func (o *OAuthManager) Logout() error {
|
||||||
|
for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} {
|
||||||
|
if err := o.settings.Delete(key); err != nil {
|
||||||
|
return fmt.Errorf("delete %s: %w", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding).
|
||||||
|
func generateCodeVerifier() (string, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeChallenge creates a PKCE S256 code challenge from a verifier.
|
||||||
|
func generateCodeChallenge(verifier string) string {
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState creates a random state parameter (32 random bytes, base64url no padding).
|
||||||
|
func generateState() (string, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVault updates the vault reference (called after vault unlock).
|
||||||
|
func (o *OAuthManager) SetVault(v *vault.VaultService) {
|
||||||
|
o.mu.Lock()
|
||||||
|
defer o.mu.Unlock()
|
||||||
|
o.vault = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBase64URL checks if a string contains only base64url characters (no padding).
|
||||||
|
func isBase64URL(s string) bool {
|
||||||
|
for _, c := range s {
|
||||||
|
if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
91
internal/ai/oauth_test.go
Normal file
91
internal/ai/oauth_test.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
|
"github.com/vstockwell/wraith/internal/settings"
|
||||||
|
"github.com/vstockwell/wraith/internal/vault"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateCodeVerifier(t *testing.T) {
|
||||||
|
v, err := generateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 32 bytes -> 43 base64url chars (no padding)
|
||||||
|
if len(v) != 43 {
|
||||||
|
t.Errorf("expected verifier length 43, got %d", len(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isBase64URL(v) {
|
||||||
|
t.Errorf("verifier contains non-base64url characters: %s", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be different each time
|
||||||
|
v2, _ := generateCodeVerifier()
|
||||||
|
if v == v2 {
|
||||||
|
t.Error("two generated verifiers should not be identical")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCodeChallenge(t *testing.T) {
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
|
||||||
|
challenge := generateCodeChallenge(verifier)
|
||||||
|
|
||||||
|
// Verify it matches a manually computed S256 hash
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
|
||||||
|
if challenge != expected {
|
||||||
|
t.Errorf("expected challenge %s, got %s", expected, challenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same verifier should produce the same challenge (deterministic)
|
||||||
|
challenge2 := generateCodeChallenge(verifier)
|
||||||
|
if challenge != challenge2 {
|
||||||
|
t.Error("challenge should be deterministic for the same verifier")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateState(t *testing.T) {
|
||||||
|
s, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 32 bytes -> 43 base64url chars (no padding)
|
||||||
|
if len(s) != 43 {
|
||||||
|
t.Errorf("expected state length 43, got %d", len(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isBase64URL(s) {
|
||||||
|
t.Errorf("state contains non-base64url characters: %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAuthenticatedWhenNoTokens(t *testing.T) {
|
||||||
|
database, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open db: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(database); err != nil {
|
||||||
|
t.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { database.Close() })
|
||||||
|
|
||||||
|
settingsSvc := settings.NewSettingsService(database)
|
||||||
|
key := vault.DeriveKey("test-password", []byte("test-salt-1234567890123456789012"))
|
||||||
|
vaultSvc := vault.NewVaultService(key)
|
||||||
|
|
||||||
|
oauth := NewOAuthManager(settingsSvc, vaultSvc)
|
||||||
|
|
||||||
|
if oauth.IsAuthenticated() {
|
||||||
|
t.Error("expected IsAuthenticated to return false when no tokens are stored")
|
||||||
|
}
|
||||||
|
}
|
||||||
564
internal/ai/router.go
Normal file
564
internal/ai/router.go
Normal file
@ -0,0 +1,564 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolRouter dispatches tool calls to the appropriate service.
|
||||||
|
// Services are stored as interface{} to avoid circular imports between packages.
|
||||||
|
type ToolRouter struct {
|
||||||
|
ssh interface{} // *ssh.SSHService
|
||||||
|
sftp interface{} // *sftp.SFTPService
|
||||||
|
rdp interface{} // *rdp.RDPService
|
||||||
|
sessions interface{} // *session.Manager
|
||||||
|
connections interface{} // *connections.ConnectionService
|
||||||
|
aiService interface{} // *AIService — for terminal buffer access
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolRouter creates an empty ToolRouter. Call SetServices to wire in backends.
|
||||||
|
func NewToolRouter() *ToolRouter {
|
||||||
|
return &ToolRouter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetServices wires the router to actual service implementations.
|
||||||
|
func (r *ToolRouter) SetServices(ssh, sftp, rdp, sessions, connections interface{}) {
|
||||||
|
r.ssh = ssh
|
||||||
|
r.sftp = sftp
|
||||||
|
r.rdp = rdp
|
||||||
|
r.sessions = sessions
|
||||||
|
r.connections = connections
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAIService wires the router to the AI service for terminal buffer access.
|
||||||
|
func (r *ToolRouter) SetAIService(aiService interface{}) {
|
||||||
|
r.aiService = aiService
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshWriter is the interface we need from SSHService for terminal_write.
|
||||||
|
type sshWriter interface {
|
||||||
|
Write(sessionID string, data string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// sshSessionLister is the interface for listing SSH sessions.
|
||||||
|
type sshSessionLister interface {
|
||||||
|
ListSessions() interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sftpLister is the interface for SFTP list operations.
|
||||||
|
type sftpLister interface {
|
||||||
|
List(sessionID string, path string) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sftpReader is the interface for SFTP read operations.
|
||||||
|
type sftpReader interface {
|
||||||
|
ReadFile(sessionID string, path string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sftpWriter is the interface for SFTP write operations.
|
||||||
|
type sftpWriter interface {
|
||||||
|
WriteFile(sessionID string, path string, content string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// rdpFrameGetter is the interface for getting RDP screenshots.
|
||||||
|
type rdpFrameGetter interface {
|
||||||
|
GetFrame(sessionID string) ([]byte, error)
|
||||||
|
GetSessionInfo(sessionID string) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// rdpMouseSender is the interface for RDP mouse events.
|
||||||
|
type rdpMouseSender interface {
|
||||||
|
SendMouse(sessionID string, x, y int, flags uint32) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// rdpKeySender is the interface for RDP key events.
|
||||||
|
type rdpKeySender interface {
|
||||||
|
SendKey(sessionID string, scancode uint32, pressed bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// rdpClipboardSender is the interface for RDP clipboard.
|
||||||
|
type rdpClipboardSender interface {
|
||||||
|
SendClipboard(sessionID string, data string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// rdpSessionLister is the interface for listing RDP sessions.
|
||||||
|
type rdpSessionLister interface {
|
||||||
|
ListSessions() interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionLister is the interface for listing sessions.
|
||||||
|
type sessionLister interface {
|
||||||
|
List() interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bufferProvider provides terminal output buffers.
|
||||||
|
type bufferProvider interface {
|
||||||
|
GetBuffer(sessionId string) *TerminalBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch routes a tool call to the appropriate handler.
|
||||||
|
func (r *ToolRouter) Dispatch(toolName string, input json.RawMessage) (interface{}, error) {
|
||||||
|
switch toolName {
|
||||||
|
case "terminal_write":
|
||||||
|
return r.handleTerminalWrite(input)
|
||||||
|
case "terminal_read":
|
||||||
|
return r.handleTerminalRead(input)
|
||||||
|
case "terminal_cwd":
|
||||||
|
return r.handleTerminalCwd(input)
|
||||||
|
case "sftp_list":
|
||||||
|
return r.handleSFTPList(input)
|
||||||
|
case "sftp_read":
|
||||||
|
return r.handleSFTPRead(input)
|
||||||
|
case "sftp_write":
|
||||||
|
return r.handleSFTPWrite(input)
|
||||||
|
case "rdp_screenshot":
|
||||||
|
return r.handleRDPScreenshot(input)
|
||||||
|
case "rdp_click":
|
||||||
|
return r.handleRDPClick(input)
|
||||||
|
case "rdp_doubleclick":
|
||||||
|
return r.handleRDPDoubleClick(input)
|
||||||
|
case "rdp_type":
|
||||||
|
return r.handleRDPType(input)
|
||||||
|
case "rdp_keypress":
|
||||||
|
return r.handleRDPKeypress(input)
|
||||||
|
case "rdp_scroll":
|
||||||
|
return r.handleRDPScroll(input)
|
||||||
|
case "rdp_move":
|
||||||
|
return r.handleRDPMove(input)
|
||||||
|
case "list_sessions":
|
||||||
|
return r.handleListSessions(input)
|
||||||
|
case "connect_ssh":
|
||||||
|
return r.handleConnectSSH(input)
|
||||||
|
case "disconnect":
|
||||||
|
return r.handleDisconnect(input)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown tool: %s", toolName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleTerminalWrite(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w, ok := r.ssh.(sshWriter)
|
||||||
|
if !ok || r.ssh == nil {
|
||||||
|
return nil, fmt.Errorf("SSH service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.Write(params.SessionID, params.Text); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleTerminalRead(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Lines int `json:"lines"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
if params.Lines <= 0 {
|
||||||
|
params.Lines = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
bp, ok := r.aiService.(bufferProvider)
|
||||||
|
if !ok || r.aiService == nil {
|
||||||
|
return nil, fmt.Errorf("terminal buffer not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bp.GetBuffer(params.SessionID)
|
||||||
|
lines := buf.ReadLast(params.Lines)
|
||||||
|
return map[string]interface{}{
|
||||||
|
"lines": lines,
|
||||||
|
"count": len(lines),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleTerminalCwd(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// terminal_cwd works by writing "pwd" to the terminal and reading output
|
||||||
|
w, ok := r.ssh.(sshWriter)
|
||||||
|
if !ok || r.ssh == nil {
|
||||||
|
return nil, fmt.Errorf("SSH service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.Write(params.SessionID, "pwd\n"); err != nil {
|
||||||
|
return nil, fmt.Errorf("send pwd: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]string{
|
||||||
|
"status": "pwd command sent — read terminal output for result",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleSFTPList(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, ok := r.sftp.(sftpLister)
|
||||||
|
if !ok || r.sftp == nil {
|
||||||
|
return nil, fmt.Errorf("SFTP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return l.List(params.SessionID, params.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleSFTPRead(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, ok := r.sftp.(sftpReader)
|
||||||
|
if !ok || r.sftp == nil {
|
||||||
|
return nil, fmt.Errorf("SFTP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := reader.ReadFile(params.SessionID, params.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]string{"content": content}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleSFTPWrite(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w, ok := r.sftp.(sftpWriter)
|
||||||
|
if !ok || r.sftp == nil {
|
||||||
|
return nil, fmt.Errorf("SFTP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteFile(params.SessionID, params.Path, params.Content); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPScreenshot(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fg, ok := r.rdp.(rdpFrameGetter)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
frame, err := fg.GetFrame(params.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session info for dimensions
|
||||||
|
info, err := fg.GetSessionInfo(params.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get dimensions from session config
|
||||||
|
type configGetter interface {
|
||||||
|
GetConfig() (int, int)
|
||||||
|
}
|
||||||
|
width, height := 1920, 1080
|
||||||
|
if cg, ok := info.(configGetter); ok {
|
||||||
|
width, height = cg.GetConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode as JPEG
|
||||||
|
jpeg, err := EncodeScreenshot(frame, width, height, 1280, 720, 75)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encode screenshot: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"image": jpeg,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPClick(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
X int `json:"x"`
|
||||||
|
Y int `json:"y"`
|
||||||
|
Button string `json:"button"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, ok := r.rdp.(rdpMouseSender)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
var buttonFlag uint32 = 0x1000 // left
|
||||||
|
switch params.Button {
|
||||||
|
case "right":
|
||||||
|
buttonFlag = 0x2000
|
||||||
|
case "middle":
|
||||||
|
buttonFlag = 0x4000
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press
|
||||||
|
if err := ms.SendMouse(params.SessionID, params.X, params.Y, buttonFlag|0x8000); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Release
|
||||||
|
if err := ms.SendMouse(params.SessionID, params.X, params.Y, buttonFlag); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPDoubleClick(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
X int `json:"x"`
|
||||||
|
Y int `json:"y"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, ok := r.rdp.(rdpMouseSender)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two clicks
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x1000|0x8000); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x1000); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPType(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cs, ok := r.rdp.(rdpClipboardSender)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send text via clipboard, then simulate Ctrl+V
|
||||||
|
if err := cs.SendClipboard(params.SessionID, params.Text); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ks, ok := r.rdp.(rdpKeySender)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("RDP key service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ctrl down, V down, V up, Ctrl up
|
||||||
|
if err := ks.SendKey(params.SessionID, 0x001D, true); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ks.SendKey(params.SessionID, 0x002F, true); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ks.SendKey(params.SessionID, 0x002F, false); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ks.SendKey(params.SessionID, 0x001D, false); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPKeypress(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ks, ok := r.rdp.(rdpKeySender)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple key name to scancode mapping for common keys
|
||||||
|
keyMap := map[string]uint32{
|
||||||
|
"Enter": 0x001C,
|
||||||
|
"Tab": 0x000F,
|
||||||
|
"Escape": 0x0001,
|
||||||
|
"Backspace": 0x000E,
|
||||||
|
"Delete": 0xE053,
|
||||||
|
"Space": 0x0039,
|
||||||
|
"Up": 0xE048,
|
||||||
|
"Down": 0xE050,
|
||||||
|
"Left": 0xE04B,
|
||||||
|
"Right": 0xE04D,
|
||||||
|
}
|
||||||
|
|
||||||
|
scancode, ok := keyMap[params.Key]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown key: %s", params.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ks.SendKey(params.SessionID, scancode, true); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ks.SendKey(params.SessionID, scancode, false); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPScroll(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
X int `json:"x"`
|
||||||
|
Y int `json:"y"`
|
||||||
|
Direction string `json:"direction"`
|
||||||
|
Clicks int `json:"clicks"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
if params.Clicks <= 0 {
|
||||||
|
params.Clicks = 3
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, ok := r.rdp.(rdpMouseSender)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags uint32 = 0x0200 // wheel flag
|
||||||
|
if params.Direction == "down" {
|
||||||
|
flags |= 0x0100 // negative flag
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < params.Clicks; i++ {
|
||||||
|
if err := ms.SendMouse(params.SessionID, params.X, params.Y, flags); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleRDPMove(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
X int `json:"x"`
|
||||||
|
Y int `json:"y"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, ok := r.rdp.(rdpMouseSender)
|
||||||
|
if !ok || r.rdp == nil {
|
||||||
|
return nil, fmt.Errorf("RDP service not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ms.SendMouse(params.SessionID, params.X, params.Y, 0x0800); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]string{"status": "ok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleListSessions(_ json.RawMessage) (interface{}, error) {
|
||||||
|
result := map[string]interface{}{
|
||||||
|
"ssh": []interface{}{},
|
||||||
|
"rdp": []interface{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.sessions != nil {
|
||||||
|
if sl, ok := r.sessions.(sessionLister); ok {
|
||||||
|
result["all"] = sl.List()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleConnectSSH(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
ConnectionID int64 `json:"connectionId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This will be wired to the app-level connect logic later
|
||||||
|
return nil, fmt.Errorf("connect_ssh requires app-level wiring — not yet available via tool dispatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRouter) handleDisconnect(input json.RawMessage) (interface{}, error) {
|
||||||
|
var params struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(input, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try SSH first
|
||||||
|
type disconnecter interface {
|
||||||
|
Disconnect(sessionID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
if d, ok := r.ssh.(disconnecter); ok {
|
||||||
|
if err := d.Disconnect(params.SessionID); err == nil {
|
||||||
|
return map[string]string{"status": "disconnected", "protocol": "ssh"}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if d, ok := r.rdp.(disconnecter); ok {
|
||||||
|
if err := d.Disconnect(params.SessionID); err == nil {
|
||||||
|
return map[string]string{"status": "disconnected", "protocol": "rdp"}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("session %s not found in SSH or RDP", params.SessionID)
|
||||||
|
}
|
||||||
85
internal/ai/router_test.go
Normal file
85
internal/ai/router_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDispatchUnknownTool(t *testing.T) {
|
||||||
|
router := NewToolRouter()
|
||||||
|
|
||||||
|
_, err := router.Dispatch("nonexistent_tool", json.RawMessage(`{}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for unknown tool")
|
||||||
|
}
|
||||||
|
if err.Error() != "unknown tool: nonexistent_tool" {
|
||||||
|
t.Errorf("unexpected error message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchListSessions(t *testing.T) {
|
||||||
|
router := NewToolRouter()
|
||||||
|
|
||||||
|
// With nil services, list_sessions should return empty result
|
||||||
|
result, err := router.Dispatch("list_sessions", json.RawMessage(`{}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := result.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected map result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
ssh, ok := m["ssh"].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected ssh key in result")
|
||||||
|
}
|
||||||
|
if len(ssh) != 0 {
|
||||||
|
t.Errorf("expected empty ssh list, got %d items", len(ssh))
|
||||||
|
}
|
||||||
|
|
||||||
|
rdp, ok := m["rdp"].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected rdp key in result")
|
||||||
|
}
|
||||||
|
if len(rdp) != 0 {
|
||||||
|
t.Errorf("expected empty rdp list, got %d items", len(rdp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchTerminalWriteNoService(t *testing.T) {
|
||||||
|
router := NewToolRouter()
|
||||||
|
|
||||||
|
_, err := router.Dispatch("terminal_write", json.RawMessage(`{"sessionId":"abc","text":"ls\n"}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when SSH service is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchTerminalReadNoService(t *testing.T) {
|
||||||
|
router := NewToolRouter()
|
||||||
|
|
||||||
|
_, err := router.Dispatch("terminal_read", json.RawMessage(`{"sessionId":"abc"}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when AI service is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchSFTPListNoService(t *testing.T) {
|
||||||
|
router := NewToolRouter()
|
||||||
|
|
||||||
|
_, err := router.Dispatch("sftp_list", json.RawMessage(`{"sessionId":"abc","path":"/"}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when SFTP service is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchDisconnectNoService(t *testing.T) {
|
||||||
|
router := NewToolRouter()
|
||||||
|
|
||||||
|
_, err := router.Dispatch("disconnect", json.RawMessage(`{"sessionId":"abc"}`))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when no services available")
|
||||||
|
}
|
||||||
|
}
|
||||||
79
internal/ai/screenshot.go
Normal file
79
internal/ai/screenshot.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/jpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncodeScreenshot converts raw RGBA pixel data to a JPEG image.
|
||||||
|
// If the source dimensions exceed maxWidth x maxHeight, the image is
|
||||||
|
// downscaled using nearest-neighbor sampling (fast, no external deps).
|
||||||
|
// Returns the JPEG bytes.
|
||||||
|
func EncodeScreenshot(rgba []byte, srcWidth, srcHeight, maxWidth, maxHeight, quality int) ([]byte, error) {
|
||||||
|
expectedLen := srcWidth * srcHeight * 4
|
||||||
|
if len(rgba) < expectedLen {
|
||||||
|
return nil, fmt.Errorf("RGBA buffer too small: got %d bytes, expected %d for %dx%d", len(rgba), expectedLen, srcWidth, srcHeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
if quality <= 0 || quality > 100 {
|
||||||
|
quality = 75
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create source image from RGBA buffer
|
||||||
|
src := image.NewRGBA(image.Rect(0, 0, srcWidth, srcHeight))
|
||||||
|
copy(src.Pix, rgba[:expectedLen])
|
||||||
|
|
||||||
|
// Determine output dimensions
|
||||||
|
dstWidth, dstHeight := srcWidth, srcHeight
|
||||||
|
if srcWidth > maxWidth || srcHeight > maxHeight {
|
||||||
|
dstWidth, dstHeight = fitDimensions(srcWidth, srcHeight, maxWidth, maxHeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
var img image.Image = src
|
||||||
|
|
||||||
|
// Downscale if needed using nearest-neighbor sampling
|
||||||
|
if dstWidth != srcWidth || dstHeight != srcHeight {
|
||||||
|
dst := image.NewRGBA(image.Rect(0, 0, dstWidth, dstHeight))
|
||||||
|
for y := 0; y < dstHeight; y++ {
|
||||||
|
srcY := y * srcHeight / dstHeight
|
||||||
|
for x := 0; x < dstWidth; x++ {
|
||||||
|
srcX := x * srcWidth / dstWidth
|
||||||
|
srcIdx := (srcY*srcWidth + srcX) * 4
|
||||||
|
dstIdx := (y*dstWidth + x) * 4
|
||||||
|
dst.Pix[dstIdx+0] = src.Pix[srcIdx+0] // R
|
||||||
|
dst.Pix[dstIdx+1] = src.Pix[srcIdx+1] // G
|
||||||
|
dst.Pix[dstIdx+2] = src.Pix[srcIdx+2] // B
|
||||||
|
dst.Pix[dstIdx+3] = src.Pix[srcIdx+3] // A
|
||||||
|
}
|
||||||
|
}
|
||||||
|
img = dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to JPEG
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil {
|
||||||
|
return nil, fmt.Errorf("encode JPEG: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fitDimensions calculates the largest dimensions that fit within max bounds
|
||||||
|
// while preserving aspect ratio.
|
||||||
|
func fitDimensions(srcW, srcH, maxW, maxH int) (int, int) {
|
||||||
|
ratio := float64(srcW) / float64(srcH)
|
||||||
|
w, h := maxW, int(float64(maxW)/ratio)
|
||||||
|
if h > maxH {
|
||||||
|
h = maxH
|
||||||
|
w = int(float64(maxH) * ratio)
|
||||||
|
}
|
||||||
|
if w <= 0 {
|
||||||
|
w = 1
|
||||||
|
}
|
||||||
|
if h <= 0 {
|
||||||
|
h = 1
|
||||||
|
}
|
||||||
|
return w, h
|
||||||
|
}
|
||||||
118
internal/ai/screenshot_test.go
Normal file
118
internal/ai/screenshot_test.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"image/jpeg"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeScreenshot(t *testing.T) {
|
||||||
|
width, height := 100, 80
|
||||||
|
|
||||||
|
// Create a test RGBA buffer (red pixels)
|
||||||
|
rgba := make([]byte, width*height*4)
|
||||||
|
for i := 0; i < len(rgba); i += 4 {
|
||||||
|
rgba[i+0] = 255 // R
|
||||||
|
rgba[i+1] = 0 // G
|
||||||
|
rgba[i+2] = 0 // B
|
||||||
|
rgba[i+3] = 255 // A
|
||||||
|
}
|
||||||
|
|
||||||
|
jpegData, err := EncodeScreenshot(rgba, width, height, 200, 200, 80)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check JPEG magic bytes (FF D8)
|
||||||
|
if len(jpegData) < 2 {
|
||||||
|
t.Fatal("JPEG data too small")
|
||||||
|
}
|
||||||
|
if jpegData[0] != 0xFF || jpegData[1] != 0xD8 {
|
||||||
|
t.Errorf("expected JPEG magic bytes FF D8, got %02X %02X", jpegData[0], jpegData[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode and verify dimensions (no downscale needed in this case)
|
||||||
|
img, err := jpeg.Decode(bytes.NewReader(jpegData))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode JPEG: %v", err)
|
||||||
|
}
|
||||||
|
bounds := img.Bounds()
|
||||||
|
if bounds.Dx() != width || bounds.Dy() != height {
|
||||||
|
t.Errorf("expected %dx%d, got %dx%d", width, height, bounds.Dx(), bounds.Dy())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeScreenshotDownscale(t *testing.T) {
|
||||||
|
srcWidth, srcHeight := 1920, 1080
|
||||||
|
maxWidth, maxHeight := 1280, 720
|
||||||
|
|
||||||
|
// Create a test RGBA buffer (gradient)
|
||||||
|
rgba := make([]byte, srcWidth*srcHeight*4)
|
||||||
|
for y := 0; y < srcHeight; y++ {
|
||||||
|
for x := 0; x < srcWidth; x++ {
|
||||||
|
idx := (y*srcWidth + x) * 4
|
||||||
|
rgba[idx+0] = byte(x % 256) // R
|
||||||
|
rgba[idx+1] = byte(y % 256) // G
|
||||||
|
rgba[idx+2] = byte((x + y) % 256) // B
|
||||||
|
rgba[idx+3] = 255 // A
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jpegData, err := EncodeScreenshot(rgba, srcWidth, srcHeight, maxWidth, maxHeight, 75)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check JPEG magic bytes
|
||||||
|
if jpegData[0] != 0xFF || jpegData[1] != 0xD8 {
|
||||||
|
t.Errorf("expected JPEG magic bytes FF D8, got %02X %02X", jpegData[0], jpegData[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode and verify dimensions are within max bounds
|
||||||
|
img, err := jpeg.Decode(bytes.NewReader(jpegData))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode JPEG: %v", err)
|
||||||
|
}
|
||||||
|
bounds := img.Bounds()
|
||||||
|
|
||||||
|
if bounds.Dx() > maxWidth {
|
||||||
|
t.Errorf("output width %d exceeds max %d", bounds.Dx(), maxWidth)
|
||||||
|
}
|
||||||
|
if bounds.Dy() > maxHeight {
|
||||||
|
t.Errorf("output height %d exceeds max %d", bounds.Dy(), maxHeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should maintain 16:9 ratio approximately
|
||||||
|
expectedWidth := 1280
|
||||||
|
expectedHeight := 720
|
||||||
|
if bounds.Dx() != expectedWidth || bounds.Dy() != expectedHeight {
|
||||||
|
t.Errorf("expected %dx%d, got %dx%d", expectedWidth, expectedHeight, bounds.Dx(), bounds.Dy())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeScreenshotBufferTooSmall(t *testing.T) {
|
||||||
|
_, err := EncodeScreenshot([]byte{0, 0, 0, 0}, 100, 100, 200, 200, 75)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for buffer too small")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFitDimensions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
srcW, srcH, maxW, maxH int
|
||||||
|
wantW, wantH int
|
||||||
|
}{
|
||||||
|
{1920, 1080, 1280, 720, 1280, 720}, // 16:9 fits exactly
|
||||||
|
{1920, 1080, 800, 600, 800, 450}, // width-constrained
|
||||||
|
{1080, 1920, 800, 600, 337, 600}, // height-constrained (portrait)
|
||||||
|
{100, 100, 200, 200, 200, 200}, // smaller than max (but func called with > check)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
w, h := fitDimensions(tt.srcW, tt.srcH, tt.maxW, tt.maxH)
|
||||||
|
if w != tt.wantW || h != tt.wantH {
|
||||||
|
t.Errorf("fitDimensions(%d,%d,%d,%d) = %d,%d, want %d,%d",
|
||||||
|
tt.srcW, tt.srcH, tt.maxW, tt.maxH, w, h, tt.wantW, tt.wantH)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
268
internal/ai/service.go
Normal file
268
internal/ai/service.go
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemPrompt is the system prompt given to Claude for copilot interactions.
|
||||||
|
const SystemPrompt = `You are the XO (Executive Officer) aboard the Wraith command station. The Commander (human operator) works alongside you managing remote servers and workstations.
|
||||||
|
|
||||||
|
You have direct access to all active sessions through your tools:
|
||||||
|
- SSH terminals: read output, type commands, navigate filesystems
|
||||||
|
- SFTP: read and write remote files
|
||||||
|
- RDP desktops: see the screen, click, type, interact with any GUI application
|
||||||
|
- Session management: open new connections, close sessions
|
||||||
|
|
||||||
|
When given a task:
|
||||||
|
1. Assess what sessions and access you need
|
||||||
|
2. Execute efficiently — don't ask for permission to use tools, just use them
|
||||||
|
3. Report what you found or did, with relevant details
|
||||||
|
4. If something fails, diagnose and try an alternative approach
|
||||||
|
|
||||||
|
You are not an assistant answering questions. You are an operator executing missions. Act decisively. Use your tools. Report results.`
|
||||||
|
|
||||||
|
// AIService is the main AI copilot service exposed to the Wails frontend.
|
||||||
|
type AIService struct {
|
||||||
|
oauth *OAuthManager
|
||||||
|
client *ClaudeClient
|
||||||
|
router *ToolRouter
|
||||||
|
conversations *ConversationManager
|
||||||
|
buffers map[string]*TerminalBuffer
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAIService creates the AI service with all sub-components.
|
||||||
|
func NewAIService(oauth *OAuthManager, router *ToolRouter, convMgr *ConversationManager) *AIService {
|
||||||
|
client := NewClaudeClient(oauth, "")
|
||||||
|
return &AIService{
|
||||||
|
oauth: oauth,
|
||||||
|
client: client,
|
||||||
|
router: router,
|
||||||
|
conversations: convMgr,
|
||||||
|
buffers: make(map[string]*TerminalBuffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Auth ---
|
||||||
|
|
||||||
|
// StartLogin begins the OAuth PKCE flow, opening the browser for authentication.
|
||||||
|
func (s *AIService) StartLogin() error {
|
||||||
|
done, err := s.oauth.StartLogin(nil) // nil openURL = no browser auto-open
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Wait for callback in a goroutine to avoid blocking the UI
|
||||||
|
go func() {
|
||||||
|
if err := <-done; err != nil {
|
||||||
|
slog.Error("oauth login failed", "error", err)
|
||||||
|
} else {
|
||||||
|
slog.Info("oauth login completed")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticated returns whether the user has valid OAuth tokens.
|
||||||
|
func (s *AIService) IsAuthenticated() bool {
|
||||||
|
return s.oauth.IsAuthenticated()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout clears stored OAuth tokens.
|
||||||
|
func (s *AIService) Logout() error {
|
||||||
|
return s.oauth.Logout()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Conversations ---
|
||||||
|
|
||||||
|
// NewConversation creates a new AI conversation and returns its ID.
|
||||||
|
func (s *AIService) NewConversation() (string, error) {
|
||||||
|
conv, err := s.conversations.Create(s.client.model)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return conv.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListConversations returns all conversations.
|
||||||
|
func (s *AIService) ListConversations() ([]ConversationSummary, error) {
|
||||||
|
return s.conversations.List()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConversation removes a conversation.
|
||||||
|
func (s *AIService) DeleteConversation(id string) error {
|
||||||
|
return s.conversations.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Chat ---
|
||||||
|
|
||||||
|
// SendMessage sends a user message in a conversation and processes the AI response.
|
||||||
|
// Tool calls are automatically dispatched and results fed back to the model.
|
||||||
|
// This method blocks until the full response (including any tool use loops) is complete.
|
||||||
|
func (s *AIService) SendMessage(conversationId, text string) error {
|
||||||
|
// Add user message to conversation
|
||||||
|
userMsg := Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{Type: "text", Text: text},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := s.conversations.AddMessage(conversationId, userMsg); err != nil {
|
||||||
|
return fmt.Errorf("store user message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the message loop (handles tool use)
|
||||||
|
return s.messageLoop(conversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// messageLoop sends the conversation to Claude and handles tool use loops.
|
||||||
|
func (s *AIService) messageLoop(conversationId string) error {
|
||||||
|
for iterations := 0; iterations < 20; iterations++ { // safety limit
|
||||||
|
messages, err := s.conversations.GetMessages(conversationId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ch, err := s.client.SendMessage(messages, CopilotTools, SystemPrompt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("send to claude: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the response
|
||||||
|
var textParts []string
|
||||||
|
var toolCalls []ContentBlock
|
||||||
|
var currentToolInput string
|
||||||
|
|
||||||
|
for event := range ch {
|
||||||
|
switch event.Type {
|
||||||
|
case "text_delta":
|
||||||
|
textParts = append(textParts, event.Data)
|
||||||
|
case "tool_use_start":
|
||||||
|
currentToolInput = ""
|
||||||
|
case "tool_use_delta":
|
||||||
|
currentToolInput += event.Data
|
||||||
|
case "done":
|
||||||
|
// Parse usage if available
|
||||||
|
var delta struct {
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal([]byte(event.Data), &delta) == nil && delta.Usage.OutputTokens > 0 {
|
||||||
|
s.conversations.UpdateTokenUsage(conversationId, delta.Usage.InputTokens, delta.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
case "error":
|
||||||
|
return fmt.Errorf("stream error: %s", event.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When a tool_use block completes, we need to check content_block_stop
|
||||||
|
// But since we accumulate, we'll finalize after the stream ends
|
||||||
|
_ = toolCalls // kept for the final assembly below
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the assistant message
|
||||||
|
var assistantContent []ContentBlock
|
||||||
|
if len(textParts) > 0 {
|
||||||
|
fullText := ""
|
||||||
|
for _, p := range textParts {
|
||||||
|
fullText += p
|
||||||
|
}
|
||||||
|
if fullText != "" {
|
||||||
|
assistantContent = append(assistantContent, ContentBlock{
|
||||||
|
Type: "text",
|
||||||
|
Text: fullText,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
assistantContent = append(assistantContent, tc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(assistantContent) == 0 {
|
||||||
|
return nil // empty response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store assistant message
|
||||||
|
assistantMsg := Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: assistantContent,
|
||||||
|
}
|
||||||
|
if err := s.conversations.AddMessage(conversationId, assistantMsg); err != nil {
|
||||||
|
return fmt.Errorf("store assistant message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there were tool calls, dispatch them and continue the loop
|
||||||
|
hasToolUse := false
|
||||||
|
for _, block := range assistantContent {
|
||||||
|
if block.Type == "tool_use" {
|
||||||
|
hasToolUse = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasToolUse {
|
||||||
|
return nil // done, no tool use to process
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch tool calls and create tool_result message
|
||||||
|
var toolResults []ContentBlock
|
||||||
|
for _, block := range assistantContent {
|
||||||
|
if block.Type != "tool_use" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.router.Dispatch(block.Name, block.Input)
|
||||||
|
resultBlock := ContentBlock{
|
||||||
|
Type: "tool_result",
|
||||||
|
ToolUseID: block.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
resultBlock.IsError = true
|
||||||
|
resultBlock.Content = []ContentBlock{
|
||||||
|
{Type: "text", Text: err.Error()},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resultJSON, _ := json.Marshal(result)
|
||||||
|
resultBlock.Content = []ContentBlock{
|
||||||
|
{Type: "text", Text: string(resultJSON)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolResults = append(toolResults, resultBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResultMsg := Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: toolResults,
|
||||||
|
}
|
||||||
|
if err := s.conversations.AddMessage(conversationId, toolResultMsg); err != nil {
|
||||||
|
return fmt.Errorf("store tool results: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue the loop to let Claude process tool results
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("exceeded maximum tool use iterations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Terminal Buffer Management ---
|
||||||
|
|
||||||
|
// GetBuffer returns the terminal output buffer for a session, creating it if needed.
|
||||||
|
func (s *AIService) GetBuffer(sessionId string) *TerminalBuffer {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
buf, ok := s.buffers[sessionId]
|
||||||
|
if !ok {
|
||||||
|
buf = NewTerminalBuffer(200)
|
||||||
|
s.buffers[sessionId] = buf
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveBuffer removes the terminal buffer for a session.
|
||||||
|
func (s *AIService) RemoveBuffer(sessionId string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.buffers, sessionId)
|
||||||
|
}
|
||||||
101
internal/ai/terminal_buffer.go
Normal file
101
internal/ai/terminal_buffer.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TerminalBuffer is a thread-safe ring buffer that captures terminal output lines.
|
||||||
|
// It is written to by SSH read loops and read by the AI tool dispatch for terminal_read.
|
||||||
|
type TerminalBuffer struct {
|
||||||
|
lines []string
|
||||||
|
mu sync.RWMutex
|
||||||
|
max int
|
||||||
|
partial string // accumulates data that doesn't end with \n
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTerminalBuffer creates a buffer that retains at most maxLines lines.
|
||||||
|
func NewTerminalBuffer(maxLines int) *TerminalBuffer {
|
||||||
|
if maxLines <= 0 {
|
||||||
|
maxLines = 200
|
||||||
|
}
|
||||||
|
return &TerminalBuffer{
|
||||||
|
lines: make([]string, 0, maxLines),
|
||||||
|
max: maxLines,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write ingests raw terminal output, splitting on newlines and appending complete lines.
|
||||||
|
// Partial lines (data without a trailing newline) are accumulated until the next Write.
|
||||||
|
func (b *TerminalBuffer) Write(data []byte) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
text := b.partial + string(data)
|
||||||
|
b.partial = ""
|
||||||
|
|
||||||
|
parts := strings.Split(text, "\n")
|
||||||
|
|
||||||
|
// The last element of Split is always either:
|
||||||
|
// - empty string if text ends with \n (discard it)
|
||||||
|
// - a partial line if text doesn't end with \n (save as partial)
|
||||||
|
last := parts[len(parts)-1]
|
||||||
|
parts = parts[:len(parts)-1]
|
||||||
|
if last != "" {
|
||||||
|
b.partial = last
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range parts {
|
||||||
|
b.lines = append(b.lines, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim to max
|
||||||
|
if len(b.lines) > b.max {
|
||||||
|
excess := len(b.lines) - b.max
|
||||||
|
b.lines = b.lines[excess:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLast returns the last n lines from the buffer.
|
||||||
|
// If fewer than n lines are available, all lines are returned.
|
||||||
|
func (b *TerminalBuffer) ReadLast(n int) []string {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
|
||||||
|
total := len(b.lines)
|
||||||
|
if n > total {
|
||||||
|
n = total
|
||||||
|
}
|
||||||
|
if n <= 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]string, n)
|
||||||
|
copy(result, b.lines[total-n:])
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadAll returns all lines currently in the buffer.
|
||||||
|
func (b *TerminalBuffer) ReadAll() []string {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]string, len(b.lines))
|
||||||
|
copy(result, b.lines)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all lines from the buffer.
|
||||||
|
func (b *TerminalBuffer) Clear() {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
b.lines = b.lines[:0]
|
||||||
|
b.partial = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of complete lines in the buffer.
|
||||||
|
func (b *TerminalBuffer) Len() int {
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
return len(b.lines)
|
||||||
|
}
|
||||||
175
internal/ai/terminal_buffer_test.go
Normal file
175
internal/ai/terminal_buffer_test.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteAndRead(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
|
||||||
|
buf.Write([]byte("line1\nline2\nline3\n"))
|
||||||
|
|
||||||
|
all := buf.ReadAll()
|
||||||
|
if len(all) != 3 {
|
||||||
|
t.Fatalf("expected 3 lines, got %d", len(all))
|
||||||
|
}
|
||||||
|
if all[0] != "line1" || all[1] != "line2" || all[2] != "line3" {
|
||||||
|
t.Errorf("unexpected lines: %v", all)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRingBufferOverflow(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(200)
|
||||||
|
|
||||||
|
// Write 300 lines
|
||||||
|
for i := 0; i < 300; i++ {
|
||||||
|
buf.Write([]byte(fmt.Sprintf("line %d\n", i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
all := buf.ReadAll()
|
||||||
|
if len(all) != 200 {
|
||||||
|
t.Fatalf("expected 200 lines (ring buffer), got %d", len(all))
|
||||||
|
}
|
||||||
|
|
||||||
|
// First line should be "line 100" (oldest retained)
|
||||||
|
if all[0] != "line 100" {
|
||||||
|
t.Errorf("expected first line 'line 100', got %q", all[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last line should be "line 299"
|
||||||
|
if all[199] != "line 299" {
|
||||||
|
t.Errorf("expected last line 'line 299', got %q", all[199])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLastSubset(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
|
||||||
|
buf.Write([]byte("a\nb\nc\nd\ne\n"))
|
||||||
|
|
||||||
|
last3 := buf.ReadLast(3)
|
||||||
|
if len(last3) != 3 {
|
||||||
|
t.Fatalf("expected 3 lines, got %d", len(last3))
|
||||||
|
}
|
||||||
|
if last3[0] != "c" || last3[1] != "d" || last3[2] != "e" {
|
||||||
|
t.Errorf("unexpected lines: %v", last3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request more than available
|
||||||
|
last10 := buf.ReadLast(10)
|
||||||
|
if len(last10) != 5 {
|
||||||
|
t.Fatalf("expected 5 lines (all available), got %d", len(last10))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(1000)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Spawn 10 writers
|
||||||
|
for w := 0; w < 10; w++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
buf.Write([]byte(fmt.Sprintf("writer %d line %d\n", id, i)))
|
||||||
|
}
|
||||||
|
}(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spawn 5 readers
|
||||||
|
for r := 0; r < 5; r++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
_ = buf.ReadLast(10)
|
||||||
|
_ = buf.Len()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
total := buf.Len()
|
||||||
|
if total != 1000 {
|
||||||
|
t.Errorf("expected 1000 lines (10 writers * 100), got %d", total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWritePartialLines(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
|
||||||
|
// Write data without trailing newline
|
||||||
|
buf.Write([]byte("partial"))
|
||||||
|
|
||||||
|
// Should have no complete lines yet
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Errorf("expected 0 lines for partial write, got %d", buf.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the line
|
||||||
|
buf.Write([]byte(" line\n"))
|
||||||
|
|
||||||
|
if buf.Len() != 1 {
|
||||||
|
t.Fatalf("expected 1 line after completing partial, got %d", buf.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
all := buf.ReadAll()
|
||||||
|
if all[0] != "partial line" {
|
||||||
|
t.Errorf("expected 'partial line', got %q", all[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteMultiplePartials(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
|
||||||
|
buf.Write([]byte("hello "))
|
||||||
|
buf.Write([]byte("world"))
|
||||||
|
buf.Write([]byte("!\nfoo\n"))
|
||||||
|
|
||||||
|
all := buf.ReadAll()
|
||||||
|
if len(all) != 2 {
|
||||||
|
t.Fatalf("expected 2 lines, got %d: %v", len(all), all)
|
||||||
|
}
|
||||||
|
if all[0] != "hello world!" {
|
||||||
|
t.Errorf("expected 'hello world!', got %q", all[0])
|
||||||
|
}
|
||||||
|
if all[1] != "foo" {
|
||||||
|
t.Errorf("expected 'foo', got %q", all[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClear(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
|
||||||
|
buf.Write([]byte("line1\nline2\n"))
|
||||||
|
buf.Clear()
|
||||||
|
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Errorf("expected 0 lines after clear, got %d", buf.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLastZero(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
buf.Write([]byte("line\n"))
|
||||||
|
|
||||||
|
result := buf.ReadLast(0)
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("expected empty result for ReadLast(0), got %d lines", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLastNegative(t *testing.T) {
|
||||||
|
buf := NewTerminalBuffer(100)
|
||||||
|
buf.Write([]byte("line\n"))
|
||||||
|
|
||||||
|
result := buf.ReadLast(-1)
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("expected empty result for ReadLast(-1), got %d lines", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
207
internal/ai/tools.go
Normal file
207
internal/ai/tools.go
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
// CopilotTools defines all tools available to the AI copilot.
|
||||||
|
var CopilotTools = []Tool{
|
||||||
|
// ── SSH Terminal Tools ────────────────────────────────────────
|
||||||
|
{
|
||||||
|
Name: "terminal_write",
|
||||||
|
Description: "Type text into an active SSH terminal session. Use this to execute commands by including a trailing newline character.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The SSH session ID"},
|
||||||
|
"text": {"type": "string", "description": "Text to type into the terminal (include \\n for Enter)"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "text"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "terminal_read",
|
||||||
|
Description: "Read recent output from an SSH terminal session. Returns the last N lines from the terminal buffer.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The SSH session ID"},
|
||||||
|
"lines": {"type": "integer", "description": "Number of recent lines to return (default 50)", "default": 50}
|
||||||
|
},
|
||||||
|
"required": ["sessionId"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "terminal_cwd",
|
||||||
|
Description: "Get the current working directory of an SSH terminal session by executing pwd.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The SSH session ID"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
|
||||||
|
// ── SFTP Tools ───────────────────────────────────────────────
|
||||||
|
{
|
||||||
|
Name: "sftp_list",
|
||||||
|
Description: "List files and directories at a remote path via SFTP.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The SSH/SFTP session ID"},
|
||||||
|
"path": {"type": "string", "description": "Remote directory path to list"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "path"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "sftp_read",
|
||||||
|
Description: "Read the contents of a remote file via SFTP. Limited to files under 5MB.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The SSH/SFTP session ID"},
|
||||||
|
"path": {"type": "string", "description": "Remote file path to read"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "path"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "sftp_write",
|
||||||
|
Description: "Write content to a remote file via SFTP. Creates or overwrites the file.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The SSH/SFTP session ID"},
|
||||||
|
"path": {"type": "string", "description": "Remote file path to write"},
|
||||||
|
"content": {"type": "string", "description": "File content to write"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "path", "content"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
|
||||||
|
// ── RDP Tools ────────────────────────────────────────────────
|
||||||
|
{
|
||||||
|
Name: "rdp_screenshot",
|
||||||
|
Description: "Capture a screenshot of the current RDP desktop. Returns a base64 JPEG image.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "rdp_click",
|
||||||
|
Description: "Click at a specific position on the RDP desktop.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"},
|
||||||
|
"x": {"type": "integer", "description": "X coordinate in pixels"},
|
||||||
|
"y": {"type": "integer", "description": "Y coordinate in pixels"},
|
||||||
|
"button": {"type": "string", "enum": ["left", "right", "middle"], "default": "left", "description": "Mouse button to click"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "x", "y"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "rdp_doubleclick",
|
||||||
|
Description: "Double-click at a specific position on the RDP desktop.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"},
|
||||||
|
"x": {"type": "integer", "description": "X coordinate in pixels"},
|
||||||
|
"y": {"type": "integer", "description": "Y coordinate in pixels"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "x", "y"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "rdp_type",
|
||||||
|
Description: "Type text into the focused element on the RDP desktop. Uses clipboard paste for reliability.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"},
|
||||||
|
"text": {"type": "string", "description": "Text to type"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "text"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "rdp_keypress",
|
||||||
|
Description: "Press a key or key combination on the RDP desktop (e.g., 'Enter', 'Tab', 'Ctrl+C').",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"},
|
||||||
|
"key": {"type": "string", "description": "Key name or combination (e.g., 'Enter', 'Ctrl+C', 'Alt+F4')"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "key"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "rdp_scroll",
|
||||||
|
Description: "Scroll the mouse wheel at a position on the RDP desktop.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"},
|
||||||
|
"x": {"type": "integer", "description": "X coordinate in pixels"},
|
||||||
|
"y": {"type": "integer", "description": "Y coordinate in pixels"},
|
||||||
|
"direction": {"type": "string", "enum": ["up", "down"], "description": "Scroll direction"},
|
||||||
|
"clicks": {"type": "integer", "description": "Number of scroll clicks (default 3)", "default": 3}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "x", "y", "direction"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "rdp_move",
|
||||||
|
Description: "Move the mouse cursor to a position on the RDP desktop without clicking.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The RDP session ID"},
|
||||||
|
"x": {"type": "integer", "description": "X coordinate in pixels"},
|
||||||
|
"y": {"type": "integer", "description": "Y coordinate in pixels"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId", "x", "y"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
|
||||||
|
// ── Session Management Tools ─────────────────────────────────
|
||||||
|
{
|
||||||
|
Name: "list_sessions",
|
||||||
|
Description: "List all active SSH and RDP sessions with their IDs, connection info, and state.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "connect_ssh",
|
||||||
|
Description: "Open a new SSH connection to a saved connection by its ID.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"connectionId": {"type": "integer", "description": "The saved connection ID from the connection manager"}
|
||||||
|
},
|
||||||
|
"required": ["connectionId"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "disconnect",
|
||||||
|
Description: "Disconnect an active session.",
|
||||||
|
InputSchema: json.RawMessage(`{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sessionId": {"type": "string", "description": "The session ID to disconnect"}
|
||||||
|
},
|
||||||
|
"required": ["sessionId"]
|
||||||
|
}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
75
internal/ai/types.go
Normal file
75
internal/ai/types.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message represents a single message in a conversation with Claude.
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"` // "user" or "assistant"
|
||||||
|
Content []ContentBlock `json:"content"` // one or more content blocks
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentBlock is a polymorphic block within a Message.
|
||||||
|
// Only one of the content fields will be populated depending on Type.
|
||||||
|
type ContentBlock struct {
|
||||||
|
Type string `json:"type"` // "text", "image", "tool_use", "tool_result"
|
||||||
|
|
||||||
|
// text
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
|
||||||
|
// image (base64 source)
|
||||||
|
Source *ImageSource `json:"source,omitempty"`
|
||||||
|
|
||||||
|
// tool_use
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input json.RawMessage `json:"input,omitempty"`
|
||||||
|
|
||||||
|
// tool_result
|
||||||
|
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||||
|
Content []ContentBlock `json:"content,omitempty"`
|
||||||
|
IsError bool `json:"is_error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageSource holds a base64-encoded image for vision requests.
|
||||||
|
type ImageSource struct {
|
||||||
|
Type string `json:"type"` // "base64"
|
||||||
|
MediaType string `json:"media_type"` // "image/jpeg", "image/png", etc.
|
||||||
|
Data string `json:"data"` // base64-encoded image data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool describes a tool available to the model.
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
InputSchema json.RawMessage `json:"input_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamEvent represents a single event from the SSE stream.
|
||||||
|
type StreamEvent struct {
|
||||||
|
Type string `json:"type"` // "text_delta", "tool_use_start", "tool_use_delta", "tool_result", "done", "error"
|
||||||
|
Data string `json:"data"` // event payload
|
||||||
|
|
||||||
|
// Populated for tool_use_start events
|
||||||
|
ToolName string `json:"tool_name,omitempty"`
|
||||||
|
ToolID string `json:"tool_id,omitempty"`
|
||||||
|
ToolInput string `json:"tool_input,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage tracks token consumption for a request.
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationSummary is a lightweight view of a conversation for listing.
|
||||||
|
type ConversationSummary struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
TokensIn int `json:"tokensIn"`
|
||||||
|
TokensOut int `json:"tokensOut"`
|
||||||
|
}
|
||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/vstockwell/wraith/internal/ai"
|
||||||
"github.com/vstockwell/wraith/internal/connections"
|
"github.com/vstockwell/wraith/internal/connections"
|
||||||
"github.com/vstockwell/wraith/internal/credentials"
|
"github.com/vstockwell/wraith/internal/credentials"
|
||||||
"github.com/vstockwell/wraith/internal/db"
|
"github.com/vstockwell/wraith/internal/db"
|
||||||
@ -35,6 +36,8 @@ type WraithApp struct {
|
|||||||
SFTP *sftp.SFTPService
|
SFTP *sftp.SFTPService
|
||||||
RDP *rdp.RDPService
|
RDP *rdp.RDPService
|
||||||
Credentials *credentials.CredentialService
|
Credentials *credentials.CredentialService
|
||||||
|
AI *ai.AIService
|
||||||
|
oauthMgr *ai.OAuthManager
|
||||||
unlocked bool
|
unlocked bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +81,15 @@ func New() (*WraithApp, error) {
|
|||||||
// CredentialService requires the vault to be unlocked, so it starts nil.
|
// CredentialService requires the vault to be unlocked, so it starts nil.
|
||||||
// It is created lazily after the vault is unlocked via initCredentials().
|
// It is created lazily after the vault is unlocked via initCredentials().
|
||||||
|
|
||||||
|
// AI copilot services — OAuthManager starts without a vault reference;
|
||||||
|
// it will be wired once the vault is unlocked.
|
||||||
|
oauthMgr := ai.NewOAuthManager(settingsSvc, nil)
|
||||||
|
toolRouter := ai.NewToolRouter()
|
||||||
|
toolRouter.SetServices(sshSvc, sftpSvc, rdpSvc, sessionMgr, connSvc)
|
||||||
|
convMgr := ai.NewConversationManager(database)
|
||||||
|
aiSvc := ai.NewAIService(oauthMgr, toolRouter, convMgr)
|
||||||
|
toolRouter.SetAIService(aiSvc)
|
||||||
|
|
||||||
// Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent)
|
// Seed built-in themes on every startup (INSERT OR IGNORE keeps it idempotent)
|
||||||
if err := themeSvc.SeedBuiltins(); err != nil {
|
if err := themeSvc.SeedBuiltins(); err != nil {
|
||||||
slog.Warn("failed to seed themes", "error", err)
|
slog.Warn("failed to seed themes", "error", err)
|
||||||
@ -93,6 +105,8 @@ func New() (*WraithApp, error) {
|
|||||||
SSH: sshSvc,
|
SSH: sshSvc,
|
||||||
SFTP: sftpSvc,
|
SFTP: sftpSvc,
|
||||||
RDP: rdpSvc,
|
RDP: rdpSvc,
|
||||||
|
AI: aiSvc,
|
||||||
|
oauthMgr: oauthMgr,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,9 +211,13 @@ func (a *WraithApp) IsUnlocked() bool {
|
|||||||
return a.unlocked
|
return a.unlocked
|
||||||
}
|
}
|
||||||
|
|
||||||
// initCredentials creates the CredentialService after the vault is unlocked.
|
// initCredentials creates the CredentialService after the vault is unlocked
|
||||||
|
// and wires the vault to services that need it (OAuth, etc.).
|
||||||
func (a *WraithApp) initCredentials() {
|
func (a *WraithApp) initCredentials() {
|
||||||
if a.Vault != nil {
|
if a.Vault != nil {
|
||||||
a.Credentials = credentials.NewCredentialService(a.db, a.Vault)
|
a.Credentials = credentials.NewCredentialService(a.db, a.Vault)
|
||||||
|
if a.oauthMgr != nil {
|
||||||
|
a.oauthMgr.SetVault(a.Vault)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
10
internal/db/migrations/002_ai_copilot.sql
Normal file
10
internal/db/migrations/002_ai_copilot.sql
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
title TEXT,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
messages TEXT NOT NULL DEFAULT '[]',
|
||||||
|
tokens_in INTEGER DEFAULT 0,
|
||||||
|
tokens_out INTEGER DEFAULT 0,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
1
main.go
1
main.go
@ -34,6 +34,7 @@ func main() {
|
|||||||
application.NewService(wraith.SSH),
|
application.NewService(wraith.SSH),
|
||||||
application.NewService(wraith.SFTP),
|
application.NewService(wraith.SFTP),
|
||||||
application.NewService(wraith.RDP),
|
application.NewService(wraith.RDP),
|
||||||
|
application.NewService(wraith.AI),
|
||||||
},
|
},
|
||||||
Assets: application.AssetOptions{
|
Assets: application.AssetOptions{
|
||||||
Handler: application.BundledAssetFileServer(assets),
|
Handler: application.BundledAssetFileServer(assets),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user