diff --git a/ax b/ax new file mode 100755 index 0000000..37b261e Binary files /dev/null and b/ax differ diff --git a/src/cmd/login.go b/src/cmd/login.go index 54cd478..69df479 100644 --- a/src/cmd/login.go +++ b/src/cmd/login.go @@ -22,67 +22,112 @@ var loginCmd = &cobra.Command{ } base := fmt.Sprintf("http://%s:%d", rc.Host, rc.Port) - resp, err := http.Post(base+"/auth/start", "application/json", nil) + sessionID := tryDeviceFlow(base) + if sessionID == "" { + sessionID = tryCallbackFlow(base) + } + + pollForToken(base, sessionID) + }, +} + +// tryDeviceFlow attempts the device authorization flow. Returns a session ID +// on success, or "" if the server does not support it. +func tryDeviceFlow(base string) string { + resp, err := http.Post(base+"/auth/device/start", "application/json", nil) + if err != nil { + return "" + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "" + } + var start struct { + SessionID string `json:"session_id"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + } + json.NewDecoder(resp.Body).Decode(&start) + if start.SessionID == "" { + return "" + } + + uri := start.VerificationURI + if start.VerificationURIComplete != "" { + uri = start.VerificationURIComplete + } + fmt.Printf("To sign in, open this URL in any browser:\n\n %s\n\nThen enter this code: %s\n\nWaiting for authentication...\n", uri, start.UserCode) + return start.SessionID +} + +// tryCallbackFlow initiates the traditional callback-based OIDC flow. +// Exits the process on failure. +func tryCallbackFlow(base string) string { + resp, err := http.Post(base+"/auth/start", "application/json", nil) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to contact server: %v\n", err) + os.Exit(1) + } + var start struct { + URL string `json:"url"` + SessionID string `json:"session_id"` + } + json.NewDecoder(resp.Body).Decode(&start) + resp.Body.Close() + + if start.URL == "" { + fmt.Fprintln(os.Stderr, "server did not return an auth URL; is OIDC configured on the server?") + os.Exit(1) + } + + fmt.Printf("Open this URL in your browser:\n\n %s\n\nWaiting for login...\n", start.URL) + return start.SessionID +} + +// pollForToken polls the server until the login completes or times out. +func pollForToken(base, sessionID string) { + deadline := time.Now().Add(5 * time.Minute) + for time.Now().Before(deadline) { + time.Sleep(2 * time.Second) + + resp, err := http.Get(fmt.Sprintf("%s/auth/poll?session_id=%s", base, sessionID)) if err != nil { - fmt.Fprintf(os.Stderr, "failed to contact server: %v\n", err) + continue + } + if resp.StatusCode == http.StatusAccepted { + resp.Body.Close() + continue + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + fmt.Fprintln(os.Stderr, "login failed") os.Exit(1) } - var start struct { - URL string `json:"url"` - SessionID string `json:"session_id"` + + var result struct { + Token string `json:"token"` + Username string `json:"username"` } - json.NewDecoder(resp.Body).Decode(&start) + json.NewDecoder(resp.Body).Decode(&result) resp.Body.Close() - if start.URL == "" { - fmt.Fprintln(os.Stderr, "server did not return an auth URL; is OIDC configured on the server?") + session, err := store.LoadSession() + if err != nil { + fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err) os.Exit(1) } - - fmt.Printf("Open this URL in your browser:\n\n %s\n\nWaiting for login...\n", start.URL) - - deadline := time.Now().Add(5 * time.Minute) - for time.Now().Before(deadline) { - time.Sleep(2 * time.Second) - - resp, err := http.Get(fmt.Sprintf("%s/auth/poll?session_id=%s", base, start.SessionID)) - if err != nil { - continue - } - if resp.StatusCode == http.StatusAccepted { - resp.Body.Close() - continue - } - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - fmt.Fprintln(os.Stderr, "login failed") - os.Exit(1) - } - - var result struct { - Token string `json:"token"` - Username string `json:"username"` - } - json.NewDecoder(resp.Body).Decode(&result) - resp.Body.Close() - - session, err := store.LoadSession() - if err != nil { - fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err) - os.Exit(1) - } - session.Token = result.Token - if err := session.Save(); err != nil { - fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err) - os.Exit(1) - } - fmt.Printf("Logged in as %s\n", result.Username) - return + session.Token = result.Token + if err := session.Save(); err != nil { + fmt.Fprintf(os.Stderr, "failed to save session: %v\n", err) + os.Exit(1) } + fmt.Printf("Logged in as %s\n", result.Username) + return + } - fmt.Fprintln(os.Stderr, "login timed out") - os.Exit(1) - }, + fmt.Fprintln(os.Stderr, "login timed out") + os.Exit(1) } func init() { diff --git a/src/serve/auth.go b/src/serve/auth.go index b1444ae..9575c8b 100644 --- a/src/serve/auth.go +++ b/src/serve/auth.go @@ -23,16 +23,26 @@ type pendingLogin struct { serverToken string // set by callback when complete; empty while pending } +// pendingDeviceLogin tracks an in-progress device authorization flow. +type pendingDeviceLogin struct { + created time.Time + serverToken string // set when device token exchange completes + username string // set when device token exchange completes + err string // set if the flow fails +} + // authHandler owns the OIDC provider connection, the pending login store, // and the active server-side session map. type authHandler struct { - mu sync.Mutex - pending map[string]*pendingLogin // loginID → pending state - sessions map[string]string // serverToken → username + mu sync.Mutex + pending map[string]*pendingLogin // loginID → pending state + pendingDevice map[string]*pendingDeviceLogin // loginID → pending device state + sessions map[string]string // serverToken → username - cfg store.OIDCConfig - provider *oidc.Provider - oauth2 oauth2.Config + cfg store.OIDCConfig + provider *oidc.Provider + oauth2 oauth2.Config + deviceFlowAvailable bool } func newAuthHandler(cfg store.OIDCConfig) (*authHandler, error) { @@ -43,18 +53,21 @@ func newAuthHandler(cfg store.OIDCConfig) (*authHandler, error) { if err != nil { return nil, fmt.Errorf("OIDC provider: %w", err) } + endpoint := provider.Endpoint() h := &authHandler{ - pending: make(map[string]*pendingLogin), - sessions: make(map[string]string), - cfg: cfg, - provider: provider, + pending: make(map[string]*pendingLogin), + pendingDevice: make(map[string]*pendingDeviceLogin), + sessions: make(map[string]string), + cfg: cfg, + provider: provider, oauth2: oauth2.Config{ ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, - Endpoint: provider.Endpoint(), + Endpoint: endpoint, RedirectURL: cfg.PublicURL + "/auth/callback", Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}, }, + deviceFlowAvailable: endpoint.DeviceAuthURL != "", } go h.cleanup() return h, nil @@ -68,6 +81,11 @@ func (h *authHandler) cleanup() { delete(h.pending, id) } } + for id, p := range h.pendingDevice { + if time.Since(p.created) > 15*time.Minute { + delete(h.pendingDevice, id) + } + } h.mu.Unlock() } } @@ -148,6 +166,69 @@ func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Login successful! You can close this tab.") } +// POST /auth/device/start → {session_id, user_code, verification_uri, verification_uri_complete} +func (h *authHandler) deviceStart(w http.ResponseWriter, r *http.Request) { + if !h.deviceFlowAvailable { + writeError(w, http.StatusNotFound, "device flow not supported by OIDC provider") + return + } + + da, err := h.oauth2.DeviceAuth(r.Context()) + if err != nil { + writeError(w, http.StatusBadGateway, "device authorization request failed: "+err.Error()) + return + } + + loginID := randomToken(16) + + h.mu.Lock() + h.pendingDevice[loginID] = &pendingDeviceLogin{created: time.Now()} + h.mu.Unlock() + + // Exchange device code for token in the background. + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + token, err := h.oauth2.DeviceAccessToken(ctx, da) + if err != nil { + h.mu.Lock() + if p := h.pendingDevice[loginID]; p != nil { + p.err = err.Error() + } + h.mu.Unlock() + return + } + + username, err := h.extractUsername(ctx, token) + if err != nil { + h.mu.Lock() + if p := h.pendingDevice[loginID]; p != nil { + p.err = "failed to identify user: " + err.Error() + } + h.mu.Unlock() + return + } + + serverToken := randomToken(32) + + h.mu.Lock() + h.sessions[serverToken] = username + if p := h.pendingDevice[loginID]; p != nil { + p.serverToken = serverToken + p.username = username + } + h.mu.Unlock() + }() + + writeJSON(w, map[string]string{ + "session_id": loginID, + "user_code": da.UserCode, + "verification_uri": da.VerificationURI, + "verification_uri_complete": da.VerificationURIComplete, + }) +} + // GET /auth/poll?session_id=... // Returns 202 while pending, 200 {token, username} when done, 404 if expired. func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) { @@ -157,25 +238,52 @@ func (h *authHandler) poll(w http.ResponseWriter, r *http.Request) { p := h.pending[loginID] h.mu.Unlock() - if p == nil { + // Check callback-based flow first. + if p != nil { + h.mu.Lock() + serverToken := p.serverToken + if serverToken != "" { + delete(h.pending, loginID) + } + h.mu.Unlock() + + if serverToken == "" { + w.WriteHeader(http.StatusAccepted) + return + } + username := h.lookupSession(serverToken) + writeJSON(w, map[string]string{"token": serverToken, "username": username}) + return + } + + // Check device flow. + h.mu.Lock() + dp := h.pendingDevice[loginID] + h.mu.Unlock() + + if dp == nil { writeError(w, http.StatusNotFound, "session not found or expired") return } h.mu.Lock() - serverToken := p.serverToken - if serverToken != "" { - delete(h.pending, loginID) // consume once delivered + serverToken := dp.serverToken + errMsg := dp.err + if serverToken != "" || errMsg != "" { + delete(h.pendingDevice, loginID) } h.mu.Unlock() + if errMsg != "" { + writeError(w, http.StatusGone, errMsg) + return + } if serverToken == "" { w.WriteHeader(http.StatusAccepted) return } - username := h.lookupSession(serverToken) - writeJSON(w, map[string]string{"token": serverToken, "username": username}) + writeJSON(w, map[string]string{"token": serverToken, "username": dp.username}) } func (h *authHandler) extractUsername(ctx context.Context, token *oauth2.Token) (string, error) { diff --git a/src/serve/server.go b/src/serve/server.go index add7ac1..a871d3e 100644 --- a/src/serve/server.go +++ b/src/serve/server.go @@ -30,6 +30,7 @@ func New(newSvc func(user string) (service.NodeService, error), oidcCfg *store.O return nil, err } mux.HandleFunc("POST /auth/start", ah.start) + mux.HandleFunc("POST /auth/device/start", ah.deviceStart) mux.HandleFunc("GET /auth/callback", ah.callback) mux.HandleFunc("GET /auth/poll", ah.poll) return withSessionAuth(ah, mux), nil