move src file to seperate direcotry
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
package serve
|
||||
|
||||
import (
|
||||
"axolotl/service"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// pendingLogin tracks an in-progress authorization code flow.
|
||||
type pendingLogin struct {
|
||||
verifier string
|
||||
state string
|
||||
created time.Time
|
||||
serverToken string // set by callback when complete; empty while pending
|
||||
}
|
||||
|
||||
// authHandler owns the OIDC provider connection, the pending login store,
|
||||
// and the active server-side session map.
|
||||
type authHandler struct {
|
||||
mu sync.Mutex
|
||||
pending map[string]*pendingLogin // loginID → pending state
|
||||
sessions map[string]string // serverToken → username
|
||||
|
||||
cfg service.OIDCConfig
|
||||
provider *oidc.Provider
|
||||
oauth2 oauth2.Config
|
||||
}
|
||||
|
||||
func newAuthHandler(cfg service.OIDCConfig) (*authHandler, error) {
|
||||
if cfg.PublicURL == "" {
|
||||
return nil, fmt.Errorf("oidc.public_url must be set to the externally reachable base URL of this server")
|
||||
}
|
||||
provider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OIDC provider: %w", err)
|
||||
}
|
||||
h := &authHandler{
|
||||
pending: make(map[string]*pendingLogin),
|
||||
sessions: make(map[string]string),
|
||||
cfg: cfg,
|
||||
provider: provider,
|
||||
oauth2: oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: cfg.PublicURL + "/auth/callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
|
||||
},
|
||||
}
|
||||
go h.cleanup()
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *authHandler) cleanup() {
|
||||
for range time.Tick(5 * time.Minute) {
|
||||
h.mu.Lock()
|
||||
for id, p := range h.pending {
|
||||
if time.Since(p.created) > 15*time.Minute {
|
||||
delete(h.pending, id)
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// lookupSession returns the username for a server-issued token, or "".
|
||||
func (h *authHandler) lookupSession(token string) string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.sessions[token]
|
||||
}
|
||||
|
||||
// POST /auth/start → {url, session_id}
|
||||
func (h *authHandler) start(w http.ResponseWriter, r *http.Request) {
|
||||
loginID := randomToken(16)
|
||||
verifier := randomToken(32)
|
||||
state := randomToken(16)
|
||||
|
||||
authURL := h.oauth2.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", pkceChallenge(verifier)),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
|
||||
h.mu.Lock()
|
||||
h.pending[loginID] = &pendingLogin{
|
||||
verifier: verifier,
|
||||
state: state,
|
||||
created: time.Now(),
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
writeJSON(w, map[string]string{"url": authURL, "session_id": loginID})
|
||||
}
|
||||
|
||||
// GET /auth/callback — OIDC provider redirects here after user authenticates.
|
||||
func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) {
|
||||
stateParam := r.URL.Query().Get("state")
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
h.mu.Lock()
|
||||
var loginID string
|
||||
var pending *pendingLogin
|
||||
for id, p := range h.pending {
|
||||
if p.state == stateParam {
|
||||
loginID, pending = id, p
|
||||
break
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if pending == nil {
|
||||
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.oauth2.Exchange(r.Context(), code,
|
||||
oauth2.SetAuthURLParam("code_verifier", pending.verifier),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username, err := h.extractUsername(r.Context(), token)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
serverToken := randomToken(32)
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[serverToken] = username
|
||||
if p := h.pending[loginID]; p != nil {
|
||||
p.serverToken = serverToken
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
fmt.Fprintln(w, "Login successful! You can close this tab.")
|
||||
}
|
||||
|
||||
// GET /auth/poll?session_id=...
|
||||
// Returns 202 while pending, 200 {token, username} when done, 404 if expired.
|
||||
func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) {
|
||||
loginID := r.URL.Query().Get("session_id")
|
||||
|
||||
h.mu.Lock()
|
||||
p := h.pending[loginID]
|
||||
h.mu.Unlock()
|
||||
|
||||
if p == nil {
|
||||
writeError(w, http.StatusNotFound, "session not found or expired")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
serverToken := p.serverToken
|
||||
if serverToken != "" {
|
||||
delete(h.pending, loginID) // consume once delivered
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if serverToken == "" {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
username := h.lookupSession(serverToken)
|
||||
writeJSON(w, map[string]string{"token": serverToken, "username": username})
|
||||
}
|
||||
|
||||
func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) {
|
||||
rawID, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no id_token in response")
|
||||
}
|
||||
idToken, err := h.provider.Verifier(&oidc.Config{ClientID: h.cfg.ClientID}).Verify(ctx, rawID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var claims map[string]any
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
user, _ := claims[h.cfg.UserClaim].(string)
|
||||
if user == "" {
|
||||
return "", fmt.Errorf("claim %q not found in token", h.cfg.UserClaim)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func randomToken(n int) string {
|
||||
b := make([]byte, n)
|
||||
rand.Read(b) //nolint:errcheck
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func pkceChallenge(verifier string) string {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package serve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userContextKey contextKey = "ax_user"
|
||||
|
||||
// withSessionAuth wraps a handler with ax session token authentication.
|
||||
// Auth endpoints (/auth/*) are passed through without a token check.
|
||||
// All other requests must supply Authorization: Bearer <server_token>.
|
||||
func withSessionAuth(ah *authHandler, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/auth/") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
auth := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
writeError(w, http.StatusUnauthorized, "Bearer token required")
|
||||
return
|
||||
}
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
username := ah.lookupSession(token)
|
||||
if username == "" {
|
||||
writeError(w, http.StatusUnauthorized, "invalid or expired session; run 'ax login'")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userContextKey, username)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func userFromContext(r *http.Request) string {
|
||||
v, _ := r.Context().Value(userContextKey).(string)
|
||||
return v
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package serve
|
||||
|
||||
import (
|
||||
"axolotl/models"
|
||||
"axolotl/service"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// New returns an HTTP handler that exposes NodeService as a JSON API.
|
||||
// When oidcCfg is non-nil, every request must carry a valid Bearer token;
|
||||
// the authenticated username is derived from the token claim configured in
|
||||
// OIDCConfig.UserClaim. Without OIDC, the X-Ax-User header is used instead.
|
||||
func New(newSvc func(user string) (service.NodeService, error), oidcCfg *service.OIDCConfig) (http.Handler, error) {
|
||||
s := &server{newSvc: newSvc}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /nodes", s.listNodes)
|
||||
mux.HandleFunc("POST /nodes", s.addNode)
|
||||
mux.HandleFunc("GET /nodes/{id}", s.getNode)
|
||||
mux.HandleFunc("PATCH /nodes/{id}", s.updateNode)
|
||||
mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode)
|
||||
mux.HandleFunc("GET /users", s.listUsers)
|
||||
mux.HandleFunc("POST /users", s.addUser)
|
||||
if oidcCfg != nil {
|
||||
ah, err := newAuthHandler(*oidcCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mux.HandleFunc("POST /auth/start", ah.start)
|
||||
mux.HandleFunc("GET /auth/callback", ah.callback)
|
||||
mux.HandleFunc("GET /auth/poll", ah.poll)
|
||||
return withSessionAuth(ah, mux), nil
|
||||
}
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
type server struct {
|
||||
newSvc func(user string) (service.NodeService, error)
|
||||
}
|
||||
|
||||
func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeService, bool) {
|
||||
user := userFromContext(r)
|
||||
if user == "" {
|
||||
user = r.Header.Get("X-Ax-User")
|
||||
}
|
||||
if user == "" {
|
||||
writeError(w, http.StatusUnauthorized, "X-Ax-User header required")
|
||||
return nil, false
|
||||
}
|
||||
svc, err := s.newSvc(user)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return nil, false
|
||||
}
|
||||
return svc, true
|
||||
}
|
||||
|
||||
func (s *server) listNodes(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
q := r.URL.Query()
|
||||
var filter service.ListFilter
|
||||
for _, tag := range q["tag"] {
|
||||
filter.Rels = append(filter.Rels, service.RelInput{Type: models.RelType(tag)})
|
||||
}
|
||||
for _, rel := range q["rel"] {
|
||||
filter.Rels = append(filter.Rels, parseRel(rel))
|
||||
}
|
||||
for k, prefix := range map[string]string{"type": "_type::", "status": "_status::", "prio": "_prio::"} {
|
||||
if v := q.Get(k); v != "" {
|
||||
filter.Rels = append(filter.Rels, service.RelInput{Type: models.RelType(prefix + v)})
|
||||
}
|
||||
}
|
||||
if v := q.Get("namespace"); v != "" {
|
||||
filter.Rels = append(filter.Rels, service.RelInput{Type: models.RelInNamespace, Target: v})
|
||||
}
|
||||
if v := q.Get("assignee"); v != "" {
|
||||
filter.Rels = append(filter.Rels, service.RelInput{Type: models.RelAssignee, Target: v})
|
||||
}
|
||||
if v := q.Get("mention"); v != "" {
|
||||
filter.Rels = append(filter.Rels, service.RelInput{Type: models.RelMentions, Target: v})
|
||||
}
|
||||
nodes, err := svc.List(filter)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, nodes)
|
||||
}
|
||||
|
||||
func (s *server) addNode(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var input service.AddInput
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
n, err := svc.Add(input)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
writeJSON(w, n)
|
||||
}
|
||||
|
||||
func (s *server) getNode(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
n, err := svc.GetByID(r.PathValue("id"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, n)
|
||||
}
|
||||
|
||||
func (s *server) updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var input service.UpdateInput
|
||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
n, err := svc.Update(r.PathValue("id"), input)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, n)
|
||||
}
|
||||
|
||||
func (s *server) deleteNode(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := svc.Delete(r.PathValue("id")); err != nil {
|
||||
writeError(w, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *server) listUsers(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
users, err := svc.ListUsers()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, users)
|
||||
}
|
||||
|
||||
func (s *server) addUser(w http.ResponseWriter, r *http.Request) {
|
||||
svc, ok := s.svc(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
n, err := svc.AddUser(body.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
writeJSON(w, n)
|
||||
}
|
||||
|
||||
func parseRel(s string) service.RelInput {
|
||||
if strings.Contains(s, "::") {
|
||||
return service.RelInput{Type: models.RelType(s)}
|
||||
}
|
||||
if idx := strings.Index(s, ":"); idx >= 0 {
|
||||
return service.RelInput{Type: models.RelType(s[:idx]), Target: s[idx+1:]}
|
||||
}
|
||||
return service.RelInput{Type: models.RelType(s)}
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, code int, msg string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": msg})
|
||||
}
|
||||
Reference in New Issue
Block a user