feat: add OIDC authentication for server mode

This commit is contained in:
2026-04-01 19:33:15 +02:00
parent 7bce56384f
commit 52a975b66d
13 changed files with 515 additions and 7 deletions

84
cmd/login.go Normal file
View File

@@ -0,0 +1,84 @@
package cmd
import (
"axolotl/service"
"encoding/json"
"fmt"
"net/http"
"os"
"time"
"github.com/spf13/cobra"
)
var loginCmd = &cobra.Command{
Use: "login",
Short: "Authenticate with the remote server via OIDC",
Run: func(cmd *cobra.Command, args []string) {
rc, ok := cfg.GetRemoteConfig()
if !ok {
fmt.Fprintln(os.Stderr, "no remote server configured; set remote.host in your config")
os.Exit(1)
}
base := fmt.Sprintf("http://%s:%d", rc.Host, rc.Port)
resp, err := http.Post(base+"/auth/start", "application/json", nil)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to contact server: %v\n", err)
os.Exit(1)
}
var start struct {
URL string `json:"url"`
SessionID string `json:"session_id"`
}
json.NewDecoder(resp.Body).Decode(&start)
resp.Body.Close()
if start.URL == "" {
fmt.Fprintln(os.Stderr, "server did not return an auth URL; is OIDC configured on the server?")
os.Exit(1)
}
fmt.Printf("Open this URL in your browser:\n\n %s\n\nWaiting for login...\n", start.URL)
deadline := time.Now().Add(5 * time.Minute)
for time.Now().Before(deadline) {
time.Sleep(2 * time.Second)
resp, err := http.Get(fmt.Sprintf("%s/auth/poll?session_id=%s", base, start.SessionID))
if err != nil {
continue
}
if resp.StatusCode == http.StatusAccepted {
resp.Body.Close()
continue
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
fmt.Fprintln(os.Stderr, "login failed")
os.Exit(1)
}
var result struct {
Token string `json:"token"`
Username string `json:"username"`
}
json.NewDecoder(resp.Body).Decode(&result)
resp.Body.Close()
if err := service.SaveSession(&service.Session{Token: result.Token}); err != nil {
fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err)
os.Exit(1)
}
fmt.Printf("Logged in as %s\n", result.Username)
return
}
fmt.Fprintln(os.Stderr, "login timed out")
os.Exit(1)
},
}
func init() {
rootCmd.AddCommand(loginCmd)
}

View File

@@ -16,7 +16,17 @@ var serveCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
sc := cfg.GetServerConfig() sc := cfg.GetServerConfig()
addr := fmt.Sprintf("%s:%d", sc.Host, sc.Port) addr := fmt.Sprintf("%s:%d", sc.Host, sc.Port)
handler := serve.New(service.GetNodeServiceForUser)
var oidcCfg *service.OIDCConfig
if oc, ok := cfg.GetOIDCConfig(); ok {
oidcCfg = &oc
}
handler, err := serve.New(service.GetNodeServiceForUser, oidcCfg)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
fmt.Fprintf(os.Stdout, "listening on %s\n", addr) fmt.Fprintf(os.Stdout, "listening on %s\n", addr)
if err := http.ListenAndServe(addr, handler); err != nil { if err := http.ListenAndServe(addr, handler); err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)

3
go.mod
View File

