package ai import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "log/slog" "net" "net/http" "net/url" "strconv" "sync" "time" "github.com/vstockwell/wraith/internal/settings" "github.com/vstockwell/wraith/internal/vault" ) const ( oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" oauthAuthorizeURL = "https://claude.ai/oauth/authorize" oauthTokenURL = "https://platform.claude.com/v1/oauth/token" ) // OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions. type OAuthManager struct { settings *settings.SettingsService vault *vault.VaultService clientID string authorizeURL string tokenURL string mu sync.RWMutex } // tokenResponse is the JSON body returned by the token endpoint. type tokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` TokenType string `json:"token_type"` } // NewOAuthManager creates an OAuthManager wired to settings and vault services. func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager { return &OAuthManager{ settings: s, vault: v, clientID: oauthClientID, authorizeURL: oauthAuthorizeURL, tokenURL: oauthTokenURL, } } // StartLogin begins the OAuth PKCE flow. It starts a local HTTP server, // opens the authorization URL in the user's browser, and returns a channel // that receives nil on success or an error on failure. func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) { verifier, err := generateCodeVerifier() if err != nil { return nil, fmt.Errorf("generate code verifier: %w", err) } challenge := generateCodeChallenge(verifier) state, err := generateState() if err != nil { return nil, fmt.Errorf("generate state: %w", err) } // Start a local HTTP server on a random port for the callback listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("listen for callback: %w", err) } port := listener.Addr().(*net.TCPAddr).Port redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port) done := make(chan error, 1) mux := http.NewServeMux() server := &http.Server{Handler: mux} mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { defer func() { go func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() server.Shutdown(ctx) }() }() // Verify state if r.URL.Query().Get("state") != state { w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "State mismatch — please try logging in again.") done <- fmt.Errorf("state mismatch") return } if errParam := r.URL.Query().Get("error"); errParam != "" { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "Authorization error: %s", errParam) done <- fmt.Errorf("authorization error: %s", errParam) return } code := r.URL.Query().Get("code") if code == "" { w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, "No authorization code received.") done <- fmt.Errorf("no authorization code") return } // Exchange code for tokens if err := o.exchangeCode(code, verifier, redirectURI); err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Fprint(w, "Failed to exchange authorization code.") done <- fmt.Errorf("exchange code: %w", err) return } w.Header().Set("Content-Type", "text/html") fmt.Fprint(w, `
You can close this window and return to Wraith.
`) done <- nil }) go func() { if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { slog.Error("oauth callback server error", "error", err) done <- fmt.Errorf("callback server: %w", err) } }() // Build the authorize URL params := url.Values{ "response_type": {"code"}, "client_id": {o.clientID}, "redirect_uri": {redirectURI}, "scope": {"user:inference"}, "state": {state}, "code_challenge": {challenge}, "code_challenge_method": {"S256"}, } authURL := o.authorizeURL + "?" + params.Encode() if openURL != nil { if err := openURL(authURL); err != nil { listener.Close() return nil, fmt.Errorf("open browser: %w", err) } } return done, nil } // exchangeCode exchanges an authorization code for access and refresh tokens. func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error { data := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, "redirect_uri": {redirectURI}, "client_id": {o.clientID}, "code_verifier": {verifier}, } resp, err := http.PostForm(o.tokenURL, data) if err != nil { return fmt.Errorf("post token request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("read token response: %w", err) } if resp.StatusCode != http.StatusOK { return fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body)) } var tokens tokenResponse if err := json.Unmarshal(body, &tokens); err != nil { return fmt.Errorf("unmarshal token response: %w", err) } return o.storeTokens(tokens) } // storeTokens encrypts and persists OAuth tokens in settings. func (o *OAuthManager) storeTokens(tokens tokenResponse) error { o.mu.Lock() defer o.mu.Unlock() if o.vault == nil { return fmt.Errorf("vault is not unlocked — cannot store tokens") } encAccess, err := o.vault.Encrypt(tokens.AccessToken) if err != nil { return fmt.Errorf("encrypt access token: %w", err) } if err := o.settings.Set("ai_access_token", encAccess); err != nil { return fmt.Errorf("store access token: %w", err) } encRefresh, err := o.vault.Encrypt(tokens.RefreshToken) if err != nil { return fmt.Errorf("encrypt refresh token: %w", err) } if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil { return fmt.Errorf("store refresh token: %w", err) } expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix() if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil { return fmt.Errorf("store token expiry: %w", err) } slog.Info("oauth tokens stored successfully") return nil } // GetAccessToken returns a valid access token, refreshing if expired. func (o *OAuthManager) GetAccessToken() (string, error) { o.mu.RLock() vlt := o.vault o.mu.RUnlock() if vlt == nil { return "", fmt.Errorf("vault is not unlocked") } // Check expiry expiresStr, _ := o.settings.Get("ai_token_expires_at") if expiresStr != "" { expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64) if time.Now().Unix() >= expiresAt { // Token expired, try to refresh if err := o.refreshToken(); err != nil { return "", fmt.Errorf("refresh expired token: %w", err) } } } encAccess, err := o.settings.Get("ai_access_token") if err != nil || encAccess == "" { return "", fmt.Errorf("no access token stored") } accessToken, err := vlt.Decrypt(encAccess) if err != nil { return "", fmt.Errorf("decrypt access token: %w", err) } return accessToken, nil } // refreshToken uses the stored refresh token to obtain new tokens. func (o *OAuthManager) refreshToken() error { o.mu.RLock() vlt := o.vault o.mu.RUnlock() if vlt == nil { return fmt.Errorf("vault is not unlocked") } encRefresh, err := o.settings.Get("ai_refresh_token") if err != nil || encRefresh == "" { return fmt.Errorf("no refresh token stored") } refreshTok, err := vlt.Decrypt(encRefresh) if err != nil { return fmt.Errorf("decrypt refresh token: %w", err) } data := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {refreshTok}, "client_id": {o.clientID}, } resp, err := http.PostForm(o.tokenURL, data) if err != nil { return fmt.Errorf("post refresh request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("read refresh response: %w", err) } if resp.StatusCode != http.StatusOK { return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body)) } var tokens tokenResponse if err := json.Unmarshal(body, &tokens); err != nil { return fmt.Errorf("unmarshal refresh response: %w", err) } // If the refresh endpoint doesn't return a new refresh token, keep the old one if tokens.RefreshToken == "" { tokens.RefreshToken = refreshTok } return o.storeTokens(tokens) } // IsAuthenticated checks if we have stored tokens. func (o *OAuthManager) IsAuthenticated() bool { encAccess, err := o.settings.Get("ai_access_token") return err == nil && encAccess != "" } // Logout clears all stored OAuth tokens. func (o *OAuthManager) Logout() error { for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} { if err := o.settings.Delete(key); err != nil { return fmt.Errorf("delete %s: %w", key, err) } } return nil } // generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding). func generateCodeVerifier() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } // generateCodeChallenge creates a PKCE S256 code challenge from a verifier. func generateCodeChallenge(verifier string) string { h := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(h[:]) } // generateState creates a random state parameter (32 random bytes, base64url no padding). func generateState() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } // SetVault updates the vault reference (called after vault unlock). func (o *OAuthManager) SetVault(v *vault.VaultService) { o.mu.Lock() defer o.mu.Unlock() o.vault = v } // isBase64URL checks if a string contains only base64url characters (no padding). func isBase64URL(s string) bool { for _, c := range s { if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') { return false } } return true }