From 7b8202b50bc15ff77d0bbfdc61296d017ef75ccc Mon Sep 17 00:00:00 2001 From: Elias Kohout Date: Fri, 12 Jun 2026 00:55:09 +0200 Subject: [PATCH] feat: harden HTTP server with rate limiting, request timeouts, and sanitized error messages --- src/cmd/serve.go | 10 ++++- src/serve/auth.go | 10 ++--- src/serve/ratelimit.go | 86 ++++++++++++++++++++++++++++++++++++++++++ src/serve/server.go | 12 +++--- 4 files changed, 107 insertions(+), 11 deletions(-) create mode 100644 src/serve/ratelimit.go diff --git a/src/cmd/serve.go b/src/cmd/serve.go index 7c9bbe3..1df02cc 100644 --- a/src/cmd/serve.go +++ b/src/cmd/serve.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "os" + "time" "github.com/spf13/cobra" ) @@ -38,7 +39,14 @@ var serveCmd = &cobra.Command{ os.Exit(1) } fmt.Fprintf(os.Stdout, "listening on %s\n", addr) - if err := http.ListenAndServe(addr, handler); err != nil { + srv := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: 5 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + if err := srv.ListenAndServe(); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } diff --git a/src/serve/auth.go b/src/serve/auth.go index 2391211..036be0f 100644 --- a/src/serve/auth.go +++ b/src/serve/auth.go @@ -144,13 +144,13 @@ func (h *authHandler) callback(w http.ResponseWriter, r *http.Request) { oauth2.SetAuthURLParam("code_verifier", pending.verifier), ) if err != nil { - http.Error(w, "token exchange failed: "+err.Error(), http.StatusBadRequest) + http.Error(w, "token exchange failed", http.StatusBadRequest) return } username, err := h.extractUsername(r.Context(), token) if err != nil { - http.Error(w, "failed to identify user: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "failed to identify user", http.StatusInternalServerError) return } @@ -177,7 +177,7 @@ func (h *authHandler) deviceStart(w http.ResponseWriter, r *http.Request) { oauth2.SetAuthURLParam("client_secret", h.cfg.ClientSecret), ) if err != nil { - writeError(w, http.StatusBadGateway, "device authorization request failed: "+err.Error()) + writeError(w, http.StatusBadGateway, "device authorization request failed") return } @@ -196,7 +196,7 @@ func (h *authHandler) deviceStart(w http.ResponseWriter, r *http.Request) { if err != nil { h.mu.Lock() if p := h.pendingDevice[loginID]; p != nil { - p.err = err.Error() + p.err = "device token exchange failed" } h.mu.Unlock() return @@ -206,7 +206,7 @@ func (h *authHandler) deviceStart(w http.ResponseWriter, r *http.Request) { if err != nil { h.mu.Lock() if p := h.pendingDevice[loginID]; p != nil { - p.err = "failed to identify user: " + err.Error() + p.err = "failed to identify user" } h.mu.Unlock() return diff --git a/src/serve/ratelimit.go b/src/serve/ratelimit.go new file mode 100644 index 0000000..4fdc178 --- /dev/null +++ b/src/serve/ratelimit.go @@ -0,0 +1,86 @@ +package serve + +import ( + "net/http" + "sync" + "time" +) + +type visitor struct { + tokens float64 + lastSeen time.Time +} + +type rateLimiter struct { + mu sync.Mutex + visitors map[string]*visitor + rate float64 // tokens per second + burst float64 // max tokens +} + +func newRateLimiter(rate float64, burst int) *rateLimiter { + rl := &rateLimiter{ + visitors: make(map[string]*visitor), + rate: rate, + burst: float64(burst), + } + go rl.cleanup() + return rl +} + +func (rl *rateLimiter) allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + v, exists := rl.visitors[ip] + now := time.Now() + if !exists { + rl.visitors[ip] = &visitor{tokens: rl.burst - 1, lastSeen: now} + return true + } + + elapsed := now.Sub(v.lastSeen).Seconds() + v.lastSeen = now + v.tokens += elapsed * rl.rate + if v.tokens > rl.burst { + v.tokens = rl.burst + } + + if v.tokens < 1 { + return false + } + v.tokens-- + return true +} + +func (rl *rateLimiter) cleanup() { + for range time.Tick(time.Minute) { + rl.mu.Lock() + for ip, v := range rl.visitors { + if time.Since(v.lastSeen) > 5*time.Minute { + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } +} + +func clientIP(r *http.Request) string { + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + return r.RemoteAddr +} + +func withRateLimit(rl *rateLimiter, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !rl.allow(clientIP(r)) { + writeError(w, http.StatusTooManyRequests, "rate limit exceeded") + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/src/serve/server.go b/src/serve/server.go index a871d3e..d176089 100644 --- a/src/serve/server.go +++ b/src/serve/server.go @@ -24,6 +24,8 @@ func New(newSvc func(user string) (service.NodeService, error), oidcCfg *store.O mux.HandleFunc("DELETE /nodes/{id}", s.deleteNode) mux.HandleFunc("GET /users", s.listUsers) mux.HandleFunc("POST /users", s.addUser) + rl := newRateLimiter(10, 30) // 10 req/s sustained, burst of 30 + if oidcCfg != nil { ah, err := newAuthHandler(*oidcCfg) if err != nil { @@ -33,9 +35,9 @@ func New(newSvc func(user string) (service.NodeService, error), oidcCfg *store.O mux.HandleFunc("POST /auth/device/start", ah.deviceStart) mux.HandleFunc("GET /auth/callback", ah.callback) mux.HandleFunc("GET /auth/poll", ah.poll) - return withSessionAuth(ah, mux), nil + return withRateLimit(rl, withSessionAuth(ah, mux)), nil } - return mux, nil + return withRateLimit(rl, mux), nil } type server struct { @@ -53,7 +55,7 @@ func (s *server) svc(w http.ResponseWriter, r *http.Request) (service.NodeServic } svc, err := s.newSvc(user) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + writeError(w, http.StatusInternalServerError, "internal error") return nil, false } return svc, true @@ -96,7 +98,7 @@ func (s *server) listNodes(w http.ResponseWriter, r *http.Request) { } nodes, err := svc.List(filter) if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + writeError(w, http.StatusInternalServerError, "internal error") return } writeJSON(w, nodes) @@ -171,7 +173,7 @@ func (s *server) listUsers(w http.ResponseWriter, r *http.Request) { } users, err := svc.ListUsers() if err != nil { - writeError(w, http.StatusInternalServerError, err.Error()) + writeError(w, http.StatusInternalServerError, "internal error") return } writeJSON(w, users)