package serve import ( "axolotl/service" "context" "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" "net/http" "sync" "time" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) // pendingLogin tracks an in-progress authorization code flow. type pendingLogin struct { verifier string state string created time.Time serverToken string // set by callback when complete; empty while pending } // authHandler owns the OIDC provider connection, the pending login store, // and the active server-side session map. type authHandler struct { mu sync.Mutex pending map[string]*pendingLogin // loginID → pending state sessions map[string]string // serverToken → username cfg service.OIDCConfig provider *oidc.Provider oauth2 oauth2.Config } func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) { if cfg.PublicURL == "" { return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server") } provider, err := oidc.NewProvider(context.Background(), cfg.Issuer) if err != nil { return nil, fmt.Errorf("OIDC provider: %w", err) } h := &authHandler{ pending: make(map[string]*pendingLogin), sessions: make(map[string]string), cfg: cfg, provider: provider, oauth2: oauth2.Config{ ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, Endpoint: provider.Endpoint(), RedirectURL: cfg.PublicURL + "/auth/callback", Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}, }, } go h.cleanup() return h, nil } func (h *authHandler) cleanup() { for range time.Tick(5 * time.Minute) { h.mu.Lock() for id, p := range h.pending { if time.Since(p.created) > 15*time.Minute { delete(h.pending, id) } } h.mu.Unlock() } } // lookupSession returns the username for a server-issued token, or "". func (h *authHandler) lookupSession(token string) string { h.mu.Lock() defer h.mu.Unlock() return h.sessions[token] } // POST /auth/start → {url, session_id} func (h *authHandler) start(w http.ResponseWriter, r *http.Request) { loginID := randomToken(16) verifier := randomToken(32) state := randomToken(16) authURL := h.oauth2.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) h.mu.Lock() h.pending[loginID] = &pendingLogin{ verifier: verifier, state: state, created: time.Now(), } h.mu.Unlock() writeJSON(w, map[string]string{"url": authURL, "session_id": loginID}) } // GET /auth/callback — OIDC provider redirects here after user authenticates. func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) { stateParam := r.URL.Query().Get("state") code := r.URL.Query().Get("code") h.mu.Lock() var loginID string var pending *pendingLogin for id, p := range h.pending { if p.state == stateParam { loginID, pending = id, p break } } h.mu.Unlock() if pending == nil { http.Error(w, "invalid or expired state", http.StatusBadRequest) return } token, err := h.oauth2.Exchange(r.Context(), code, oauth2.SetAuthURLParam("code_verifier", pending.verifier), ) if err != nil { http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest) return } username, err := h.extractUsername(r.Context(), token) if err != nil { http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError) return } serverToken := randomToken(32) h.mu.Lock() h.sessions[serverToken] = username if p := h.pending[loginID]; p != nil { p.serverToken = serverToken } h.mu.Unlock() fmt.Fprintln(w, "Login successful! You can close this tab.") } // GET /auth/poll?session_id=... // Returns 202 while pending, 200 {token, username} when done, 404 if expired. func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) { loginID := r.URL.Query().Get("session_id") h.mu.Lock() p := h.pending[loginID] h.mu.Unlock() if p == nil { writeError(w, http.StatusNotFound, "session not found or expired") return } h.mu.Lock() serverToken := p.serverToken if serverToken != "" { delete(h.pending, loginID) // consume once delivered } h.mu.Unlock() if serverToken == "" { w.WriteHeader(http.StatusAccepted) return } username := h.lookupSession(serverToken) writeJSON(w, map[string]string{"token": serverToken, "username": username}) } func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) { rawID, ok := token.Extra("id_token").(string) if !ok { return "", fmt.Errorf("no id_token in response") } idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID) if err != nil { return "", err } var claims map[string]any if err := idToken.Claims(&claims); err != nil { return "", err } user, _ := claims[h.cfg.UserClaim].(string) if user == "" { return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim) } return user, nil } func randomToken(n int) string { b := make([]byte, n) rand.Read(b) //nolint:errcheck return base64.RawURLEncoding.EncodeToString(b) } func pkceChallenge(verifier string) string { h := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(h[:]) }