211 lines
5.2 KiB
Go
211 lines
5.2 KiB
Go
package serve
|
|
|
|
import (
|
|
"axolotl/service"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// pendingLogin tracks an in-progress authorization code flow.
|
|
type pendingLogin struct {
|
|
verifier string
|
|
state string
|
|
created time.Time
|
|
serverToken string // set by callback when complete; empty while pending
|
|
}
|
|
|
|
// authHandler owns the OIDC provider connection, the pending login store,
|
|
// and the active server-side session map.
|
|
type authHandler struct {
|
|
mu sync.Mutex
|
|
pending map[string]*pendingLogin // loginID → pending state
|
|
sessions map[string]string // serverToken → username
|
|
|
|
cfg service.OIDCConfig
|
|
provider *oidc.Provider
|
|
oauth2 oauth2.Config
|
|
}
|
|
|
|
func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) {
|
|
if cfg.PublicURL == "" {
|
|
return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server")
|
|
}
|
|
provider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("OIDC provider: %w", err)
|
|
}
|
|
h := &authHandler{
|
|
pending: make(map[string]*pendingLogin),
|
|
sessions: make(map[string]string),
|
|
cfg: cfg,
|
|
provider: provider,
|
|
oauth2: oauth2.Config{
|
|
ClientID: cfg.ClientID,
|
|
ClientSecret: cfg.ClientSecret,
|
|
Endpoint: provider.Endpoint(),
|
|
RedirectURL: cfg.PublicURL + "/auth/callback",
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
|
|
},
|
|
}
|
|
go h.cleanup()
|
|
return h, nil
|
|
}
|
|
|
|
func (h *authHandler) cleanup() {
|
|
for range time.Tick(5 * time.Minute) {
|
|
h.mu.Lock()
|
|
for id, p := range h.pending {
|
|
if time.Since(p.created) > 15*time.Minute {
|
|
delete(h.pending, id)
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// lookupSession returns the username for a server-issued token, or "".
|
|
func (h *authHandler) lookupSession(token string) string {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
return h.sessions[token]
|
|
}
|
|
|
|
// POST /auth/start → {url, session_id}
|
|
func (h *authHandler) start(w http.ResponseWriter, r *http.Request) {
|
|
loginID := randomToken(16)
|
|
verifier := randomToken(32)
|
|
state := randomToken(16)
|
|
|
|
authURL := h.oauth2.AuthCodeURL(state,
|
|
oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
|
|
h.mu.Lock()
|
|
h.pending[loginID] = &pendingLogin{
|
|
verifier: verifier,
|
|
state: state,
|
|
created: time.Now(),
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
writeJSON(w, map[string]string{"url": authURL, "session_id": loginID})
|
|
}
|
|
|
|
// GET /auth/callback — OIDC provider redirects here after user authenticates.
|
|
func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) {
|
|
stateParam := r.URL.Query().Get("state")
|
|
code := r.URL.Query().Get("code")
|
|
|
|
h.mu.Lock()
|
|
var loginID string
|
|
var pending *pendingLogin
|
|
for id, p := range h.pending {
|
|
if p.state == stateParam {
|
|
loginID, pending = id, p
|
|
break
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
if pending == nil {
|
|
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
token, err := h.oauth2.Exchange(r.Context(), code,
|
|
oauth2.SetAuthURLParam("code_verifier", pending.verifier),
|
|
)
|
|
if err != nil {
|
|
http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
username, err := h.extractUsername(r.Context(), token)
|
|
if err != nil {
|
|
http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
serverToken := randomToken(32)
|
|
|
|
h.mu.Lock()
|
|
h.sessions[serverToken] = username
|
|
if p := h.pending[loginID]; p != nil {
|
|
p.serverToken = serverToken
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
fmt.Fprintln(w, "Login successful! You can close this tab.")
|
|
}
|
|
|
|
// GET /auth/poll?session_id=...
|
|
// Returns 202 while pending, 200 {token, username} when done, 404 if expired.
|
|
func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) {
|
|
loginID := r.URL.Query().Get("session_id")
|
|
|
|
h.mu.Lock()
|
|
p := h.pending[loginID]
|
|
h.mu.Unlock()
|
|
|
|
if p == nil {
|
|
writeError(w, http.StatusNotFound, "session not found or expired")
|
|
return
|
|
}
|
|
|
|
h.mu.Lock()
|
|
serverToken := p.serverToken
|
|
if serverToken != "" {
|
|
delete(h.pending, loginID) // consume once delivered
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
if serverToken == "" {
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
username := h.lookupSession(serverToken)
|
|
writeJSON(w, map[string]string{"token": serverToken, "username": username})
|
|
}
|
|
|
|
func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) {
|
|
rawID, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("no id_token in response")
|
|
}
|
|
idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var claims map[string]any
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
return "", err
|
|
}
|
|
user, _ := claims[h.cfg.UserClaim].(string)
|
|
if user == "" {
|
|
return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func randomToken(n int) string {
|
|
b := make([]byte, n)
|
|
rand.Read(b) //nolint:errcheck
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
}
|
|
|
|
func pkceChallenge(verifier string) string {
|
|
h := sha256.Sum256([]byte(verifier))
|
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
|
}
|