@@ -9,7 +9,9 @@ require (
) )
require ( require (
github.com/coreos/go-oidc/v3 v3.17.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -18,6 +20,7 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/spf13/pflag v1.0.9 // indirect github.com/spf13/pflag v1.0.9 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
modernc.org/gc/v3 v3.1.2 // indirect modernc.org/gc/v3 v3.1.2 // indirect
modernc.org/libc v1.70.0 // indirect modernc.org/libc v1.70.0 // indirect

6
go.sum
View File

@@ -1,8 +1,12 @@
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -29,6 +33,8 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

210
serve/auth.go Normal file
View File

@@ -0,0 +1,210 @@
package serve
import (
"axolotl/service"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
// pendingLogin tracks an in-progress authorization code flow.
type pendingLogin struct {
verifier string
state string
created time.Time
serverToken string // set by callback when complete; empty while pending
}
// authHandler owns the OIDC provider connection, the pending login store,
// and the active server-side session map.
type authHandler struct {
mu sync.Mutex
pending map[string]*pendingLogin // loginID → pending state
sessions map[string]string // serverToken → username
cfg service.OIDCConfig
provider *oidc.Provider
oauth2 oauth2.Config
}
func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) {
if cfg.PublicURL == "" {
return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server")
}
provider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("OIDC provider: %w", err)
}
h := &authHandler{
pending: make(map[string]*pendingLogin),
sessions: make(map[string]string),
cfg: cfg,
provider: provider,
oauth2: oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: cfg.PublicURL + "/auth/callback",
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
},
}
go h.cleanup()
return h, nil
}
func (h *authHandler) cleanup() {
for range time.Tick(5 * time.Minute) {
h.mu.Lock()
for id, p := range h.pending {
if time.Since(p.created) > 15*time.Minute {
delete(h.pending, id)
}
}
h.mu.Unlock()
}
}
// lookupSession returns the username for a server-issued token, or "".
func (h *authHandler) lookupSession(token string) string {
h.mu.Lock()
defer h.mu.Unlock()
return h.sessions[token]
}
// POST /auth/start → {url, session_id}
func (h *authHandler) start(w http.ResponseWriter, r *http.Request) {
loginID := randomToken(16)
verifier := randomToken(32)
state := randomToken(16)
authURL := h.oauth2.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
h.mu.Lock()
h.pending[loginID] = &pendingLogin{
verifier: verifier,
state: state,
created: time.Now(),
}
h.mu.Unlock()
writeJSON(w, map[string]string{"url": authURL, "session_id": loginID})
}
// GET /auth/callback — OIDC provider redirects here after user authenticates.
func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) {
stateParam := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")
h.mu.Lock()
var loginID string
var pending *pendingLogin
for id, p := range h.pending {
if p.state == stateParam {
loginID, pending = id, p
break
}
}
h.mu.Unlock()
if pending == nil {
http.Error(w, "invalid or expired state", http.StatusBadRequest)
return
}
token, err := h.oauth2.Exchange(r.Context(), code,
oauth2.SetAuthURLParam("code_verifier", pending.verifier),
)
if err != nil {
http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest)
return
}
username, err := h.extractUsername(r.Context(), token)
if err != nil {
http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError)
return
}
serverToken := randomToken(32)
h.mu.Lock()
h.sessions[serverToken] = username
if p := h.pending[loginID]; p != nil {
p.serverToken = serverToken
}
h.mu.Unlock()
fmt.Fprintln(w, "Login successful! You can close this tab.")
}
// GET /auth/poll?session_id=...
// Returns 202 while pending, 200 {token, username} when done, 404 if expired.
func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) {
loginID := r.URL.Query().Get("session_id")
h.mu.Lock()
p := h.pending[loginID]
h.mu.Unlock()
if p == nil {
writeError(w, http.StatusNotFound, "session not found or expired")
return
}
h.mu.Lock()
serverToken := p.serverToken
if serverToken != "" {
delete(h.pending, loginID) // consume once delivered
}
h.mu.Unlock()
if serverToken == "" {
w.WriteHeader(http.StatusAccepted)
return
}
username := h.lookupSession(serverToken)
writeJSON(w, map[string]string{"token": serverToken, "username": username})
}
func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) {
rawID, ok := token.Extra("id_token").(string)
if !ok {
return "", fmt.Errorf("no id_token in response")
}
idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID)
if err != nil {
return "", err
}
var claims map[string]any
if err := idToken.Claims(&claims); err != nil {
return "", err
}
user, _ := claims[h.cfg.UserClaim].(string)
if user == "" {
return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim)
}
return user, nil
}
func randomToken(n int) string {
b := make([]byte, n)
rand.Read(b) //nolint:errcheck
return base64.RawURLEncoding.EncodeToString(b)
}
func pkceChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}

41
serve/oidc.go Normal file
View File

@@ -0,0 +1,41 @@
package serve
import (
"context"
"net/http"
"strings"
)
type contextKey string
const userContextKey contextKey = "ax_user"
// withSessionAuth wraps a handler with ax session token authentication.
// Auth endpoints (/auth/*) are passed through without a token check.
// All other requests must supply Authorization: Bearer <server_token>.
func withSessionAuth(ah *authHandler, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/auth/") {
next.ServeHTTP(w, r)
return
}
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
writeError(w, http.StatusUnauthorized, "Bearer token required")
return
}
token := strings.TrimPrefix(auth, "Bearer ")
username := ah.lookupSession(token)
if username == "" {
writeError(w, http.StatusUnauthorized, "invalid or expired session; run 'ax login'")
return
}
ctx := context.WithValue(r.Context(), userContextKey, username)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func userFromContext(r *http.Request) string {
v, _ := r.Context().Value(userContextKey).(string)
return v
}

