wraith/internal/ai/client.go
Vantz Stockwell 7ee5321d69
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled
feat: AI copilot backend — OAuth PKCE, Claude API streaming, 16 tools, conversations
- OAuth PKCE flow for Max subscription auth (no API key needed)
- Claude API client with SSE streaming (Messages API v1)
- 16 tool definitions: terminal, SFTP, RDP, session management
- Tool dispatch router mapping to existing Wraith services
- Conversation manager with SQLite persistence
- Terminal output ring buffer for AI context
- RDP screenshot encoder (RGBA → JPEG with downscaling)
- Wired into Wails app as AIService

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 09:09:23 -04:00

257 lines
6.6 KiB
Go

package ai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
)
const (
defaultModel = "claude-sonnet-4-20250514"
anthropicVersion = "2023-06-01"
apiBaseURL = "https://api.anthropic.com/v1/messages"
)
// ClaudeClient is the HTTP client for the Anthropic Messages API.
type ClaudeClient struct {
oauth *OAuthManager
apiKey string // fallback for non-OAuth users
model string
httpClient *http.Client
}
// NewClaudeClient creates a client that authenticates via OAuth (primary) or API key (fallback).
func NewClaudeClient(oauth *OAuthManager, model string) *ClaudeClient {
if model == "" {
model = defaultModel
}
return &ClaudeClient{
oauth: oauth,
model: model,
httpClient: &http.Client{},
}
}
// SetAPIKey sets a fallback API key for users without Max subscriptions.
func (c *ClaudeClient) SetAPIKey(key string) {
c.apiKey = key
}
// apiRequest is the JSON body sent to the Messages API.
type apiRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
System string `json:"system,omitempty"`
Tools []Tool `json:"tools,omitempty"`
}
// SendMessage sends a request to the Messages API with streaming enabled.
// It returns a channel of StreamEvents that the caller reads until the channel is closed.
func (c *ClaudeClient) SendMessage(messages []Message, tools []Tool, systemPrompt string) (<-chan StreamEvent, error) {
reqBody := apiRequest{
Model: c.model,
Messages: messages,
MaxTokens: 8192,
Stream: true,
System: systemPrompt,
Tools: tools,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequest("POST", apiBaseURL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", anthropicVersion)
if err := c.setAuthHeader(req); err != nil {
return nil, fmt.Errorf("set auth header: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body))
}
ch := make(chan StreamEvent, 64)
go c.parseSSEStream(resp.Body, ch)
return ch, nil
}
// setAuthHeader sets the appropriate authorization header on the request.
func (c *ClaudeClient) setAuthHeader(req *http.Request) error {
// Try OAuth first
if c.oauth != nil && c.oauth.IsAuthenticated() {
token, err := c.oauth.GetAccessToken()
if err == nil {
req.Header.Set("Authorization", "Bearer "+token)
return nil
}
slog.Warn("oauth token failed, falling back to api key", "error", err)
}
// Fallback to API key
if c.apiKey != "" {
req.Header.Set("x-api-key", c.apiKey)
return nil
}
return fmt.Errorf("no authentication method available — log in via OAuth or set an API key")
}
// parseSSEStream reads the SSE response body and sends StreamEvents on the channel.
func (c *ClaudeClient) parseSSEStream(body io.ReadCloser, ch chan<- StreamEvent) {
defer body.Close()
defer close(ch)
scanner := bufio.NewScanner(body)
var currentEventType string
var currentToolID string
var currentToolName string
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
if strings.HasPrefix(line, "event: ") {
currentEventType = strings.TrimPrefix(line, "event: ")
continue
}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
switch currentEventType {
case "content_block_start":
var block struct {
Index int `json:"index"`
ContentBlock struct {
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Text string `json:"text"`
Input string `json:"input"`
} `json:"content_block"`
}
if err := json.Unmarshal([]byte(data), &block); err != nil {
slog.Warn("failed to parse content_block_start", "error", err)
continue
}
if block.ContentBlock.Type == "tool_use" {
currentToolID = block.ContentBlock.ID
currentToolName = block.ContentBlock.Name
ch <- StreamEvent{
Type: "tool_use_start",
ToolID: currentToolID,
ToolName: currentToolName,
}
} else if block.ContentBlock.Type == "text" && block.ContentBlock.Text != "" {
ch <- StreamEvent{Type: "text_delta", Data: block.ContentBlock.Text}
}
case "content_block_delta":
var delta struct {
Delta struct {
Type string `json:"type"`
Text string `json:"text"`
PartialJSON string `json:"partial_json"`
} `json:"delta"`
}
if err := json.Unmarshal([]byte(data), &delta); err != nil {
slog.Warn("failed to parse content_block_delta", "error", err)
continue
}
if delta.Delta.Type == "text_delta" {
ch <- StreamEvent{Type: "text_delta", Data: delta.Delta.Text}
} else if delta.Delta.Type == "input_json_delta" {
ch <- StreamEvent{
Type: "tool_use_delta",
Data: delta.Delta.PartialJSON,
ToolID: currentToolID,
ToolName: currentToolName,
}
}
case "content_block_stop":
// nothing to do; tool input is accumulated by the caller
case "message_delta":
// contains stop_reason and usage; emit usage info
ch <- StreamEvent{Type: "done", Data: data}
case "message_stop":
// end of message stream
return
case "message_start":
// contains message metadata; could extract usage but we handle it at message_delta
case "ping":
// heartbeat; ignore
case "error":
ch <- StreamEvent{Type: "error", Data: data}
return
}
}
if err := scanner.Err(); err != nil {
ch <- StreamEvent{Type: "error", Data: err.Error()}
}
}
// BuildRequestBody creates the JSON request body for testing purposes.
func BuildRequestBody(messages []Message, tools []Tool, systemPrompt, model string) ([]byte, error) {
if model == "" {
model = defaultModel
}
req := apiRequest{
Model: model,
Messages: messages,
MaxTokens: 8192,
Stream: true,
System: systemPrompt,
Tools: tools,
}
return json.Marshal(req)
}
// ParseSSELine parses a single SSE data line. Returns the event type and data payload.
func ParseSSELine(line string) (eventType, data string) {
if strings.HasPrefix(line, "event: ") {
return "event", strings.TrimPrefix(line, "event: ")
}
if strings.HasPrefix(line, "data: ") {
return "data", strings.TrimPrefix(line, "data: ")
}
return "", ""
}