From 52a975b66de43e105fbdde49569e55bc83c67a50 Mon Sep 17 00:00:00 2001 From: Elias Kohout Date: Wed, 1 Apr 2026 19:33:15 +0200 Subject: [PATCH] feat: add OIDC authentication for server mode --- cmd/login.go | 84 ++++++++++++++++ cmd/serve.go | 12 ++- go.mod | 3 + go.sum | 6 ++ serve/auth.go | 210 ++++++++++++++++++++++++++++++++++++++++ serve/oidc.go | 41 ++++++++ serve/server.go | 23 ++++- service/api_client.go | 16 ++- service/config.go | 12 +++ service/config_file.go | 12 +++ service/node_service.go | 2 +- service/session.go | 67 +++++++++++++ store/sqlite.go | 34 +++++++ 13 files changed, 515 insertions(+), 7 deletions(-) create mode 100644 cmd/login.go create mode 100644 serve/auth.go create mode 100644 serve/oidc.go create mode 100644 service/session.go diff --git a/cmd/login.go b/cmd/login.go new file mode 100644 index 0000000..7444da5 --- /dev/null +++ b/cmd/login.go @@ -0,0 +1,84 @@ +package cmd + +import ( + "axolotl/service" + "encoding/json" + "fmt" + "net/http" + "os" + "time" + + "github.com/spf13/cobra" +) + +var loginCmd = &cobra.Command{ + Use: "login", + Short: "Authenticate with the remote server via OIDC", + Run: func(cmd *cobra.Command, args []string) { + rc, ok := cfg.GetRemoteConfig() + if !ok { + fmt.Fprintln(os.Stderr, "no remote server configured; set remote.host in your config") + os.Exit(1) + } + base := fmt.Sprintf("http://%s:%d", rc.Host, rc.Port) + + resp, err := http.Post(base+"/auth/start", "application/json", nil) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to contact server: %v\n", err) + os.Exit(1) + } + var start struct { + URL string `json:"url"` + SessionID string `json:"session_id"` + } + json.NewDecoder(resp.Body).Decode(&start) + resp.Body.Close() + + if start.URL == "" { + fmt.Fprintln(os.Stderr, "server did not return an auth URL; is OIDC configured on the server?") + os.Exit(1) + } + + fmt.Printf("Open this URL in your browser:\n\n %s\n\nWaiting for login...\n", start.URL) + + deadline := time.Now().Add(5 * time.Minute) + for time.Now().Before(deadline) { + time.Sleep(2 * time.Second) + + resp, err := http.Get(fmt.Sprintf("%s/auth/poll?session_id=%s", base, start.SessionID)) + if err != nil { + continue + } + if resp.StatusCode == http.StatusAccepted { + resp.Body.Close() + continue + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + fmt.Fprintln(os.Stderr, "login failed") + os.Exit(1) + } + + var result struct { + Token string `json:"token"` + Username string `json:"username"` + } + json.NewDecoder(resp.Body).Decode(&result) + resp.Body.Close() + + if err := service.SaveSession(&service.Session{Token: result.Token}); err != nil { + fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err) + os.Exit(1) + } + fmt.Printf("Logged in as %s\n", result.Username) + return + } + + fmt.Fprintln(os.Stderr, "login timed out") + os.Exit(1) + }, +} + +func init() { + rootCmd.AddCommand(loginCmd) +} diff --git a/cmd/serve.go b/cmd/serve.go index c539a39..04bb054 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -16,7 +16,17 @@ var serveCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { sc := cfg.GetServerConfig() addr := fmt.Sprintf("%s:%d", sc.Host, sc.Port) - handler := serve.New(service.GetNodeServiceForUser) + + var oidcCfg *service.OIDCConfig + if oc, ok := cfg.GetOIDCConfig(); ok { + oidcCfg = &oc + } + + handler, err := serve.New(service.GetNodeServiceForUser, oidcCfg) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } fmt.Fprintf(os.Stdout, "listening on %s\n", addr) if err := http.ListenAndServe(addr, handler); err != nil { fmt.Fprintln(os.Stderr, err) diff --git a/go.mod b/go.mod index 0eefd19..dd8cdfc 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,9 @@ require ( ) require ( + github.com/coreos/go-oidc/v3 v3.17.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -18,6 +20,7 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/spf13/pflag v1.0.9 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.42.0 // indirect modernc.org/gc/v3 v3.1.2 // indirect modernc.org/libc v1.70.0 // indirect diff --git a/go.sum b/go.sum index 3aa285b..ff46ffd 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,12 @@ +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -29,6 +33,8 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/serve/auth.go b/serve/auth.go new file mode 100644 index 0000000..58c7138 --- /dev/null +++ b/serve/auth.go @@ -0,0 +1,210 @@ +package serve + +import ( + "axolotl/service" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "sync" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +// pendingLogin tracks an in-progress authorization code flow. +type pendingLogin struct { + verifier string + state string + created time.Time + serverToken string // set by callback when complete; empty while pending +} + +// authHandler owns the OIDC provider connection, the pending login store, +// and the active server-side session map. +type authHandler struct { + mu sync.Mutex + pending map[string]*pendingLogin // loginID → pending state + sessions map[string]string // serverToken → username + + cfg service.OIDCConfig + provider *oidc.Provider + oauth2 oauth2.Config +} + +func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) { + if cfg.PublicURL == "" { + return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server") + } + provider, err := oidc.NewProvider(context.Background(), cfg.Issuer) + if err != nil { + return nil, fmt.Errorf("OIDC provider: %w", err) + } + h := &authHandler{ + pending: make(map[string]*pendingLogin), + sessions: make(map[string]string), + cfg: cfg, + provider: provider, + oauth2: oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: cfg.PublicURL + "/auth/callback", + Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}, + }, + } + go h.cleanup() + return h, nil +} + +func (h *authHandler) cleanup() { + for range time.Tick(5 * time.Minute) { + h.mu.Lock() + for id, p := range h.pending { + if time.Since(p.created) > 15*time.Minute { + delete(h.pending, id) + } + } + h.mu.Unlock() + } +} + +// lookupSession returns the username for a server-issued token, or "". +func (h *authHandler) lookupSession(token string) string { + h.mu.Lock() + defer h.mu.Unlock() + return h.sessions[token] +} + +// POST /auth/start → {url, session_id} +func (h *authHandler) start(w http.ResponseWriter, r *http.Request) { + loginID := randomToken(16) + verifier := randomToken(32) + state := randomToken(16) + + authURL := h.oauth2.AuthCodeURL(state, + oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + ) + + h.mu.Lock() + h.pending[loginID] = &pendingLogin{ + verifier: verifier, + state: state, + created: time.Now(), + } + h.mu.Unlock() + + writeJSON(w, map[string]string{"url": authURL, "session_id": loginID}) +} + +// GET /auth/callback — OIDC provider redirects here after user authenticates. +func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) { + stateParam := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + + h.mu.Lock() + var loginID string + var pending *pendingLogin + for id, p := range h.pending { + if p.state == stateParam { + loginID, pending = id, p + break + } + } + h.mu.Unlock() + + if pending == nil { + http.Error(w, "invalid or expired state", http.StatusBadRequest) + return + } + + token, err := h.oauth2.Exchange(r.Context(), code, + oauth2.SetAuthURLParam("code_verifier", pending.verifier), + ) + if err != nil { + http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest) + return + } + + username, err := h.extractUsername(r.Context(), token) + if err != nil { + http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError) + return + } + + serverToken := randomToken(32) + + h.mu.Lock() + h.sessions[serverToken] = username + if p := h.pending[loginID]; p != nil { + p.serverToken = serverToken + } + h.mu.Unlock() + + fmt.Fprintln(w, "Login successful! You can close this tab.") +} + +// GET /auth/poll?session_id=... +// Returns 202 while pending, 200 {token, username} when done, 404 if expired. +func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) { + loginID := r.URL.Query().Get("session_id") + + h.mu.Lock() + p := h.pending[loginID] + h.mu.Unlock() + + if p == nil { + writeError(w, http.StatusNotFound, "session not found or expired") + return + } + + h.mu.Lock() + serverToken := p.serverToken + if serverToken != "" { + delete(h.pending, loginID) // consume once delivered + } + h.mu.Unlock() + + if serverToken == "" { + w.WriteHeader(http.StatusAccepted) + return + } + + username := h.lookupSession(serverToken) + writeJSON(w, map[string]string{"token": serverToken, "username": username}) +} + +func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) { + rawID, ok := token.Extra("id_token").(string) + if !ok { + return "", fmt.Errorf("no id_token in response") + } + idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID) + if err != nil { + return "", err + } + var claims map[string]any + if err := idToken.Claims(&claims); err != nil { + return "", err + } + user, _ := claims[h.cfg.UserClaim].(string) + if user == "" { + return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim) + } + return user, nil +} + +func randomToken(n int) string { + b := make([]byte, n) + rand.Read(b) //nolint:errcheck + return base64.RawURLEncoding.EncodeToString(b) +} + +func pkceChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} diff --git a/serve/oidc.go b/serve/oidc.go new file mode 100644 index 0000000..b55cd41 --- /dev/null +++ b/serve/oidc.go @@ -0,0 +1,41 @@ +package serve + +import ( + "context" + "net/http" + "strings" +) + +type contextKey string + +const userContextKey contextKey = "ax_user" + +// withSessionAuth wraps a handler with ax session token authentication. +// Auth endpoints (/auth/*) are passed through without a token check. +// All other requests must supply Authorization: Bearer . +func withSessionAuth(ah *authHandler, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/auth/") { + next.ServeHTTP(w, r) + return + } + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + writeError(w, http.StatusUnauthorized, "Bearer token required") + return + } + token := strings.TrimPrefix(auth, "Bearer ") + username := ah.lookupSession(token) + if username == "" { + writeError(w, http.StatusUnauthorized, "invalid or expired session; run 'ax login'") + return + } + ctx := context.WithValue(r.Context(), userContextKey, username) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func userFromContext(r *http.Request) string { + v, _ := r.Context().Value(userContextKey).(string) + return v +} diff --git a/serve/server.go b/serve/server.go index 64827ae..0d50301 100644 --- a/serve/server.go +++ b/serve/server.go @@ -9,8 +9,10 @@ import ( ) // New returns an HTTP handler that exposes NodeService as a JSON API. -// Every request must supply an X-Ax-User header identifying the acting user. -func New(newSvc func(user string) (service.NodeService, error)) http.Handler { +// When oidcCfg is non-nil, every request must carry a valid Bearer token; +// the authenticated username is derived from the token claim configured in +// OIDCConfig.UserClaim. Without OIDC, the X-Ax-User header is used instead. +func New(newSvc func(user string) (service.NodeService, error), oidcCfg *service.OIDCConfig) (http.Handler, error) { s := &server{newSvc: newSvc} mux := http.NewServeMux() mux.HandleFunc("GET /nodes", s.listNodes) @@ -20,7 +22,17 @@ func New(newSvc func(user string) (service.NodeService, error)) http.Handler { mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode) mux.HandleFunc("GET /users", s.listUsers) mux.HandleFunc("POST /users", s.addUser) - return mux + if oidcCfg != nil { + ah, err := newAuthHandler(*oidcCfg) + if err != nil { + return nil, err + } + mux.HandleFunc("POST /auth/start", ah.start) + mux.HandleFunc("GET /auth/callback", ah.callback) + mux.HandleFunc("GET /auth/poll", ah.poll) + return withSessionAuth(ah, mux), nil + } + return mux, nil } type server struct { @@ -28,7 +40,10 @@ type server struct { } func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) { - user := r.Header.Get("X-Ax-User") + user := userFromContext(r) + if user == "" { + user = r.Header.Get("X-Ax-User") + } if user == "" { writeError(w, http.StatusUnauthorized, "X-Ax-User header required") return nil, false diff --git a/service/api_client.go b/service/api_client.go index 6640994..0ab5a48 100644 --- a/service/api_client.go +++ b/service/api_client.go @@ -28,13 +28,27 @@ func (c *apiClient) do(method, path string, body any) (*http.Response, error) { if err != nil { return nil, err } - req.Header.Set("X-Ax-User", c.user) + if err := c.setAuth(req); err != nil { + return nil, err + } if body != nil { req.Header.Set("Content-Type", "application/json") } return c.http.Do(req) } +// setAuth attaches either a Bearer token (when a session exists) or the +// X-Ax-User header (no session / non-OIDC servers). +func (c *apiClient) setAuth(req *http.Request) error { + sess, err := LoadSession() + if err != nil || sess == nil || sess.Token == "" { + req.Header.Set("X-Ax-User", c.user) + return nil + } + req.Header.Set("Authorization", "Bearer "+sess.Token) + return nil +} + func apiDecode[T any](resp *http.Response) (T, error) { var v T defer resp.Body.Close() diff --git a/service/config.go b/service/config.go index 798f5b5..e646a99 100644 --- a/service/config.go +++ b/service/config.go @@ -11,6 +11,16 @@ type ServerConfig struct { Port int `json:"port"` } +type OIDCConfig struct { + Issuer string `json:"issuer"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + // PublicURL is the externally reachable base URL of this server, used to + // construct the OIDC redirect URI (e.g. "https://ax.example.com:7000"). + PublicURL string `json:"public_url"` + UserClaim string `json:"user_claim"` // default "preferred_username" +} + type Config interface { GetUser() string SetUser(username string) error @@ -21,5 +31,7 @@ type Config interface { GetServerConfig() ServerConfig // GetRemoteConfig returns the remote server address and whether remote mode is enabled. GetRemoteConfig() (ServerConfig, bool) + // GetOIDCConfig returns the OIDC configuration and whether OIDC is enabled. + GetOIDCConfig() (OIDCConfig, bool) Save() error } diff --git a/service/config_file.go b/service/config_file.go index 4f6f20d..caaf19c 100644 --- a/service/config_file.go +++ b/service/config_file.go @@ -15,6 +15,7 @@ type fileConfig struct { UserAliases []*Alias `json:"aliases"` Serve ServerConfig `json:"serve"` Remote ServerConfig `json:"remote"` + OIDC OIDCConfig `json:"oidc"` } var defaultAliases = []*Alias{ @@ -142,6 +143,17 @@ func (c *fileConfig) ListAliases() ([]*Alias, error) { return result, nil } +func (c *fileConfig) GetOIDCConfig() (OIDCConfig, bool) { + if c.OIDC.Issuer == "" { + return OIDCConfig{}, false + } + cfg := c.OIDC + if cfg.UserClaim == "" { + cfg.UserClaim = "preferred_username" + } + return cfg, true +} + func (c *fileConfig) GetRemoteConfig() (ServerConfig, bool) { if c.Remote.Host == "" { return ServerConfig{}, false diff --git a/service/node_service.go b/service/node_service.go index e5f7929..10f4061 100644 --- a/service/node_service.go +++ b/service/node_service.go @@ -92,7 +92,7 @@ func GetNodeServiceForUser(user string) (NodeService, error) { if user == "" { return nil, fmt.Errorf("user is required") } - st, err := store.FindAndOpenSQLiteStore() + st, err := store.FindOrInitSQLiteStore() if err != nil { return nil, err } diff --git a/service/session.go b/service/session.go new file mode 100644 index 0000000..d0455ea --- /dev/null +++ b/service/session.go @@ -0,0 +1,67 @@ +package service + +import ( + "encoding/json" + "os" + "path/filepath" +) + +// Session holds the server-issued token returned by POST /auth/poll. +// The ax server owns the full OIDC flow; the client only needs this token. +type Session struct { + Token string `json:"token"` +} + +func sessionPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".config", "ax", "session.json"), nil +} + +func LoadSession() (*Session, error) { + path, err := sessionPath() + if err != nil { + return nil, err + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var s Session + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return &s, nil +} + +func SaveSession(s *Session) error { + path, err := sessionPath() + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return err + } + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func ClearSession() error { + path, err := sessionPath() + if err != nil { + return err + } + err = os.Remove(path) + if os.IsNotExist(err) { + return nil + } + return err +} diff --git a/store/sqlite.go b/store/sqlite.go index 2401a89..8a97cc1 100644 --- a/store/sqlite.go +++ b/store/sqlite.go @@ -83,6 +83,40 @@ func FindAndOpenSQLiteStore() (Store, error) { } } +// FindOrInitSQLiteStore is like FindAndOpenSQLiteStore but intended for server +// mode: if no .ax.db is found it creates and initialises one in the current +// working directory instead of returning an error. +func FindOrInitSQLiteStore() (Store, error) { + if dbpath := os.Getenv("AX_DB_PATH"); dbpath != "" { + if err := InitSQLiteStore(dbpath); err != nil { + return nil, err + } + return NewSQLiteStore(dbpath) + } + dir, err := filepath.Abs(".") + if err != nil { + return nil, err + } + for { + dbpath := filepath.Join(dir, ".ax.db") + if _, err := os.Stat(dbpath); err == nil { + return NewSQLiteStore(dbpath) + } + if parent := filepath.Dir(dir); parent == dir { + break + } else { + dir = parent + } + } + // Not found — create and initialise in CWD. + cwd, _ := filepath.Abs(".") + dbpath := filepath.Join(cwd, ".ax.db") + if err := InitSQLiteStore(dbpath); err != nil { + return nil, err + } + return NewSQLiteStore(dbpath) +} + // NewSQLiteStore opens a SQLite database at the given path, runs a one-time // schema migration if needed, then applies per-connection PRAGMAs. func NewSQLiteStore(path string) (Store, error) {