feat: add OIDC authentication for server mode
This commit is contained in:
84
cmd/login.go
Normal file
84
cmd/login.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"axolotl/service"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var loginCmd = &cobra.Command{
|
||||||
|
Use: "login",
|
||||||
|
Short: "Authenticate with the remote server via OIDC",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
rc, ok := cfg.GetRemoteConfig()
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintln(os.Stderr, "no remote server configured; set remote.host in your config")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
base := fmt.Sprintf("http://%s:%d", rc.Host, rc.Port)
|
||||||
|
|
||||||
|
resp, err := http.Post(base+"/auth/start", "application/json", nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "failed to contact server: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
var start struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&start)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if start.URL == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "server did not return an auth URL; is OIDC configured on the server?")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Open this URL in your browser:\n\n %s\n\nWaiting for login...\n", start.URL)
|
||||||
|
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
resp, err := http.Get(fmt.Sprintf("%s/auth/poll?session_id=%s", base, start.SessionID))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusAccepted {
|
||||||
|
resp.Body.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
resp.Body.Close()
|
||||||
|
fmt.Fprintln(os.Stderr, "login failed")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if err := service.SaveSession(&service.Session{Token: result.Token}); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("Logged in as %s\n", result.Username)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr, "login timed out")
|
||||||
|
os.Exit(1)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(loginCmd)
|
||||||
|
}
|
||||||
12
cmd/serve.go
12
cmd/serve.go
@@ -16,7 +16,17 @@ var serveCmd = &cobra.Command{
|
|||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
sc := cfg.GetServerConfig()
|
sc := cfg.GetServerConfig()
|
||||||
addr := fmt.Sprintf("%s:%d", sc.Host, sc.Port)
|
addr := fmt.Sprintf("%s:%d", sc.Host, sc.Port)
|
||||||
handler := serve.New(service.GetNodeServiceForUser)
|
|
||||||
|
var oidcCfg *service.OIDCConfig
|
||||||
|
if oc, ok := cfg.GetOIDCConfig(); ok {
|
||||||
|
oidcCfg = &oc
|
||||||
|
}
|
||||||
|
|
||||||
|
handler, err := serve.New(service.GetNodeServiceForUser, oidcCfg)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
fmt.Fprintf(os.Stdout, "listening on %s\n", addr)
|
fmt.Fprintf(os.Stdout, "listening on %s\n", addr)
|
||||||
if err := http.ListenAndServe(addr, handler); err != nil {
|
if err := http.ListenAndServe(addr, handler); err != nil {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -9,7 +9,9 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/coreos/go-oidc/v3 v3.17.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
@@ -18,6 +20,7 @@ require (
|
|||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/spf13/pflag v1.0.9 // indirect
|
github.com/spf13/pflag v1.0.9 // indirect
|
||||||
|
golang.org/x/oauth2 v0.36.0 // indirect
|
||||||
golang.org/x/sys v0.42.0 // indirect
|
golang.org/x/sys v0.42.0 // indirect
|
||||||
modernc.org/gc/v3 v3.1.2 // indirect
|
modernc.org/gc/v3 v3.1.2 // indirect
|
||||||
modernc.org/libc v1.70.0 // indirect
|
modernc.org/libc v1.70.0 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -1,8 +1,12 @@
|
|||||||
|
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
|
||||||
|
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
@@ -29,6 +33,8 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
|
|||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||||
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|||||||
210
serve/auth.go
Normal file
210
serve/auth.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package serve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"axolotl/service"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pendingLogin tracks an in-progress authorization code flow.
|
||||||
|
type pendingLogin struct {
|
||||||
|
verifier string
|
||||||
|
state string
|
||||||
|
created time.Time
|
||||||
|
serverToken string // set by callback when complete; empty while pending
|
||||||
|
}
|
||||||
|
|
||||||
|
// authHandler owns the OIDC provider connection, the pending login store,
|
||||||
|
// and the active server-side session map.
|
||||||
|
type authHandler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
pending map[string]*pendingLogin // loginID → pending state
|
||||||
|
sessions map[string]string // serverToken → username
|
||||||
|
|
||||||
|
cfg service.OIDCConfig
|
||||||
|
provider *oidc.Provider
|
||||||
|
oauth2 oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) {
|
||||||
|
if cfg.PublicURL == "" {
|
||||||
|
return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server")
|
||||||
|
}
|
||||||
|
provider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("OIDC provider: %w", err)
|
||||||
|
}
|
||||||
|
h := &authHandler{
|
||||||
|
pending: make(map[string]*pendingLogin),
|
||||||
|
sessions: make(map[string]string),
|
||||||
|
cfg: cfg,
|
||||||
|
provider: provider,
|
||||||
|
oauth2: oauth2.Config{
|
||||||
|
ClientID: cfg.ClientID,
|
||||||
|
ClientSecret: cfg.ClientSecret,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: cfg.PublicURL + "/auth/callback",
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go h.cleanup()
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) cleanup() {
|
||||||
|
for range time.Tick(5 * time.Minute) {
|
||||||
|
h.mu.Lock()
|
||||||
|
for id, p := range h.pending {
|
||||||
|
if time.Since(p.created) > 15*time.Minute {
|
||||||
|
delete(h.pending, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupSession returns the username for a server-issued token, or "".
|
||||||
|
func (h *authHandler) lookupSession(token string) string {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
return h.sessions[token]
|
||||||
|
}
|
||||||
|
|
||||||
|
// POST /auth/start → {url, session_id}
|
||||||
|
func (h *authHandler) start(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginID := randomToken(16)
|
||||||
|
verifier := randomToken(32)
|
||||||
|
state := randomToken(16)
|
||||||
|
|
||||||
|
authURL := h.oauth2.AuthCodeURL(state,
|
||||||
|
oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)),
|
||||||
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
h.pending[loginID] = &pendingLogin{
|
||||||
|
verifier: verifier,
|
||||||
|
state: state,
|
||||||
|
created: time.Now(),
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
writeJSON(w, map[string]string{"url": authURL, "session_id": loginID})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET /auth/callback — OIDC provider redirects here after user authenticates.
|
||||||
|
func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
stateParam := r.URL.Query().Get("state")
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
var loginID string
|
||||||
|
var pending *pendingLogin
|
||||||
|
for id, p := range h.pending {
|
||||||
|
if p.state == stateParam {
|
||||||
|
loginID, pending = id, p
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
if pending == nil {
|
||||||
|
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := h.oauth2.Exchange(r.Context(), code,
|
||||||
|
oauth2.SetAuthURLParam("code_verifier", pending.verifier),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username, err := h.extractUsername(r.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverToken := randomToken(32)
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
h.sessions[serverToken] = username
|
||||||
|
if p := h.pending[loginID]; p != nil {
|
||||||
|
p.serverToken = serverToken
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
fmt.Fprintln(w, "Login successful! You can close this tab.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET /auth/poll?session_id=...
|
||||||
|
// Returns 202 while pending, 200 {token, username} when done, 404 if expired.
|
||||||
|
func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginID := r.URL.Query().Get("session_id")
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
p := h.pending[loginID]
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
if p == nil {
|
||||||
|
writeError(w, http.StatusNotFound, "session not found or expired")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
serverToken := p.serverToken
|
||||||
|
if serverToken != "" {
|
||||||
|
delete(h.pending, loginID) // consume once delivered
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
if serverToken == "" {
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username := h.lookupSession(serverToken)
|
||||||
|
writeJSON(w, map[string]string{"token": serverToken, "username": username})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) {
|
||||||
|
rawID, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("no id_token in response")
|
||||||
|
}
|
||||||
|
idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var claims map[string]any
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
user, _ := claims[h.cfg.UserClaim].(string)
|
||||||
|
if user == "" {
|
||||||
|
return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim)
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomToken(n int) string {
|
||||||
|
b := make([]byte, n)
|
||||||
|
rand.Read(b) //nolint:errcheck
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pkceChallenge(verifier string) string {
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
}
|
||||||
41
serve/oidc.go
Normal file
41
serve/oidc.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package serve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const userContextKey contextKey = "ax_user"
|
||||||
|
|
||||||
|
// withSessionAuth wraps a handler with ax session token authentication.
|
||||||
|
// Auth endpoints (/auth/*) are passed through without a token check.
|
||||||
|
// All other requests must supply Authorization: Bearer <server_token>.
|
||||||
|
func withSessionAuth(ah *authHandler, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/auth/") {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
if !strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
writeError(w, http.StatusUnauthorized, "Bearer token required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
username := ah.lookupSession(token)
|
||||||
|
if username == "" {
|
||||||
|
writeError(w, http.StatusUnauthorized, "invalid or expired session; run 'ax login'")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(r.Context(), userContextKey, username)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func userFromContext(r *http.Request) string {
|
||||||
|
v, _ := r.Context().Value(userContextKey).(string)
|
||||||
|
return v
|
||||||
|
}
|
||||||
@@ -9,8 +9,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// New returns an HTTP handler that exposes NodeService as a JSON API.
|
// New returns an HTTP handler that exposes NodeService as a JSON API.
|
||||||
// Every request must supply an X-Ax-User header identifying the acting user.
|
// When oidcCfg is non-nil, every request must carry a valid Bearer token;
|
||||||
func New(newSvc func(user string) (service.NodeService, error)) http.Handler {
|
// the authenticated username is derived from the token claim configured in
|
||||||
|
// OIDCConfig.UserClaim. Without OIDC, the X-Ax-User header is used instead.
|
||||||
|
func New(newSvc func(user string) (service.NodeService, error), oidcCfg *service.OIDCConfig) (http.Handler, error) {
|
||||||
s := &server{newSvc: newSvc}
|
s := &server{newSvc: newSvc}
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /nodes", s.listNodes)
|
mux.HandleFunc("GET /nodes", s.listNodes)
|
||||||
@@ -20,7 +22,17 @@ func New(newSvc func(user string) (service.NodeService, error)) http.Handler {
|
|||||||
mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode)
|
mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode)
|
||||||
mux.HandleFunc("GET /users", s.listUsers)
|
mux.HandleFunc("GET /users", s.listUsers)
|
||||||
mux.HandleFunc("POST /users", s.addUser)
|
mux.HandleFunc("POST /users", s.addUser)
|
||||||
return mux
|
if oidcCfg != nil {
|
||||||
|
ah, err := newAuthHandler(*oidcCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mux.HandleFunc("POST /auth/start", ah.start)
|
||||||
|
mux.HandleFunc("GET /auth/callback", ah.callback)
|
||||||
|
mux.HandleFunc("GET /auth/poll", ah.poll)
|
||||||
|
return withSessionAuth(ah, mux), nil
|
||||||
|
}
|
||||||
|
return mux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
@@ -28,7 +40,10 @@ type server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) {
|
func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) {
|
||||||
user := r.Header.Get("X-Ax-User")
|
user := userFromContext(r)
|
||||||
|
if user == "" {
|
||||||
|
user = r.Header.Get("X-Ax-User")
|
||||||
|
}
|
||||||
if user == "" {
|
if user == "" {
|
||||||
writeError(w, http.StatusUnauthorized, "X-Ax-User header required")
|
writeError(w, http.StatusUnauthorized, "X-Ax-User header required")
|
||||||
return nil, false
|
return nil, false
|
||||||
|
|||||||
@@ -28,13 +28,27 @@ func (c *apiClient) do(method, path string, body any) (*http.Response, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("X-Ax-User", c.user)
|
if err := c.setAuth(req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if body != nil {
|
if body != nil {
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
}
|
}
|
||||||
return c.http.Do(req)
|
return c.http.Do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setAuth attaches either a Bearer token (when a session exists) or the
|
||||||
|
// X-Ax-User header (no session / non-OIDC servers).
|
||||||
|
func (c *apiClient) setAuth(req *http.Request) error {
|
||||||
|
sess, err := LoadSession()
|
||||||
|
if err != nil || sess == nil || sess.Token == "" {
|
||||||
|
req.Header.Set("X-Ax-User", c.user)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+sess.Token)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func apiDecode[T any](resp *http.Response) (T, error) {
|
func apiDecode[T any](resp *http.Response) (T, error) {
|
||||||
var v T
|
var v T
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|||||||
@@ -11,6 +11,16 @@ type ServerConfig struct {
|
|||||||
Port int `json:"port"`
|
Port int `json:"port"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OIDCConfig struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
// PublicURL is the externally reachable base URL of this server, used to
|
||||||
|
// construct the OIDC redirect URI (e.g. "https://ax.example.com:7000").
|
||||||
|
PublicURL string `json:"public_url"`
|
||||||
|
UserClaim string `json:"user_claim"` // default "preferred_username"
|
||||||
|
}
|
||||||
|
|
||||||
type Config interface {
|
type Config interface {
|
||||||
GetUser() string
|
GetUser() string
|
||||||
SetUser(username string) error
|
SetUser(username string) error
|
||||||
@@ -21,5 +31,7 @@ type Config interface {
|
|||||||
GetServerConfig() ServerConfig
|
GetServerConfig() ServerConfig
|
||||||
// GetRemoteConfig returns the remote server address and whether remote mode is enabled.
|
// GetRemoteConfig returns the remote server address and whether remote mode is enabled.
|
||||||
GetRemoteConfig() (ServerConfig, bool)
|
GetRemoteConfig() (ServerConfig, bool)
|
||||||
|
// GetOIDCConfig returns the OIDC configuration and whether OIDC is enabled.
|
||||||
|
GetOIDCConfig() (OIDCConfig, bool)
|
||||||
Save() error
|
Save() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ type fileConfig struct {
|
|||||||
UserAliases []*Alias `json:"aliases"`
|
UserAliases []*Alias `json:"aliases"`
|
||||||
Serve ServerConfig `json:"serve"`
|
Serve ServerConfig `json:"serve"`
|
||||||
Remote ServerConfig `json:"remote"`
|
Remote ServerConfig `json:"remote"`
|
||||||
|
OIDC OIDCConfig `json:"oidc"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultAliases = []*Alias{
|
var defaultAliases = []*Alias{
|
||||||
@@ -142,6 +143,17 @@ func (c *fileConfig) ListAliases() ([]*Alias, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *fileConfig) GetOIDCConfig() (OIDCConfig, bool) {
|
||||||
|
if c.OIDC.Issuer == "" {
|
||||||
|
return OIDCConfig{}, false
|
||||||
|
}
|
||||||
|
cfg := c.OIDC
|
||||||
|
if cfg.UserClaim == "" {
|
||||||
|
cfg.UserClaim = "preferred_username"
|
||||||
|
}
|
||||||
|
return cfg, true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *fileConfig) GetRemoteConfig() (ServerConfig, bool) {
|
func (c *fileConfig) GetRemoteConfig() (ServerConfig, bool) {
|
||||||
if c.Remote.Host == "" {
|
if c.Remote.Host == "" {
|
||||||
return ServerConfig{}, false
|
return ServerConfig{}, false
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func GetNodeServiceForUser(user string) (NodeService, error) {
|
|||||||
if user == "" {
|
if user == "" {
|
||||||
return nil, fmt.Errorf("user is required")
|
return nil, fmt.Errorf("user is required")
|
||||||
}
|
}
|
||||||
st, err := store.FindAndOpenSQLiteStore()
|
st, err := store.FindOrInitSQLiteStore()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
67
service/session.go
Normal file
67
service/session.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Session holds the server-issued token returned by POST /auth/poll.
|
||||||
|
// The ax server owns the full OIDC flow; the client only needs this token.
|
||||||
|
type Session struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".config", "ax", "session.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadSession() (*Session, error) {
|
||||||
|
path, err := sessionPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var s Session
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveSession(s *Session) error {
|
||||||
|
path, err := sessionPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := json.MarshalIndent(s, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, data, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClearSession() error {
|
||||||
|
path, err := sessionPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = os.Remove(path)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -83,6 +83,40 @@ func FindAndOpenSQLiteStore() (Store, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FindOrInitSQLiteStore is like FindAndOpenSQLiteStore but intended for server
|
||||||
|
// mode: if no .ax.db is found it creates and initialises one in the current
|
||||||
|
// working directory instead of returning an error.
|
||||||
|
func FindOrInitSQLiteStore() (Store, error) {
|
||||||
|
if dbpath := os.Getenv("AX_DB_PATH"); dbpath != "" {
|
||||||
|
if err := InitSQLiteStore(dbpath); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewSQLiteStore(dbpath)
|
||||||
|
}
|
||||||
|
dir, err := filepath.Abs(".")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
dbpath := filepath.Join(dir, ".ax.db")
|
||||||
|
if _, err := os.Stat(dbpath); err == nil {
|
||||||
|
return NewSQLiteStore(dbpath)
|
||||||
|
}
|
||||||
|
if parent := filepath.Dir(dir); parent == dir {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
dir = parent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Not found — create and initialise in CWD.
|
||||||
|
cwd, _ := filepath.Abs(".")
|
||||||
|
dbpath := filepath.Join(cwd, ".ax.db")
|
||||||
|
if err := InitSQLiteStore(dbpath); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewSQLiteStore(dbpath)
|
||||||
|
}
|
||||||
|
|
||||||
// NewSQLiteStore opens a SQLite database at the given path, runs a one-time
|
// NewSQLiteStore opens a SQLite database at the given path, runs a one-time
|
||||||
// schema migration if needed, then applies per-connection PRAGMAs.
|
// schema migration if needed, then applies per-connection PRAGMAs.
|
||||||
func NewSQLiteStore(path string) (Store, error) {
|
func NewSQLiteStore(path string) (Store, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user