All checks were successful
Build & Sign Wraith / Build Windows + Sign (push) Successful in 1m2s
The callback page now shows the real error message instead of a generic "Failed to exchange" message. Token exchange tries JSON Content-Type first (matching Claude Code's pattern) with form-encoded fallback. Full response body logged for debugging. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
420 lines
12 KiB
Go
420 lines
12 KiB
Go
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/vstockwell/wraith/internal/settings"
|
|
"github.com/vstockwell/wraith/internal/vault"
|
|
)
|
|
|
|
const (
|
|
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
|
oauthAuthorizeURL = "https://claude.ai/oauth/authorize"
|
|
oauthTokenURL = "https://platform.claude.com/v1/oauth/token"
|
|
)
|
|
|
|
// OAuthManager handles OAuth PKCE authentication for Claude Max subscriptions.
|
|
type OAuthManager struct {
|
|
settings *settings.SettingsService
|
|
vault *vault.VaultService
|
|
clientID string
|
|
authorizeURL string
|
|
tokenURL string
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// tokenResponse is the JSON body returned by the token endpoint.
|
|
type tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
}
|
|
|
|
// NewOAuthManager creates an OAuthManager wired to settings and vault services.
|
|
func NewOAuthManager(s *settings.SettingsService, v *vault.VaultService) *OAuthManager {
|
|
return &OAuthManager{
|
|
settings: s,
|
|
vault: v,
|
|
clientID: oauthClientID,
|
|
authorizeURL: oauthAuthorizeURL,
|
|
tokenURL: oauthTokenURL,
|
|
}
|
|
}
|
|
|
|
// StartLogin begins the OAuth PKCE flow. It starts a local HTTP server,
|
|
// opens the authorization URL in the user's browser, and returns a channel
|
|
// that receives nil on success or an error on failure.
|
|
func (o *OAuthManager) StartLogin(openURL func(string) error) (<-chan error, error) {
|
|
verifier, err := generateCodeVerifier()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate code verifier: %w", err)
|
|
}
|
|
challenge := generateCodeChallenge(verifier)
|
|
state, err := generateState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate state: %w", err)
|
|
}
|
|
|
|
// Start a local HTTP server on a random port for the callback
|
|
listener, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listen for callback: %w", err)
|
|
}
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
redirectURI := fmt.Sprintf("http://localhost:%d/callback", port)
|
|
|
|
done := make(chan error, 1)
|
|
|
|
mux := http.NewServeMux()
|
|
server := &http.Server{Handler: mux}
|
|
|
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
go func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
server.Shutdown(ctx)
|
|
}()
|
|
}()
|
|
|
|
// Verify state
|
|
if r.URL.Query().Get("state") != state {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprint(w, "State mismatch — please try logging in again.")
|
|
done <- fmt.Errorf("state mismatch")
|
|
return
|
|
}
|
|
|
|
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "Authorization error: %s", errParam)
|
|
done <- fmt.Errorf("authorization error: %s", errParam)
|
|
return
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprint(w, "No authorization code received.")
|
|
done <- fmt.Errorf("no authorization code")
|
|
return
|
|
}
|
|
|
|
// Exchange code for tokens
|
|
if err := o.exchangeCode(code, verifier, redirectURI); err != nil {
|
|
slog.Error("oauth token exchange failed", "error", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
fmt.Fprintf(w, "<html><body><h2>Authentication Failed</h2><pre>%s</pre><p>Check Wraith logs for details.</p></body></html>", err.Error())
|
|
done <- fmt.Errorf("exchange code: %w", err)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/html")
|
|
fmt.Fprint(w, `<html><body><h2>Authenticated!</h2><p>You can close this window and return to Wraith.</p></body></html>`)
|
|
done <- nil
|
|
})
|
|
|
|
go func() {
|
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
slog.Error("oauth callback server error", "error", err)
|
|
done <- fmt.Errorf("callback server: %w", err)
|
|
}
|
|
}()
|
|
|
|
// Build the authorize URL
|
|
params := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {o.clientID},
|
|
"redirect_uri": {redirectURI},
|
|
"scope": {"user:inference"},
|
|
"state": {state},
|
|
"code_challenge": {challenge},
|
|
"code_challenge_method": {"S256"},
|
|
}
|
|
authURL := o.authorizeURL + "?" + params.Encode()
|
|
|
|
if openURL != nil {
|
|
if err := openURL(authURL); err != nil {
|
|
listener.Close()
|
|
return nil, fmt.Errorf("open browser: %w", err)
|
|
}
|
|
}
|
|
|
|
return done, nil
|
|
}
|
|
|
|
// exchangeCode exchanges an authorization code for access and refresh tokens.
|
|
func (o *OAuthManager) exchangeCode(code, verifier, redirectURI string) error {
|
|
// Try JSON format first (what Claude Code appears to use), fall back to form-encoded
|
|
payload := map[string]string{
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": redirectURI,
|
|
"client_id": o.clientID,
|
|
"code_verifier": verifier,
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal token request: %w", err)
|
|
}
|
|
|
|
slog.Info("exchanging auth code", "tokenURL", o.tokenURL, "redirectURI", redirectURI)
|
|
|
|
req, err := http.NewRequest("POST", o.tokenURL, io.NopCloser(
|
|
bytes.NewReader(jsonBody),
|
|
))
|
|
if err != nil {
|
|
return fmt.Errorf("create token request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("post token request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("read token response: %w", err)
|
|
}
|
|
|
|
slog.Info("token endpoint response", "status", resp.StatusCode, "body", string(body)[:min(len(body), 500)])
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
// If JSON failed, try form-encoded
|
|
slog.Info("JSON token exchange failed, trying form-encoded")
|
|
data := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {code},
|
|
"redirect_uri": {redirectURI},
|
|
"client_id": {o.clientID},
|
|
"code_verifier": {verifier},
|
|
}
|
|
resp2, err := http.PostForm(o.tokenURL, data)
|
|
if err != nil {
|
|
return fmt.Errorf("post form token request: %w", err)
|
|
}
|
|
defer resp2.Body.Close()
|
|
body, err = io.ReadAll(resp2.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("read form token response: %w", err)
|
|
}
|
|
slog.Info("form-encoded token response", "status", resp2.StatusCode, "body", string(body)[:min(len(body), 500)])
|
|
if resp2.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("token endpoint returned %d (json) and %d (form): %s", resp.StatusCode, resp2.StatusCode, string(body))
|
|
}
|
|
}
|
|
|
|
var tokens tokenResponse
|
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
|
return fmt.Errorf("unmarshal token response: %w", err)
|
|
}
|
|
|
|
return o.storeTokens(tokens)
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// storeTokens encrypts and persists OAuth tokens in settings.
|
|
func (o *OAuthManager) storeTokens(tokens tokenResponse) error {
|
|
o.mu.Lock()
|
|
defer o.mu.Unlock()
|
|
|
|
if o.vault == nil {
|
|
return fmt.Errorf("vault is not unlocked — cannot store tokens")
|
|
}
|
|
|
|
encAccess, err := o.vault.Encrypt(tokens.AccessToken)
|
|
if err != nil {
|
|
return fmt.Errorf("encrypt access token: %w", err)
|
|
}
|
|
if err := o.settings.Set("ai_access_token", encAccess); err != nil {
|
|
return fmt.Errorf("store access token: %w", err)
|
|
}
|
|
|
|
encRefresh, err := o.vault.Encrypt(tokens.RefreshToken)
|
|
if err != nil {
|
|
return fmt.Errorf("encrypt refresh token: %w", err)
|
|
}
|
|
if err := o.settings.Set("ai_refresh_token", encRefresh); err != nil {
|
|
return fmt.Errorf("store refresh token: %w", err)
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second).Unix()
|
|
if err := o.settings.Set("ai_token_expires_at", strconv.FormatInt(expiresAt, 10)); err != nil {
|
|
return fmt.Errorf("store token expiry: %w", err)
|
|
}
|
|
|
|
slog.Info("oauth tokens stored successfully")
|
|
return nil
|
|
}
|
|
|
|
// GetAccessToken returns a valid access token, refreshing if expired.
|
|
func (o *OAuthManager) GetAccessToken() (string, error) {
|
|
o.mu.RLock()
|
|
vlt := o.vault
|
|
o.mu.RUnlock()
|
|
|
|
if vlt == nil {
|
|
return "", fmt.Errorf("vault is not unlocked")
|
|
}
|
|
|
|
// Check expiry
|
|
expiresStr, _ := o.settings.Get("ai_token_expires_at")
|
|
if expiresStr != "" {
|
|
expiresAt, _ := strconv.ParseInt(expiresStr, 10, 64)
|
|
if time.Now().Unix() >= expiresAt {
|
|
// Token expired, try to refresh
|
|
if err := o.refreshToken(); err != nil {
|
|
return "", fmt.Errorf("refresh expired token: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
encAccess, err := o.settings.Get("ai_access_token")
|
|
if err != nil || encAccess == "" {
|
|
return "", fmt.Errorf("no access token stored")
|
|
}
|
|
|
|
accessToken, err := vlt.Decrypt(encAccess)
|
|
if err != nil {
|
|
return "", fmt.Errorf("decrypt access token: %w", err)
|
|
}
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
// refreshToken uses the stored refresh token to obtain new tokens.
|
|
func (o *OAuthManager) refreshToken() error {
|
|
o.mu.RLock()
|
|
vlt := o.vault
|
|
o.mu.RUnlock()
|
|
|
|
if vlt == nil {
|
|
return fmt.Errorf("vault is not unlocked")
|
|
}
|
|
|
|
encRefresh, err := o.settings.Get("ai_refresh_token")
|
|
if err != nil || encRefresh == "" {
|
|
return fmt.Errorf("no refresh token stored")
|
|
}
|
|
|
|
refreshTok, err := vlt.Decrypt(encRefresh)
|
|
if err != nil {
|
|
return fmt.Errorf("decrypt refresh token: %w", err)
|
|
}
|
|
|
|
data := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshTok},
|
|
"client_id": {o.clientID},
|
|
}
|
|
|
|
resp, err := http.PostForm(o.tokenURL, data)
|
|
if err != nil {
|
|
return fmt.Errorf("post refresh request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("read refresh response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var tokens tokenResponse
|
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
|
return fmt.Errorf("unmarshal refresh response: %w", err)
|
|
}
|
|
|
|
// If the refresh endpoint doesn't return a new refresh token, keep the old one
|
|
if tokens.RefreshToken == "" {
|
|
tokens.RefreshToken = refreshTok
|
|
}
|
|
|
|
return o.storeTokens(tokens)
|
|
}
|
|
|
|
// IsAuthenticated checks if we have stored tokens.
|
|
func (o *OAuthManager) IsAuthenticated() bool {
|
|
encAccess, err := o.settings.Get("ai_access_token")
|
|
return err == nil && encAccess != ""
|
|
}
|
|
|
|
// Logout clears all stored OAuth tokens.
|
|
func (o *OAuthManager) Logout() error {
|
|
for _, key := range []string{"ai_access_token", "ai_refresh_token", "ai_token_expires_at"} {
|
|
if err := o.settings.Delete(key); err != nil {
|
|
return fmt.Errorf("delete %s: %w", key, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// generateCodeVerifier creates a PKCE code verifier (32 random bytes, base64url no padding).
|
|
func generateCodeVerifier() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// generateCodeChallenge creates a PKCE S256 code challenge from a verifier.
|
|
func generateCodeChallenge(verifier string) string {
|
|
h := sha256.Sum256([]byte(verifier))
|
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
|
}
|
|
|
|
// generateState creates a random state parameter (32 random bytes, base64url no padding).
|
|
func generateState() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// SetVault updates the vault reference (called after vault unlock).
|
|
func (o *OAuthManager) SetVault(v *vault.VaultService) {
|
|
o.mu.Lock()
|
|
defer o.mu.Unlock()
|
|
o.vault = v
|
|
}
|
|
|
|
// isBase64URL checks if a string contains only base64url characters (no padding).
|
|
func isBase64URL(s string) bool {
|
|
for _, c := range s {
|
|
if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_') {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|