wraith/internal/ai/oauth.go
Vantz Stockwell 7ee5321d69
Some checks failed
Build & Sign Wraith / Build Windows + Sign (push) Has been cancelled
feat: AI copilot backend — OAuth PKCE, Claude API streaming, 16 tools, conversations
- OAuth PKCE flow for Max subscription auth (no API key needed)
- Claude API client with SSE streaming (Messages API v1)
- 16 tool definitions: terminal, SFTP, RDP, session management
- Tool dispatch router mapping to existing Wraith services
- Conversation manager with SQLite persistence
- Terminal output ring buffer for AI context
- RDP screenshot encoder (RGBA → JPEG with downscaling)
- Wired into Wails app as AIService

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 09:09:23 -04:00

372 lines
10 KiB
Go

package ai
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/vstockwell/wraith/internal/settings"
"github.com/vstockwell/wraith/internal/vault"
)
const (
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauthAuthorizeURL = "https://claude.ai/oauth/authorize"
oauthTokenURL = "https://platform.claude.com/v1/oauth/token"
)
// OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions.
type OAuthManager struct {
settings *settings.SettingsService
vault *vault.VaultService
clientID string
authorizeURL string
tokenURL string
mu sync.RWMutex
}
// tokenResponse is the JSON body returned by the token endpoint.
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// NewOAuthManager creates an OAuthManager wired to settings and vault services.
func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager {
return &OAuthManager{
settings: s,
vault: v,
clientID: oauthClientID,
authorizeURL: oauthAuthorizeURL,
tokenURL: oauthTokenURL,
}
}
// StartLogin begins the OAuth PKCE flow. It starts a local HTTP server,
// opens the authorization URL in the user's browser, and returns a channel
// that receives nil on success or an error on failure.
func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) {
verifier, err := generateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier: %w", err)
}
challenge := generateCodeChallenge(verifier)
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generate state: %w", err)
}
// Start a local HTTP server on a random port for the callback
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen for callback: %w", err)
}
port := listener.Addr().(*net.TCPAddr).Port
redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port)
done := make(chan error, 1)
mux := http.NewServeMux()
server := &http.Server{Handler: mux}
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
defer func() {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
}()
// Verify state
if r.URL.Query().Get("state") != state {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "State mismatch — please try logging in again.")
done <- fmt.Errorf("state mismatch")
return
}
if errParam := r.URL.Query().Get("error"); errParam != "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Authorization error: %s", errParam)
done <- fmt.Errorf("authorization error: %s", errParam)
return
}
code := r.URL.Query().Get("code")
if code == "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "No authorization code received.")
done <- fmt.Errorf("no authorization code")
return
}
// Exchange code for tokens
if err := o.exchangeCode(code, verifier, redirectURI); err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, "Failed to exchange authorization code.")
done <- fmt.Errorf("exchange code: %w", err)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, `<html><body><h2>Authenticated!</h2><p>You can close this window and return to Wraith.</p></body></html>`)
done <- nil
})
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
slog.Error("oauth callback server error", "error", err)
done <- fmt.Errorf("callback server: %w", err)
}
}()
// Build the authorize URL
params := url.Values{
"response_type": {"code"},
"client_id": {o.clientID},
"redirect_uri": {redirectURI},
"scope": {"user:inference"},
"state": {state},
"code_challenge": {challenge},
"code_challenge_method": {"S256"},
}
authURL := o.authorizeURL + "?" + params.Encode()
if openURL != nil {
if err := openURL(authURL); err != nil {
listener.Close()
return nil, fmt.Errorf("open browser: %w", err)
}
}
return done, nil
}
// exchangeCode exchanges an authorization code for access and refresh tokens.
func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {o.clientID},
"code_verifier": {verifier},
}
resp, err := http.PostForm(o.tokenURL, data)
if err != nil {
return fmt.Errorf("post token request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read token response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body))
}
var tokens tokenResponse
if err := json.Unmarshal(body, &tokens); err != nil {
return fmt.Errorf("unmarshal token response: %w", err)
}
return o.storeTokens(tokens)
}
// storeTokens encrypts and persists OAuth tokens in settings.
func (o *OAuthManager) storeTokens(tokens tokenResponse) error {
o.mu.Lock()
defer o.mu.Unlock()
if o.vault == nil {
return fmt.Errorf("vault is not unlocked — cannot store tokens")
}
encAccess, err := o.vault.Encrypt(tokens.AccessToken)
if err != nil {
return fmt.Errorf("encrypt access token: %w", err)
}
if err := o.settings.Set("ai_access_token", encAccess); err != nil {
return fmt.Errorf("store access token: %w", err)
}
encRefresh, err := o.vault.Encrypt(tokens.RefreshToken)
if err != nil {
return fmt.Errorf("encrypt refresh token: %w", err)
}
if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil {
return fmt.Errorf("store refresh token: %w", err)
}
expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix()
if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil {
return fmt.Errorf("store token expiry: %w", err)
}
slog.Info("oauth tokens stored successfully")
return nil
}
// GetAccessToken returns a valid access token, refreshing if expired.
func (o *OAuthManager) GetAccessToken() (string, error) {
o.mu.RLock()
vlt := o.vault
o.mu.RUnlock()
if vlt == nil {
return "", fmt.Errorf("vault is not unlocked")
}
// Check expiry
expiresStr, _ := o.settings.Get("ai_token_expires_at")
if expiresStr != "" {
expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64)
if time.Now().Unix() >= expiresAt {
// Token expired, try to refresh
if err := o.refreshToken(); err != nil {
return "", fmt.Errorf("refresh expired token: %w", err)
}
}
}
encAccess, err := o.settings.Get("ai_access_token")
if err != nil || encAccess == "" {
return "", fmt.Errorf("no access token stored")
}
accessToken, err := vlt.Decrypt(encAccess)
if err != nil {
return "", fmt.Errorf("decrypt access token: %w", err)
}
return accessToken, nil
}
// refreshToken uses the stored refresh token to obtain new tokens.
func (o *OAuthManager) refreshToken() error {
o.mu.RLock()
vlt := o.vault
o.mu.RUnlock()
if vlt == nil {
return fmt.Errorf("vault is not unlocked")
}
encRefresh, err := o.settings.Get("ai_refresh_token")
if err != nil || encRefresh == "" {
return fmt.Errorf("no refresh token stored")
}
refreshTok, err := vlt.Decrypt(encRefresh)
if err != nil {
return fmt.Errorf("decrypt refresh token: %w", err)
}
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshTok},
"client_id": {o.clientID},
}
resp, err := http.PostForm(o.tokenURL, data)
if err != nil {
return fmt.Errorf("post refresh request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body))
}
var tokens tokenResponse
if err := json.Unmarshal(body, &tokens); err != nil {
return fmt.Errorf("unmarshal refresh response: %w", err)
}
// If the refresh endpoint doesn't return a new refresh token, keep the old one
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshTok
}
return o.storeTokens(tokens)
}
// IsAuthenticated checks if we have stored tokens.
func (o *OAuthManager) IsAuthenticated() bool {
encAccess, err := o.settings.Get("ai_access_token")
return err == nil && encAccess != ""
}
// Logout clears all stored OAuth tokens.
func (o *OAuthManager) Logout() error {
for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} {
if err := o.settings.Delete(key); err != nil {
return fmt.Errorf("delete %s: %w", key, err)
}
}
return nil
}
// generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding).
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// generateCodeChallenge creates a PKCE S256 code challenge from a verifier.
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}
// generateState creates a random state parameter (32 random bytes, base64url no padding).
func generateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// SetVault updates the vault reference (called after vault unlock).
func (o *OAuthManager) SetVault(v *vault.VaultService) {
o.mu.Lock()
defer o.mu.Unlock()
o.vault = v
}
// isBase64URL checks if a string contains only base64url characters (no padding).
func isBase64URL(s string) bool {
for _, c := range s {
if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') {
return false
}
}
return true
}