View File

@@ -9,8 +9,10 @@ import (
) )
// New returns an HTTP handler that exposes NodeService as a JSON API. // New returns an HTTP handler that exposes NodeService as a JSON API.
// Every request must supply an X-Ax-User header identifying the acting user. // When oidcCfg is non-nil, every request must carry a valid Bearer token;
func New(newSvc func(user string) (service.NodeService, error)) http.Handler { // the authenticated username is derived from the token claim configured in
// OIDCConfig.UserClaim. Without OIDC, the X-Ax-User header is used instead.
func New(newSvc func(user string) (service.NodeService, error), oidcCfg *service.OIDCConfig) (http.Handler, error) {
s := &server{newSvc: newSvc} s := &server{newSvc: newSvc}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /nodes", s.listNodes) mux.HandleFunc("GET /nodes", s.listNodes)
@@ -20,7 +22,17 @@ func New(newSvc func(user string) (service.NodeService, error)) http.Handler {
mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode) mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode)
mux.HandleFunc("GET /users", s.listUsers) mux.HandleFunc("GET /users", s.listUsers)
mux.HandleFunc("POST /users", s.addUser) mux.HandleFunc("POST /users", s.addUser)
return mux if oidcCfg != nil {
ah, err := newAuthHandler(*oidcCfg)
if err != nil {
return nil, err
}
mux.HandleFunc("POST /auth/start", ah.start)
mux.HandleFunc("GET /auth/callback", ah.callback)
mux.HandleFunc("GET /auth/poll", ah.poll)
return withSessionAuth(ah, mux), nil
}
return mux, nil
} }
type server struct { type server struct {
@@ -28,7 +40,10 @@ type server struct {
} }
func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) { func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) {
user := r.Header.Get("X-Ax-User") user := userFromContext(r)
if user == "" {
user = r.Header.Get("X-Ax-User")
}
if user == "" { if user == "" {
writeError(w, http.StatusUnauthorized, "X-Ax-User header required") writeError(w, http.StatusUnauthorized, "X-Ax-User header required")
return nil, false return nil, false

View File

@@ -28,13 +28,27 @@ func (c *apiClient) do(method, path string, body any) (*http.Response, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("X-Ax-User", c.user) if err := c.setAuth(req); err != nil {
return nil, err
}
if body != nil { if body != nil {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
} }
return c.http.Do(req) return c.http.Do(req)
} }
// setAuth attaches either a Bearer token (when a session exists) or the
// X-Ax-User header (no session / non-OIDC servers).
func (c *apiClient) setAuth(req *http.Request) error {
sess, err := LoadSession()
if err != nil || sess == nil || sess.Token == "" {
req.Header.Set("X-Ax-User", c.user)
return nil
}
req.Header.Set("Authorization", "Bearer "+sess.Token)
return nil
}
func apiDecode[T any](resp *http.Response) (T, error) { func apiDecode[T any](resp *http.Response) (T, error) {
var v T var v T
defer resp.Body.Close() defer resp.Body.Close()

View File

@@ -11,6 +11,16 @@ type ServerConfig struct {
Port int `json:"port"` Port int `json:"port"`
} }
type OIDCConfig struct {
Issuer string `json:"issuer"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
// PublicURL is the externally reachable base URL of this server, used to
// construct the OIDC redirect URI (e.g. "https://ax.example.com:7000").
PublicURL string `json:"public_url"`
UserClaim string `json:"user_claim"` // default "preferred_username"
}
type Config interface { type Config interface {
GetUser() string GetUser() string
SetUser(username string) error SetUser(username string) error
@@ -21,5 +31,7 @@ type Config interface {
GetServerConfig() ServerConfig GetServerConfig() ServerConfig
// GetRemoteConfig returns the remote server address and whether remote mode is enabled. // GetRemoteConfig returns the remote server address and whether remote mode is enabled.
GetRemoteConfig() (ServerConfig, bool) GetRemoteConfig() (ServerConfig, bool)
// GetOIDCConfig returns the OIDC configuration and whether OIDC is enabled.
GetOIDCConfig() (OIDCConfig, bool)
Save() error Save() error
} }

View File

@@ -15,6 +15,7 @@ type fileConfig struct {
UserAliases []*Alias `json:"aliases"` UserAliases []*Alias `json:"aliases"`
Serve ServerConfig `json:"serve"` Serve ServerConfig `json:"serve"`
Remote ServerConfig `json:"remote"` Remote ServerConfig `json:"remote"`
OIDC OIDCConfig `json:"oidc"`
} }
var defaultAliases = []*Alias{ var defaultAliases = []*Alias{
@@ -142,6 +143,17 @@ func (c *fileConfig) ListAliases() ([]*Alias, error) {
return result, nil return result, nil
} }
func (c *fileConfig) GetOIDCConfig() (OIDCConfig, bool) {
if c.OIDC.Issuer == "" {
return OIDCConfig{}, false
}
cfg := c.OIDC
if cfg.UserClaim == "" {
cfg.UserClaim = "preferred_username"
}
return cfg, true
}
func (c *fileConfig) GetRemoteConfig() (ServerConfig, bool) { func (c *fileConfig) GetRemoteConfig() (ServerConfig, bool) {
if c.Remote.Host == "" { if c.Remote.Host == "" {
return ServerConfig{}, false return ServerConfig{}, false

View File

@@ -92,7 +92,7 @@ func GetNodeServiceForUser(user string) (NodeService, error) {
if user == "" { if user == "" {
return nil, fmt.Errorf("user is required") return nil, fmt.Errorf("user is required")
} }
st, err := store.FindAndOpenSQLiteStore() st, err := store.FindOrInitSQLiteStore()
if err != nil { if err != nil {
return nil, err return nil, err
} }

67
service/session.go Normal file
View File

@@ -0,0 +1,67 @@
package service
import (
"encoding/json"
"os"
"path/filepath"
)
// Session holds the server-issued token returned by POST /auth/poll.
// The ax server owns the full OIDC flow; the client only needs this token.
type Session struct {
Token string `json:"token"`
}
func sessionPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".config", "ax", "session.json"), nil
}
func LoadSession() (*Session, error) {
path, err := sessionPath()
if err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var s Session
if err := json.Unmarshal(data, &s); err != nil {
return nil, err
}
return &s, nil
}
func SaveSession(s *Session) error {
path, err := sessionPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
data, err := json.MarshalIndent(s, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func ClearSession() error {
path, err := sessionPath()
if err != nil {
return err
}
err = os.Remove(path)
if os.IsNotExist(err) {
return nil
}
return err
}

View File

@@ -83,6 +83,40 @@ func FindAndOpenSQLiteStore() (Store, error) {
} }
} }
// FindOrInitSQLiteStore is like FindAndOpenSQLiteStore but intended for server
// mode: if no .ax.db is found it creates and initialises one in the current
// working directory instead of returning an error.
func FindOrInitSQLiteStore() (Store, error) {
if dbpath := os.Getenv("AX_DB_PATH"); dbpath != "" {
if err := InitSQLiteStore(dbpath); err != nil {
return nil, err
}
return NewSQLiteStore(dbpath)
}
dir, err := filepath.Abs(".")
if err != nil {
return nil, err
}
for {
dbpath := filepath.Join(dir, ".ax.db")
if _, err := os.Stat(dbpath); err == nil {
return NewSQLiteStore(dbpath)
}
if parent := filepath.Dir(dir); parent == dir {
break
} else {
dir = parent
}
}
// Not found — create and initialise in CWD.
cwd, _ := filepath.Abs(".")
dbpath := filepath.Join(cwd, ".ax.db")
if err := InitSQLiteStore(dbpath); err != nil {
return nil, err
}
return NewSQLiteStore(dbpath)
}
// NewSQLiteStore opens a SQLite database at the given path, runs a one-time // NewSQLiteStore opens a SQLite database at the given path, runs a one-time
// schema migration if needed, then applies per-connection PRAGMAs. // schema migration if needed, then applies per-connection PRAGMAs.
func NewSQLiteStore(path string) (Store, error) { func NewSQLiteStore(path string) (Store, error) {