feat: add OIDC authentication for server mode

This commit is contained in:
2026-04-01 19:33:15 +02:00
parent 7bce56384f
commit 52a975b66d
13 changed files with 515 additions and 7 deletions

210
serve/auth.go Normal file
View File

@@ -0,0 +1,210 @@
package serve
import (
"axolotl/service"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
// pendingLogin tracks an in-progress authorization code flow.
type pendingLogin struct {
verifier string
state string
created time.Time
serverToken string // set by callback when complete; empty while pending
}
// authHandler owns the OIDC provider connection, the pending login store,
// and the active server-side session map.
type authHandler struct {
mu sync.Mutex
pending map[string]*pendingLogin // loginID → pending state
sessions map[string]string // serverToken → username
cfg service.OIDCConfig
provider *oidc.Provider
oauth2 oauth2.Config
}
func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) {
if cfg.PublicURL == "" {
return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server")
}
provider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("OIDC provider: %w", err)
}
h := &authHandler{
pending: make(map[string]*pendingLogin),
sessions: make(map[string]string),
cfg: cfg,
provider: provider,
oauth2: oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: cfg.PublicURL + "/auth/callback",
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
},
}
go h.cleanup()
return h, nil
}
func (h *authHandler) cleanup() {
for range time.Tick(5 * time.Minute) {
h.mu.Lock()
for id, p := range h.pending {
if time.Since(p.created) > 15*time.Minute {
delete(h.pending, id)
}
}
h.mu.Unlock()
}
}
// lookupSession returns the username for a server-issued token, or "".
func (h *authHandler) lookupSession(token string) string {
h.mu.Lock()
defer h.mu.Unlock()
return h.sessions[token]
}
// POST /auth/start → {url, session_id}
func (h *authHandler) start(w http.ResponseWriter, r *http.Request) {
loginID := randomToken(16)
verifier := randomToken(32)
state := randomToken(16)
authURL := h.oauth2.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
h.mu.Lock()
h.pending[loginID] = &pendingLogin{
verifier: verifier,
state: state,
created: time.Now(),
}
h.mu.Unlock()
writeJSON(w, map[string]string{"url": authURL, "session_id": loginID})
}
// GET /auth/callback — OIDC provider redirects here after user authenticates.
func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) {
stateParam := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")
h.mu.Lock()
var loginID string
var pending *pendingLogin
for id, p := range h.pending {
if p.state == stateParam {
loginID, pending = id, p
break
}
}
h.mu.Unlock()
if pending == nil {
http.Error(w, "invalid or expired state", http.StatusBadRequest)
return
}
token, err := h.oauth2.Exchange(r.Context(), code,
oauth2.SetAuthURLParam("code_verifier", pending.verifier),
)
if err != nil {
http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest)
return
}
username, err := h.extractUsername(r.Context(), token)
if err != nil {
http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError)
return
}
serverToken := randomToken(32)
h.mu.Lock()
h.sessions[serverToken] = username
if p := h.pending[loginID]; p != nil {
p.serverToken = serverToken
}
h.mu.Unlock()
fmt.Fprintln(w, "Login successful! You can close this tab.")
}
// GET /auth/poll?session_id=...
// Returns 202 while pending, 200 {token, username} when done, 404 if expired.
func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) {
loginID := r.URL.Query().Get("session_id")
h.mu.Lock()
p := h.pending[loginID]
h.mu.Unlock()
if p == nil {
writeError(w, http.StatusNotFound, "session not found or expired")
return
}
h.mu.Lock()
serverToken := p.serverToken
if serverToken != "" {
delete(h.pending, loginID) // consume once delivered
}
h.mu.Unlock()
if serverToken == "" {
w.WriteHeader(http.StatusAccepted)
return
}
username := h.lookupSession(serverToken)
writeJSON(w, map[string]string{"token": serverToken, "username": username})
}
func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) {
rawID, ok := token.Extra("id_token").(string)
if !ok {
return "", fmt.Errorf("no id_token in response")
}
idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID)
if err != nil {
return "", err
}
var claims map[string]any
if err := idToken.Claims(&claims); err != nil {
return "", err
}
user, _ := claims[h.cfg.UserClaim].(string)
if user == "" {
return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim)
}
return user, nil
}
func randomToken(n int) string {
b := make([]byte, n)
rand.Read(b) //nolint:errcheck
return base64.RawURLEncoding.EncodeToString(b)
}
func pkceChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}

41
serve/oidc.go Normal file
View File

@@ -0,0 +1,41 @@
package serve
import (
"context"
"net/http"
"strings"
)
type contextKey string
const userContextKey contextKey = "ax_user"
// withSessionAuth wraps a handler with ax session token authentication.
// Auth endpoints (/auth/*) are passed through without a token check.
// All other requests must supply Authorization: Bearer <server_token>.
func withSessionAuth(ah *authHandler, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/auth/") {
next.ServeHTTP(w, r)
return
}
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
writeError(w, http.StatusUnauthorized, "Bearer token required")
return
}
token := strings.TrimPrefix(auth, "Bearer ")
username := ah.lookupSession(token)
if username == "" {
writeError(w, http.StatusUnauthorized, "invalid or expired session; run 'ax login'")
return
}
ctx := context.WithValue(r.Context(), userContextKey, username)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func userFromContext(r *http.Request) string {
v, _ := r.Context().Value(userContextKey).(string)
return v
}

View File

@@ -9,8 +9,10 @@ import (
)
// New returns an HTTP handler that exposes NodeService as a JSON API.
// Every request must supply an X-Ax-User header identifying the acting user.
func New(newSvc func(user string) (service.NodeService, error)) http.Handler {
// When oidcCfg is non-nil, every request must carry a valid Bearer token;
// the authenticated username is derived from the token claim configured in
// OIDCConfig.UserClaim. Without OIDC, the X-Ax-User header is used instead.
func New(newSvc func(user string) (service.NodeService, error), oidcCfg *service.OIDCConfig) (http.Handler, error) {
s := &server{newSvc: newSvc}
mux := http.NewServeMux()
mux.HandleFunc("GET /nodes", s.listNodes)
@@ -20,7 +22,17 @@ func New(newSvc func(user string) (service.NodeService, error)) http.Handler {
mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode)
mux.HandleFunc("GET /users", s.listUsers)
mux.HandleFunc("POST /users", s.addUser)
return mux
if oidcCfg != nil {
ah, err := newAuthHandler(*oidcCfg)
if err != nil {
return nil, err
}
mux.HandleFunc("POST /auth/start", ah.start)
mux.HandleFunc("GET /auth/callback", ah.callback)
mux.HandleFunc("GET /auth/poll", ah.poll)
return withSessionAuth(ah, mux), nil
}
return mux, nil
}
type server struct {
@@ -28,7 +40,10 @@ type server struct {
}
func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) {
user := r.Header.Get("X-Ax-User")
user := userFromContext(r)
if user == "" {
user = r.Header.Get("X-Ax-User")
}
if user == "" {
writeError(w, http.StatusUnauthorized, "X-Ax-User header required")
return nil, false