wraith/internal/ai/oauth.go
Vantz Stockwell 999f8f0539
All checks were successful
Build & Sign Wraith / Build Windows + Sign (push) Successful in 1m2s
fix: OAuth token exchange — try JSON then form-encoded, show actual error in browser
The callback page now shows the real error message instead of a generic
"Failed to exchange" message. Token exchange tries JSON Content-Type first
(matching Claude Code's pattern) with form-encoded fallback. Full response
body logged for debugging.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 11:48:54 -04:00

420 lines
12 KiB
Go

package ai
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/vstockwell/wraith/internal/settings"
"github.com/vstockwell/wraith/internal/vault"
)
const (
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauthAuthorizeURL = "https://claude.ai/oauth/authorize"
oauthTokenURL = "https://platform.claude.com/v1/oauth/token"
)
// OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions.
type OAuthManager struct {
settings *settings.SettingsService
vault *vault.VaultService
clientID string
authorizeURL string
tokenURL string
mu sync.RWMutex
}
// tokenResponse is the JSON body returned by the token endpoint.
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// NewOAuthManager creates an OAuthManager wired to settings and vault services.
func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager {
return &OAuthManager{
settings: s,
vault: v,
clientID: oauthClientID,
authorizeURL: oauthAuthorizeURL,
tokenURL: oauthTokenURL,
}
}
// StartLogin begins the OAuth PKCE flow. It starts a local HTTP server,
// opens the authorization URL in the user's browser, and returns a channel
// that receives nil on success or an error on failure.
func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) {
verifier, err := generateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier: %w", err)
}
challenge := generateCodeChallenge(verifier)
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generate state: %w", err)
}
// Start a local HTTP server on a random port for the callback
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("listen for callback: %w", err)
}
port := listener.Addr().(*net.TCPAddr).Port
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
done := make(chan error, 1)
mux := http.NewServeMux()
server := &http.Server{Handler: mux}
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
defer func() {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
}()
// Verify state
if r.URL.Query().Get("state") != state {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "State mismatch — please try logging in again.")
done <- fmt.Errorf("state mismatch")
return
}
if errParam := r.URL.Query().Get("error"); errParam != "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Authorization error: %s", errParam)
done <- fmt.Errorf("authorization error: %s", errParam)
return
}
code := r.URL.Query().Get("code")
if code == "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "No authorization code received.")
done <- fmt.Errorf("no authorization code")
return
}
// Exchange code for tokens
if err := o.exchangeCode(code, verifier, redirectURI); err != nil {
slog.Error("oauth token exchange failed", "error", err)
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "<html><body><h2>Authentication Failed</h2><pre>%s</pre><p>Check Wraith logs for details.</p></body></html>", err.Error())
done <- fmt.Errorf("exchange code: %w", err)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, `<html><body><h2>Authenticated!</h2><p>You can close this window and return to Wraith.</p></body></html>`)
done <- nil
})
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
slog.Error("oauth callback server error", "error", err)
done <- fmt.Errorf("callback server: %w", err)
}
}()
// Build the authorize URL
params := url.Values{
"response_type": {"code"},
"client_id": {o.clientID},
"redirect_uri": {redirectURI},
"scope": {"user:inference"},
"state": {state},
"code_challenge": {challenge},
"code_challenge_method": {"S256"},
}
authURL := o.authorizeURL + "?" + params.Encode()
if openURL != nil {
if err := openURL(authURL); err != nil {
listener.Close()
return nil, fmt.Errorf("open browser: %w", err)
}
}
return done, nil
}
// exchangeCode exchanges an authorization code for access and refresh tokens.
func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error {
// Try JSON format first (what Claude Code appears to use), fall back to form-encoded
payload := map[string]string{
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirectURI,
"client_id": o.clientID,
"code_verifier": verifier,
}
jsonBody, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal token request: %w", err)
}
slog.Info("exchanging auth code", "tokenURL", o.tokenURL, "redirectURI", redirectURI)
req, err := http.NewRequest("POST", o.tokenURL, io.NopCloser(
bytes.NewReader(jsonBody),
))
if err != nil {
return fmt.Errorf("create token request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("post token request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read token response: %w", err)
}
slog.Info("token endpoint response", "status", resp.StatusCode, "body", string(body)[:min(len(body), 500)])
if resp.StatusCode != http.StatusOK {
// If JSON failed, try form-encoded
slog.Info("JSON token exchange failed, trying form-encoded")
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {o.clientID},
"code_verifier": {verifier},
}
resp2, err := http.PostForm(o.tokenURL, data)
if err != nil {
return fmt.Errorf("post form token request: %w", err)
}
defer resp2.Body.Close()
body, err = io.ReadAll(resp2.Body)
if err != nil {
return fmt.Errorf("read form token response: %w", err)
}
slog.Info("form-encoded token response", "status", resp2.StatusCode, "body", string(body)[:min(len(body), 500)])
if resp2.StatusCode != http.StatusOK {
return fmt.Errorf("token endpoint returned %d (json) and %d (form): %s", resp.StatusCode, resp2.StatusCode, string(body))
}
}
var tokens tokenResponse
if err := json.Unmarshal(body, &tokens); err != nil {
return fmt.Errorf("unmarshal token response: %w", err)
}
return o.storeTokens(tokens)
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// storeTokens encrypts and persists OAuth tokens in settings.
func (o *OAuthManager) storeTokens(tokens tokenResponse) error {
o.mu.Lock()
defer o.mu.Unlock()
if o.vault == nil {
return fmt.Errorf("vault is not unlocked — cannot store tokens")
}
encAccess, err := o.vault.Encrypt(tokens.AccessToken)
if err != nil {
return fmt.Errorf("encrypt access token: %w", err)
}
if err := o.settings.Set("ai_access_token", encAccess); err != nil {
return fmt.Errorf("store access token: %w", err)
}
encRefresh, err := o.vault.Encrypt(tokens.RefreshToken)
if err != nil {
return fmt.Errorf("encrypt refresh token: %w", err)
}
if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil {
return fmt.Errorf("store refresh token: %w", err)
}
expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix()
if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil {
return fmt.Errorf("store token expiry: %w", err)
}
slog.Info("oauth tokens stored successfully")
return nil
}
// GetAccessToken returns a valid access token, refreshing if expired.
func (o *OAuthManager) GetAccessToken() (string, error) {
o.mu.RLock()
vlt := o.vault
o.mu.RUnlock()
if vlt == nil {
return "", fmt.Errorf("vault is not unlocked")
}
// Check expiry
expiresStr, _ := o.settings.Get("ai_token_expires_at")
if expiresStr != "" {
expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64)
if time.Now().Unix() >= expiresAt {
// Token expired, try to refresh
if err := o.refreshToken(); err != nil {
return "", fmt.Errorf("refresh expired token: %w", err)
}
}
}
encAccess, err := o.settings.Get("ai_access_token")
if err != nil || encAccess == "" {
return "", fmt.Errorf("no access token stored")
}
accessToken, err := vlt.Decrypt(encAccess)
if err != nil {
return "", fmt.Errorf("decrypt access token: %w", err)
}
return accessToken, nil
}
// refreshToken uses the stored refresh token to obtain new tokens.
func (o *OAuthManager) refreshToken() error {
o.mu.RLock()
vlt := o.vault
o.mu.RUnlock()
if vlt == nil {
return fmt.Errorf("vault is not unlocked")
}
encRefresh, err := o.settings.Get("ai_refresh_token")
if err != nil || encRefresh == "" {
return fmt.Errorf("no refresh token stored")
}
refreshTok, err := vlt.Decrypt(encRefresh)
if err != nil {
return fmt.Errorf("decrypt refresh token: %w", err)
}
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshTok},
"client_id": {o.clientID},
}
resp, err := http.PostForm(o.tokenURL, data)
if err != nil {
return fmt.Errorf("post refresh request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body))
}
var tokens tokenResponse
if err := json.Unmarshal(body, &tokens); err != nil {
return fmt.Errorf("unmarshal refresh response: %w", err)
}
// If the refresh endpoint doesn't return a new refresh token, keep the old one
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshTok
}
return o.storeTokens(tokens)
}
// IsAuthenticated checks if we have stored tokens.
func (o *OAuthManager) IsAuthenticated() bool {
encAccess, err := o.settings.Get("ai_access_token")
return err == nil && encAccess != ""
}
// Logout clears all stored OAuth tokens.
func (o *OAuthManager) Logout() error {
for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} {
if err := o.settings.Delete(key); err != nil {
return fmt.Errorf("delete %s: %w", key, err)
}
}
return nil
}
// generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding).
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// generateCodeChallenge creates a PKCE S256 code challenge from a verifier.
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}
// generateState creates a random state parameter (32 random bytes, base64url no padding).
func generateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// SetVault updates the vault reference (called after vault unlock).
func (o *OAuthManager) SetVault(v *vault.VaultService) {
o.mu.Lock()
defer o.mu.Unlock()
o.vault = v
}
// isBase64URL checks if a string contains only base64url characters (no padding).
func isBase64URL(s string) bool {
for _, c := range s {
if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') {
return false
}
}
return true
}