diff --git a/go.mod b/go.mod index 6283146..6fb377c 100644 --- a/go.mod +++ b/go.mod @@ -25,14 +25,12 @@ require ( ) require ( - github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.20.1 // indirect - go.uber.org/ratelimit v0.3.1 golang.org/x/sys v0.42.0 // indirect google.golang.org/protobuf v1.36.11 ) diff --git a/go.sum b/go.sum index c472c43..fe38fbb 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= -github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -42,14 +40,10 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= -go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ= diff --git a/pkg/codec/encoding.go b/pkg/codec/encoding.go index 5bbd515..2696722 100644 --- a/pkg/codec/encoding.go +++ b/pkg/codec/encoding.go @@ -29,6 +29,11 @@ func DecodeBase64(encoded string) ([]byte, error) { // 0–9, the next 26 letters ('A'–'Z') as values 10–35, and the final 26 letters // ('a'–'z') as values 36–61. // +// Leading '0' characters are significant: like base58's leading '1's, each +// leading '0' in the input represents one leading zero byte in the output. +// This preserves binary payloads (e.g. protobuf) whose first bytes are 0x00, +// which a plain big-integer round trip would silently drop. +// // An error is returned if the input string contains invalid characters. // // Example usage: @@ -66,7 +71,59 @@ func DecodeBase62(s string) ([]byte, error) { result.Add(&result, big.NewInt(int64(val))) } - // Convert big.Int to a byte slice. + // Count leading '0' characters: each one encodes a leading zero byte that + // the big.Int representation cannot carry. + leadingZeros := 0 + for _, c := range s { + if c != '0' { + break + } + leadingZeros++ + } + decoded := result.Bytes() + if leadingZeros > 0 { + withZeros := make([]byte, leadingZeros+len(decoded)) + copy(withZeros[leadingZeros:], decoded) + decoded = withZeros + } return decoded, nil } + +// EncodeBase62 encodes bytes to a base62 string using the same alphabet as +// DecodeBase62 ([0-9A-Za-z]). Leading zero bytes are encoded as leading '0' +// characters so that DecodeBase62(EncodeBase62(b)) round-trips exactly. +func EncodeBase62(data []byte) string { + if len(data) == 0 { + return "" + } + + leadingZeros := 0 + for _, b := range data { + if b != 0 { + break + } + leadingZeros++ + } + + const alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + var num big.Int + num.SetBytes(data) + + var digits []byte + base := big.NewInt(62) + mod := new(big.Int) + for num.Sign() > 0 { + num.DivMod(&num, base, mod) + digits = append(digits, alphabet[mod.Int64()]) + } + + encoded := make([]byte, leadingZeros+len(digits)) + for i := range leadingZeros { + encoded[i] = '0' + } + for i, d := range digits { + encoded[leadingZeros+len(digits)-1-i] = d + } + return string(encoded) +} diff --git a/pkg/codec/encoding_test.go b/pkg/codec/encoding_test.go index 0f64886..cf6124a 100644 --- a/pkg/codec/encoding_test.go +++ b/pkg/codec/encoding_test.go @@ -107,3 +107,39 @@ func TestDecodeBase62(t *testing.T) { }) } } + +// TestBase62RoundTrip verifies EncodeBase62/DecodeBase62 round-trip exactly, +// including binary payloads with leading zero bytes (regression for the +// big.Int round trip silently dropping leading 0x00 bytes). +func TestBase62RoundTrip(t *testing.T) { + testCases := []struct { + name string + data []byte + }{ + {"simple ascii", []byte("Hello World")}, + {"single zero byte", []byte{0x00}}, + {"leading zero bytes", []byte{0x00, 0x00, 0x01, 0x02}}, + {"all zero bytes", []byte{0x00, 0x00, 0x00}}, + {"binary with embedded zeros", []byte{0x0a, 0x00, 0xff, 0x00}}, + {"empty", []byte{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encoded := EncodeBase62(tc.data) + decoded, err := DecodeBase62(encoded) + if err != nil { + t.Fatalf("DecodeBase62(%q) failed: %v", encoded, err) + } + if len(decoded) != len(tc.data) { + t.Fatalf("round trip changed length: %d -> %d (encoded %q, decoded %x, original %x)", + len(tc.data), len(decoded), encoded, decoded, tc.data) + } + for i := range decoded { + if decoded[i] != tc.data[i] { + t.Fatalf("round trip mismatch at byte %d: got %x, want %x", i, decoded, tc.data) + } + } + }) + } +} diff --git a/pkg/common/middleware_chain.go b/pkg/common/middleware_chain.go index 56a661e..61dad4e 100644 --- a/pkg/common/middleware_chain.go +++ b/pkg/common/middleware_chain.go @@ -13,9 +13,14 @@ func NewMiddlewareChain(middlewares ...Middleware) MiddlewareChain { return middlewares } -// Append adds middleware to the end of the chain +// Append returns a new chain with the given middleware added to the end. +// The original chain is never modified, and the result has its own backing +// array, so multiple Appends on the same parent chain are safe. func (c MiddlewareChain) Append(middlewares ...Middleware) MiddlewareChain { - return append(c, middlewares...) + result := make(MiddlewareChain, 0, len(c)+len(middlewares)) + result = append(result, c...) + result = append(result, middlewares...) + return result } // Prepend adds middleware to the beginning of the chain diff --git a/pkg/common/middleware_chain_test.go b/pkg/common/middleware_chain_test.go index ab39798..36545b4 100644 --- a/pkg/common/middleware_chain_test.go +++ b/pkg/common/middleware_chain_test.go @@ -153,3 +153,39 @@ func TestEmptyMiddlewareChain(t *testing.T) { t.Errorf("Expected body %q, got %q", "OK", w.Body.String()) } } + +// TestMiddlewareChainAppendDoesNotAliasParent verifies that two Appends on the +// same parent chain don't share a backing array (regression: append(c, ...) +// could let the second Append clobber the first chain's middleware). +func TestMiddlewareChainAppendDoesNotAliasParent(t *testing.T) { + record := func(tag string, log *[]string) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *log = append(*log, tag) + next.ServeHTTP(w, r) + }) + } + } + + var log []string + // Build a parent with spare capacity so naive append would share arrays. + parent := make(MiddlewareChain, 0, 4) + parent = parent.Append(record("base", &log)) + + chainA := parent.Append(record("a", &log)) + chainB := parent.Append(record("b", &log)) // must not overwrite chainA's "a" + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + log = nil + chainA.Then(handler).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil)) + if len(log) != 2 || log[0] != "base" || log[1] != "a" { + t.Fatalf("chainA executed %v, want [base a]", log) + } + + log = nil + chainB.Then(handler).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil)) + if len(log) != 2 || log[0] != "base" || log[1] != "b" { + t.Fatalf("chainB executed %v, want [base b]", log) + } +} diff --git a/pkg/common/types.go b/pkg/common/types.go index 75ce0a4..0d7a33b 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -60,7 +60,9 @@ type RateLimitConfig[T comparable, U any] struct { UserIDFromUser func(user U) T // UserIDToString converts the user ID (type T) to a string for use as a map key. - // Required only when Strategy is StrategyUser. + // Used only when Strategy is StrategyUser. Optional: if nil, a default + // conversion is used (handles string, int, int64, fmt.Stringer, and falls + // back to fmt.Sprint). UserIDToString func(userID T) string // KeyExtractor provides a custom function to generate the rate limit key from the request. diff --git a/pkg/metrics/handler_method.go b/pkg/metrics/handler_method.go index 419c5b3..4bf2fa2 100644 --- a/pkg/metrics/handler_method.go +++ b/pkg/metrics/handler_method.go @@ -4,14 +4,58 @@ import ( "bufio" "net" "net/http" + "strconv" "time" "github.com/Suhaibinator/SRouter/pkg/scontext" // Keep scontext import ) +// cachedHistogram returns the histogram for the given cache key, building it at +// most once. Building (and registering with the backend) per request would +// allocate and contend the registry lock on every request. +func (m *MetricsMiddlewareImpl[T, U]) cachedHistogram(key string, build func() Histogram) Histogram { + if v, ok := m.metricCache.Load(key); ok { + return v.(Histogram) + } + actual, _ := m.metricCache.LoadOrStore(key, build()) + return actual.(Histogram) +} + +// cachedCounter returns the counter for the given cache key, building it at most once. +func (m *MetricsMiddlewareImpl[T, U]) cachedCounter(key string, build func() Counter) Counter { + if v, ok := m.metricCache.Load(key); ok { + return v.(Counter) + } + actual, _ := m.metricCache.LoadOrStore(key, build()) + return actual.(Counter) +} + +// taggedHistogram starts a histogram builder with the configured default tags applied. +func (m *MetricsMiddlewareImpl[T, U]) taggedHistogram(name, desc string) HistogramBuilder { + b := m.registry.NewHistogram().Name(name).Description(desc) + for k, v := range m.config.DefaultTags { + if v != "" { + b = b.Tag(k, v) + } + } + return b +} + +// taggedCounter starts a counter builder with the configured default tags applied. +func (m *MetricsMiddlewareImpl[T, U]) taggedCounter(name, desc string) CounterBuilder { + b := m.registry.NewCounter().Name(name).Description(desc) + for k, v := range m.config.DefaultTags { + if v != "" { + b = b.Tag(k, v) + } + } + return b +} + // Handler wraps an HTTP handler with metrics collection. It captures metrics such as // request latency, throughput, QPS, and errors based on the middleware configuration. -// The metrics are collected using the registry provided to the middleware. +// The metrics are collected using the registry provided to the middleware, are tagged +// with the configured DefaultTags, and are built once per route and then cached. // The 'name' parameter can be used as a fallback identifier if route template information is not available. // This is now a method on the generic MetricsMiddlewareImpl[T, U]. func (m *MetricsMiddlewareImpl[T, U]) Handler(name string, handler http.Handler) http.Handler { @@ -49,89 +93,77 @@ func (m *MetricsMiddlewareImpl[T, U]) Handler(name string, handler http.Handler) // Collect metrics if m.config.EnableLatency { - // Create a route-specific histogram for request latency - latency := m.registry.NewHistogram(). - Name("request_latency_seconds"). - Description("Request latency in seconds"). - Tag("route", routeIdentifier). - Build() - - // Observe the request latency + // Route-specific histogram for request latency + latency := m.cachedHistogram("latency|"+routeIdentifier, func() Histogram { + return m.taggedHistogram("request_latency_seconds", "Request latency in seconds"). + Tag("route", routeIdentifier). + Build() + }) latency.Observe(duration.Seconds()) - // Create a global histogram for total request latency across all routes - totalLatency := m.registry.NewHistogram(). - Name("request_latency_seconds_total"). - Description("Total request latency in seconds across all routes"). - Build() - - // Observe the total latency + // Global histogram for request latency across all routes. + // (Named all_* rather than *_total: the _total suffix is reserved + // for counters in Prometheus naming conventions.) + totalLatency := m.cachedHistogram("latency|all", func() Histogram { + return m.taggedHistogram("all_request_latency_seconds", "Request latency in seconds across all routes"). + Build() + }) totalLatency.Observe(duration.Seconds()) } - if m.config.EnableThroughput { - // Create a route-specific counter for request throughput - throughput := m.registry.NewCounter(). - Name("request_throughput_bytes"). - Description("Request throughput in bytes"). - Tag("route", routeIdentifier). - Build() - - // Add the request size - if r.ContentLength > 0 { - throughput.Add(float64(r.ContentLength)) - - // Also track total throughput across all routes - totalThroughput := m.registry.NewCounter(). - Name("request_throughput_bytes_total"). - Description("Total request throughput in bytes across all routes"). + if m.config.EnableThroughput && r.ContentLength > 0 { + // Route-specific counter for request throughput + throughput := m.cachedCounter("throughput|"+routeIdentifier, func() Counter { + return m.taggedCounter("request_throughput_bytes", "Request throughput in bytes"). + Tag("route", routeIdentifier). Build() + }) + throughput.Add(float64(r.ContentLength)) - totalThroughput.Add(float64(r.ContentLength)) - } + // Also track total throughput across all routes + totalThroughput := m.cachedCounter("throughput|all", func() Counter { + return m.taggedCounter("request_throughput_bytes_total", "Total request throughput in bytes across all routes"). + Build() + }) + totalThroughput.Add(float64(r.ContentLength)) } if m.config.EnableQPS { - // Create a route-specific counter for requests per second - qps := m.registry.NewCounter(). - Name("requests_total"). - Description("Total number of requests"). - Tag("route", routeIdentifier). - Build() - - // Increment the counter + // Route-specific counter for requests + qps := m.cachedCounter("qps|"+routeIdentifier, func() Counter { + return m.taggedCounter("requests_total", "Total number of requests"). + Tag("route", routeIdentifier). + Build() + }) qps.Inc() - // Create a global counter for total requests across all routes - totalQps := m.registry.NewCounter(). - Name("all_requests_total"). - Description("Total number of requests across all routes"). - Build() - - // Increment the total counter + // Global counter for total requests across all routes + totalQps := m.cachedCounter("qps|all", func() Counter { + return m.taggedCounter("all_requests_total", "Total number of requests across all routes"). + Build() + }) totalQps.Inc() } if m.config.EnableErrors && rw.statusCode >= 400 { - // Create a route-specific counter for errors - errors := m.registry.NewCounter(). - Name("request_errors_total"). - Description("Total number of request errors"). - Tag("route", routeIdentifier). - Tag("status_code", http.StatusText(rw.statusCode)). - Build() - - // Increment the counter + // Record the numeric status code (e.g. "404"), not the status text. + statusCode := strconv.Itoa(rw.statusCode) + + // Route-specific counter for errors + errors := m.cachedCounter("errors|"+routeIdentifier+"|"+statusCode, func() Counter { + return m.taggedCounter("request_errors_total", "Total number of request errors"). + Tag("route", routeIdentifier). + Tag("status_code", statusCode). + Build() + }) errors.Inc() - // Create a global counter for errors across all routes - totalErrors := m.registry.NewCounter(). - Name("all_request_errors_total"). - Description("Total number of request errors across all routes"). - Tag("status_code", http.StatusText(rw.statusCode)). - Build() - - // Increment the total counter + // Global counter for errors across all routes + totalErrors := m.cachedCounter("errors|all|"+statusCode, func() Counter { + return m.taggedCounter("all_request_errors_total", "Total number of request errors across all routes"). + Tag("status_code", statusCode). + Build() + }) totalErrors.Inc() } }) diff --git a/pkg/metrics/handler_method_test.go b/pkg/metrics/handler_method_test.go new file mode 100644 index 0000000..5691935 --- /dev/null +++ b/pkg/metrics/handler_method_test.go @@ -0,0 +1,139 @@ +package metrics + +import ( + "net/http" + "strings" + "testing" +) + +// TestHandlerAppliesDefaultTags verifies that DefaultTags from the middleware +// config are applied to every metric the Handler emits, that tags with empty +// values are skipped, and that route/status tags are still added on top. +func TestHandlerAppliesDefaultTags(t *testing.T) { + registry := NewMockMetricsRegistry() + middleware := NewMetricsMiddleware[string, any](registry, MetricsMiddlewareConfig{ + EnableLatency: true, + EnableThroughput: true, + EnableQPS: true, + EnableErrors: true, + DefaultTags: Tags{ + "service": "api", + "region": "", // Empty values must not be emitted as tags. + }, + }) + + handler := middleware.Handler("fallback", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + + req, err := http.NewRequest("POST", "/missing", strings.NewReader("payload")) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + handler.ServeHTTP(NewMockResponseWriter(), req) + + checkTags := func(name string, tags Tags, extra Tags) { + t.Helper() + if tags["service"] != "api" { + t.Errorf("%s: expected default tag service=api, got tags %v", name, tags) + } + if _, present := tags["region"]; present { + t.Errorf("%s: empty-valued default tag %q must not be emitted, got tags %v", name, "region", tags) + } + for k, v := range extra { + if tags[k] != v { + t.Errorf("%s: expected tag %s=%s, got tags %v", name, k, v, tags) + } + } + } + + latency, ok := registry.histograms["request_latency_seconds"].(*MockHistogram) + if !ok { + t.Fatal("expected request_latency_seconds histogram to be built") + } + checkTags("request_latency_seconds", latency.Tags(), Tags{"route": "fallback"}) + + allLatency, ok := registry.histograms["all_request_latency_seconds"].(*MockHistogram) + if !ok { + t.Fatal("expected all_request_latency_seconds histogram to be built") + } + checkTags("all_request_latency_seconds", allLatency.Tags(), nil) + + qps, ok := registry.counters["requests_total"].(*MockCounter) + if !ok { + t.Fatal("expected requests_total counter to be built") + } + checkTags("requests_total", qps.Tags(), Tags{"route": "fallback"}) + + throughput, ok := registry.counters["request_throughput_bytes"].(*MockCounter) + if !ok { + t.Fatal("expected request_throughput_bytes counter to be built") + } + checkTags("request_throughput_bytes", throughput.Tags(), Tags{"route": "fallback"}) + + errors, ok := registry.counters["request_errors_total"].(*MockCounter) + if !ok { + t.Fatal("expected request_errors_total counter to be built") + } + checkTags("request_errors_total", errors.Tags(), Tags{"route": "fallback", "status_code": "404"}) +} + +// TestHandlerReusesCachedMetrics verifies that the per-route metric instances +// are built once and reused across requests: repeated requests must accumulate +// into the same counter/histogram rather than rebuilding fresh metrics. +func TestHandlerReusesCachedMetrics(t *testing.T) { + registry := NewMockMetricsRegistry() + middleware := NewMetricsMiddleware[string, any](registry, MetricsMiddlewareConfig{ + EnableLatency: true, + EnableThroughput: true, + EnableQPS: true, + }) + + handler := middleware.Handler("cached-route", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + const requests = 3 + const bodySize = 7 // len("payload") + for range requests { + req, err := http.NewRequest("POST", "/cached", strings.NewReader("payload")) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + handler.ServeHTTP(NewMockResponseWriter(), req) + } + + // If caching failed, each request would build (and overwrite) a fresh + // counter holding only that request's increment. + qps, ok := registry.counters["requests_total"].(*MockCounter) + if !ok { + t.Fatal("expected requests_total counter to be built") + } + if qps.Value() != requests { + t.Errorf("expected requests_total to accumulate %d increments on one instance, got %v", requests, qps.Value()) + } + + totalQPS, ok := registry.counters["all_requests_total"].(*MockCounter) + if !ok { + t.Fatal("expected all_requests_total counter to be built") + } + if totalQPS.Value() != requests { + t.Errorf("expected all_requests_total to accumulate %d increments on one instance, got %v", requests, totalQPS.Value()) + } + + throughput, ok := registry.counters["request_throughput_bytes"].(*MockCounter) + if !ok { + t.Fatal("expected request_throughput_bytes counter to be built") + } + if throughput.Value() != requests*bodySize { + t.Errorf("expected request_throughput_bytes to accumulate %d bytes on one instance, got %v", requests*bodySize, throughput.Value()) + } + + latency, ok := registry.histograms["request_latency_seconds"].(*MockHistogram) + if !ok { + t.Fatal("expected request_latency_seconds histogram to be built") + } + if len(latency.observations) != requests { + t.Errorf("expected %d latency observations on one histogram instance, got %d", requests, len(latency.observations)) + } +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index a929301..dafaa80 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -8,6 +8,7 @@ package metrics import ( "math/rand" "net/http" + "sync" "time" ) @@ -273,24 +274,40 @@ func (s *RandomSampler) Sample() bool { // MetricsMiddlewareImpl is a concrete generic implementation of the MetricsMiddleware interface. // T is the UserID type (comparable), U is the User object type (any). type MetricsMiddlewareImpl[T comparable, U any] struct { - registry MetricsRegistry - config MetricsMiddlewareConfig - filter MetricsFilter - sampler MetricsSampler + registry MetricsRegistry + config MetricsMiddlewareConfig + filter MetricsFilter + sampler MetricsSampler + metricCache sync.Map // cache key (string) -> built Counter/Histogram, one per route +} + +// samplerFromConfig returns a sampler implementing the configured SamplingRate, +// or nil when no sampling is needed (rate <= 0 means "not configured" and +// rate >= 1 means "always sample", both of which need no sampler). +func samplerFromConfig(config MetricsMiddlewareConfig) MetricsSampler { + if config.SamplingRate > 0 && config.SamplingRate < 1 { + return NewRandomSampler(config.SamplingRate) + } + return nil } // NewMetricsMiddleware creates a new generic MetricsMiddlewareImpl. +// If config.SamplingRate is in (0, 1), a RandomSampler is installed automatically; +// it can be replaced via WithSampler. // T is the UserID type (comparable), U is the User object type (any). func NewMetricsMiddleware[T comparable, U any](registry MetricsRegistry, config MetricsMiddlewareConfig) *MetricsMiddlewareImpl[T, U] { return &MetricsMiddlewareImpl[T, U]{ registry: registry, config: config, + sampler: samplerFromConfig(config), } } -// Configure configures the middleware. +// Configure configures the middleware. A sampler is derived from the new +// config's SamplingRate (replace it with WithSampler if custom behavior is needed). func (m *MetricsMiddlewareImpl[T, U]) Configure(config MetricsMiddlewareConfig) MetricsMiddleware[T, U] { m.config = config + m.sampler = samplerFromConfig(config) return m } diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go index 6737fd0..4c84184 100644 --- a/pkg/metrics/metrics_test.go +++ b/pkg/metrics/metrics_test.go @@ -1061,10 +1061,11 @@ func TestMetricsMiddlewareImpl_Handler_WithRouteTemplate(t *testing.T) { t.Error("Route-specific latency metric not found") } - // Check for global latency metric + // Check for global latency metric (all_* prefix: the _total suffix is + // reserved for counters in Prometheus naming conventions) var foundGlobalLatency bool for _, histogram := range registry.histograms { - if histogram.Name() == "request_latency_seconds_total" { + if histogram.Name() == "all_request_latency_seconds" { foundGlobalLatency = true break } diff --git a/pkg/metrics/prometheus/adapter.go b/pkg/metrics/prometheus/adapter.go index 36561f0..dc11c03 100644 --- a/pkg/metrics/prometheus/adapter.go +++ b/pkg/metrics/prometheus/adapter.go @@ -108,9 +108,11 @@ func (b *PrometheusCounterBuilder) Build() srouter_metrics.Counter { // We'll re-fetch based on the Vec structure. counterVec = are.ExistingCollector.(*prometheus.CounterVec) } else { - // SRouter interface expects Build to return the metric directly. - // Handle fatal error - cannot proceed without the metric. - b.registry.registry.MustRegister(counterVec) // Re-attempt registration, will panic on non-AlreadyRegisteredError + // SRouter interface expects Build to return the metric directly, + // and Build can be called from the request path, so never panic. + // The metric still works locally; it just won't be exported. + b.registry.logger.Error("Failed to register Prometheus counter; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } // Convert prometheus.Labels to srouter_metrics.Tags @@ -124,7 +126,9 @@ func (b *PrometheusCounterBuilder) Build() srouter_metrics.Counter { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { promCounter = are.ExistingCollector.(prometheus.Counter) } else { - b.registry.registry.MustRegister(promCounter) // Re-attempt registration, will panic on non-AlreadyRegisteredError + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus counter; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } // Convert prometheus.Labels to srouter_metrics.Tags @@ -179,7 +183,9 @@ func (b *PrometheusGaugeBuilder) Build() srouter_metrics.Gauge { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { gaugeVec = are.ExistingCollector.(*prometheus.GaugeVec) } else { - b.registry.registry.MustRegister(gaugeVec) // Panic on error + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus gauge; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } tags := make(srouter_metrics.Tags, len(b.opts.ConstLabels)) @@ -191,7 +197,9 @@ func (b *PrometheusGaugeBuilder) Build() srouter_metrics.Gauge { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { promGauge = are.ExistingCollector.(prometheus.Gauge) } else { - b.registry.registry.MustRegister(promGauge) // Panic on error + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus gauge; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } tags := make(srouter_metrics.Tags, len(b.opts.ConstLabels)) @@ -256,7 +264,9 @@ func (b *PrometheusHistogramBuilder) Build() srouter_metrics.Histogram { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { histoVec = are.ExistingCollector.(*prometheus.HistogramVec) } else { - b.registry.registry.MustRegister(histoVec) // Panic on error + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus histogram; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } tags := make(srouter_metrics.Tags, len(b.opts.ConstLabels)) @@ -268,7 +278,9 @@ func (b *PrometheusHistogramBuilder) Build() srouter_metrics.Histogram { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { promHisto = are.ExistingCollector.(prometheus.Histogram) } else { - b.registry.registry.MustRegister(promHisto) // Panic on error + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus histogram; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } tags := make(srouter_metrics.Tags, len(b.opts.ConstLabels)) @@ -362,7 +374,9 @@ func (b *PrometheusSummaryBuilder) Build() srouter_metrics.Summary { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { summaryVec = are.ExistingCollector.(*prometheus.SummaryVec) } else { - b.registry.registry.MustRegister(summaryVec) // Panic on error + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus summary; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } tags := make(srouter_metrics.Tags, len(b.opts.ConstLabels)) @@ -374,7 +388,9 @@ func (b *PrometheusSummaryBuilder) Build() srouter_metrics.Summary { if are, ok := err.(prometheus.AlreadyRegisteredError); ok { promSummary = are.ExistingCollector.(prometheus.Summary) } else { - b.registry.registry.MustRegister(promSummary) // Panic on error + // Never panic in the request path; keep the unregistered metric. + b.registry.logger.Error("Failed to register Prometheus summary; metric will not be exported", + zap.String("metric_name", b.opts.Name), zap.Error(err)) } } tags := make(srouter_metrics.Tags, len(b.opts.ConstLabels)) diff --git a/pkg/metrics/prometheus/adapter_test.go b/pkg/metrics/prometheus/adapter_test.go index 663a3e7..b6d6170 100644 --- a/pkg/metrics/prometheus/adapter_test.go +++ b/pkg/metrics/prometheus/adapter_test.go @@ -850,71 +850,64 @@ func TestNegativeAgeBuckets(t *testing.T) { assert.True(t, found, "Summary with negative AgeBuckets should be registered successfully") } -// --- New Test for Panic Paths (using Mock Registerer) --- +// --- Test for registration error paths (using Mock Registerer) --- +// TestPrometheusBuilder_RegisterErrorPanic verifies that registration failures +// no longer panic: Build logs an error and returns a working (unexported) metric, +// since Build can be called from the request path. func TestPrometheusBuilder_RegisterErrorPanic(t *testing.T) { mockRegistry := newMockPrometheusRegisterer() genericError := errors.New("generic registration error") mockRegistry.registerError = genericError // Configure mock to return a generic error + core, observedLogs := observer.New(zap.ErrorLevel) + logger := zap.New(core) + // Create adapter instance using the mock registerer - promRegistry := NewPrometheusRegistry(mockRegistry, "test", "panic_test", zap.NewNop()) - - // Test Counter Panic - assert.PanicsWithError(t, genericError.Error(), func() { - promRegistry.NewCounter().Name("panic_counter").Build() - }, "Counter Build should panic with generic error") - - // Test CounterVec Panic - assert.PanicsWithError(t, genericError.Error(), func() { - builderInterface := promRegistry.NewCounter() - builder := builderInterface.(*PrometheusCounterBuilder) // Cast - builder.Name("panic_counter_vec") // Call interface methods first - builder.LabelNames("a") // Call specific method - builder.Build() - }, "CounterVec Build should panic with generic error") - - // Test Gauge Panic - assert.PanicsWithError(t, genericError.Error(), func() { - promRegistry.NewGauge().Name("panic_gauge").Build() - }, "Gauge Build should panic with generic error") - - // Test GaugeVec Panic - assert.PanicsWithError(t, genericError.Error(), func() { - builderInterface := promRegistry.NewGauge() - builder := builderInterface.(*PrometheusGaugeBuilder) // Cast - builder.Name("panic_gauge_vec") // Call interface methods first - builder.LabelNames("b") // Call specific method - builder.Build() - }, "GaugeVec Build should panic with generic error") - - // Test Histogram Panic - assert.PanicsWithError(t, genericError.Error(), func() { - promRegistry.NewHistogram().Name("panic_histogram").Build() - }, "Histogram Build should panic with generic error") - - // Test HistogramVec Panic - assert.PanicsWithError(t, genericError.Error(), func() { - builderInterface := promRegistry.NewHistogram() - builder := builderInterface.(*PrometheusHistogramBuilder) // Cast - builder.Name("panic_histogram_vec") // Call interface methods first - builder.LabelNames("c") // Call specific method - builder.Build() - }, "HistogramVec Build should panic with generic error") - - // Test Summary Panic - assert.PanicsWithError(t, genericError.Error(), func() { - promRegistry.NewSummary().Name("panic_summary").Build() - }, "Summary Build should panic with generic error") - - // Test SummaryVec Panic - assert.PanicsWithError(t, genericError.Error(), func() { - builderInterface := promRegistry.NewSummary() - builder := builderInterface.(*PrometheusSummaryBuilder) // Cast - builder.Name("panic_summary_vec") // Call interface methods first - builder.LabelNames("d") // Call specific method - builder.Build() - }, "SummaryVec Build should panic with generic error") + promRegistry := NewPrometheusRegistry(mockRegistry, "test", "panic_test", logger) + + builds := []struct { + name string + build func() any + }{ + {"counter", func() any { return promRegistry.NewCounter().Name("err_counter").Build() }}, + {"counter_vec", func() any { + builder := promRegistry.NewCounter().(*PrometheusCounterBuilder) + builder.Name("err_counter_vec") + builder.LabelNames("a") + return builder.Build() + }}, + {"gauge", func() any { return promRegistry.NewGauge().Name("err_gauge").Build() }}, + {"gauge_vec", func() any { + builder := promRegistry.NewGauge().(*PrometheusGaugeBuilder) + builder.Name("err_gauge_vec") + builder.LabelNames("b") + return builder.Build() + }}, + {"histogram", func() any { return promRegistry.NewHistogram().Name("err_histogram").Build() }}, + {"histogram_vec", func() any { + builder := promRegistry.NewHistogram().(*PrometheusHistogramBuilder) + builder.Name("err_histogram_vec") + builder.LabelNames("c") + return builder.Build() + }}, + {"summary", func() any { return promRegistry.NewSummary().Name("err_summary").Build() }}, + {"summary_vec", func() any { + builder := promRegistry.NewSummary().(*PrometheusSummaryBuilder) + builder.Name("err_summary_vec") + builder.LabelNames("d") + return builder.Build() + }}, + } + + for _, tc := range builds { + var metric any + assert.NotPanics(t, func() { metric = tc.build() }, "%s Build should not panic on registration error", tc.name) + assert.NotNil(t, metric, "%s Build should still return a usable metric", tc.name) + } + + // One error log per failed registration + assert.Equal(t, len(builds), observedLogs.Len(), "Expected one error log per failed registration") } // TestAgeBucketsOverflow tests the logging and clamping when AgeBuckets exceeds MaxUint32. diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index f2ce30d..0f86538 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -143,16 +143,16 @@ func AuthenticationWithProvider[T comparable, U any]( // Authentication is a middleware that checks if a request is authenticated using a simple auth function. // T is the User ID type (comparable), U is the User object type (any). // It allows for custom authentication logic to be provided as a simple function. +// +// Note: All methods, including OPTIONS, require authentication. CORS preflight +// requests are handled by the router's CORS support before middleware runs, so +// they never reach this middleware. (Earlier versions skipped authentication +// for OPTIONS, which left registered OPTIONS routes unauthenticated.) func Authentication[T comparable, U any]( authFunc func(*http.Request) (T, bool), ) common.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - // Allow preflight requests without authentication - next.ServeHTTP(w, r) - return - } // Check if the request is authenticated userID, ok := authFunc(r) if !ok { @@ -173,17 +173,17 @@ func Authentication[T comparable, U any]( // It allows for custom authentication logic to be provided as a simple function that returns a boolean. // It adds a boolean flag to the SRouterContext if authentication is successful. // T is the User ID type (comparable), U is the User object type (any). +// +// Note: All methods, including OPTIONS, require authentication. CORS preflight +// requests are handled by the router's CORS support before middleware runs, so +// they never reach this middleware. (Earlier versions skipped authentication +// for OPTIONS, which left registered OPTIONS routes unauthenticated.) func AuthenticationBool[T comparable, U any]( authFunc func(*http.Request) bool, flagName string, // Flag name parameter ) common.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - // Allow preflight requests without authentication - next.ServeHTTP(w, r) - return - } // Check if the request is authenticated if !authFunc(r) { http.Error(w, "Unauthorized", http.StatusUnauthorized) diff --git a/pkg/middleware/auth_test.go b/pkg/middleware/auth_test.go index 2bd3815..519ab17 100644 --- a/pkg/middleware/auth_test.go +++ b/pkg/middleware/auth_test.go @@ -65,22 +65,32 @@ func TestAuthenticationGeneric(t *testing.T) { t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, rec.Code) } - // Test with OPTIONS request (should bypass authentication) + // Test with OPTIONS request (no longer bypasses authentication) req = httptest.NewRequest("OPTIONS", "/test", nil) - req.Header.Set("X-Auth-Token", "invalid-token") // Even with invalid token + req.Header.Set("X-Auth-Token", "invalid-token") rec = httptest.NewRecorder() // Call the handler wrappedHandler.ServeHTTP(rec, req) - // Check that the response status code is 200 (OK) because OPTIONS should skip auth + // OPTIONS requests are authenticated like any other method; CORS preflight + // is handled by the router before middleware runs. + if rec.Code != http.StatusUnauthorized { + t.Errorf("Expected status code %d for unauthenticated OPTIONS request, got %d", http.StatusUnauthorized, rec.Code) + } + + // OPTIONS with valid credentials should succeed + req = httptest.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("X-Auth-Token", "valid-token") + rec = httptest.NewRecorder() + wrappedHandler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Errorf("Expected status code %d for OPTIONS request, got %d", http.StatusOK, rec.Code) + t.Errorf("Expected status code %d for authenticated OPTIONS request, got %d", http.StatusOK, rec.Code) } } -// TestAuthenticationWithProvider_OptionsBypass tests the OPTIONS request bypass -// in the AuthenticationWithProvider middleware. +// TestAuthenticationWithProvider_OptionsBypass verifies OPTIONS requests are +// authenticated like any other method by the AuthenticationWithProvider middleware. func TestAuthenticationWithProvider_OptionsBypass(t *testing.T) { // Create a mock AuthProvider (using BearerTokenProvider for simplicity) provider := &BearerTokenProvider[string]{ @@ -204,16 +214,26 @@ func TestAuthentication(t *testing.T) { t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, rec.Code) } - // Test with OPTIONS request (should bypass authentication) + // Test with OPTIONS request (no longer bypasses authentication) req = httptest.NewRequest("OPTIONS", "/test", nil) - req.Header.Set("X-Auth-Token", "invalid-token") // Even with invalid token + req.Header.Set("X-Auth-Token", "invalid-token") rec = httptest.NewRecorder() // Call the handler wrappedHandler.ServeHTTP(rec, req) - // Check that the response status code is 200 (OK) because OPTIONS should skip auth + // OPTIONS requests are authenticated like any other method; CORS preflight + // is handled by the router before middleware runs. + if rec.Code != http.StatusUnauthorized { + t.Errorf("Expected status code %d for unauthenticated OPTIONS request, got %d", http.StatusUnauthorized, rec.Code) + } + + // OPTIONS with valid credentials should succeed + req = httptest.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("X-Auth-Token", "valid-token") + rec = httptest.NewRecorder() + wrappedHandler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Errorf("Expected status code %d for OPTIONS request, got %d", http.StatusOK, rec.Code) + t.Errorf("Expected status code %d for authenticated OPTIONS request, got %d", http.StatusOK, rec.Code) } } diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 38e04d7..b07a3d3 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -8,6 +8,7 @@ import ( "net/http" "runtime/debug" "sync" + "sync/atomic" "time" "go.uber.org/zap" @@ -77,7 +78,10 @@ func maxBodySize(maxSize int64) Middleware { // Timeout is a middleware that sets a timeout for the request processing. // If the handler takes longer than the specified timeout to respond, -// the middleware will cancel the request context and return a 408 Request Timeout response. +// the middleware will cancel the request context and return a 408 Request Timeout response, +// but only if the handler has not already started writing a response. +// Once the timeout fires, any further handler writes are rejected with +// http.ErrHandlerTimeout instead of racing with the timeout response. // This prevents long-running requests from blocking server resources indefinitely. func timeout(timeout time.Duration) Middleware { return func(next http.Handler) http.Handler { @@ -110,7 +114,14 @@ func timeout(timeout time.Duration) Middleware { // Handler finished normally return case <-ctx.Done(): - // Timeout occurred + // Timeout occurred. If the handler already started writing + // (or claims the response concurrently with the timeout), + // don't write a second response on top of it. + if wrappedW.wroteHeader.Load() || !wrappedW.markTimedOut() { + return + } + + // Serialize with any handler write currently in progress. wMutex.Lock() http.Error(w, "Request Timeout", http.StatusRequestTimeout) wMutex.Unlock() @@ -120,32 +131,62 @@ func timeout(timeout time.Duration) Middleware { } } -// mutexResponseWriter is a wrapper around http.ResponseWriter that uses a mutex to protect access. -// This ensures thread-safety when writing to the response from multiple goroutines. +// mutexResponseWriter is a wrapper around http.ResponseWriter that uses a mutex to protect access +// and tracks whether the response has been started. Once timedOut is set, all writes are rejected +// so a late handler can never touch the underlying writer after the timeout response was sent. type mutexResponseWriter struct { http.ResponseWriter - mu *sync.Mutex + mu *sync.Mutex + wroteHeader atomic.Bool // Tracks if WriteHeader or Write has been called + timedOut atomic.Bool // When true, reject all writes to the underlying writer +} + +// markTimedOut transitions the writer into the timed-out state, rejecting all +// further handler writes, and attempts to claim the response for the timeout +// handler. It returns false if a handler write claimed the response first (in +// the window between the caller's last check and this transition), in which +// case the timeout response must be suppressed. +func (rw *mutexResponseWriter) markTimedOut() bool { + rw.timedOut.Store(true) + return rw.wroteHeader.CompareAndSwap(false, true) } // WriteHeader acquires the mutex and calls the underlying ResponseWriter.WriteHeader. // This ensures thread-safety when setting the status code from multiple goroutines. func (rw *mutexResponseWriter) WriteHeader(statusCode int) { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() - rw.ResponseWriter.WriteHeader(statusCode) + if !rw.wroteHeader.Swap(true) { + rw.ResponseWriter.WriteHeader(statusCode) + } } // Write acquires the mutex and calls the underlying ResponseWriter.Write. // This ensures thread-safety when writing the response body from multiple goroutines. func (rw *mutexResponseWriter) Write(b []byte) (int, error) { + if rw.timedOut.Load() { + return 0, http.ErrHandlerTimeout + } rw.mu.Lock() defer rw.mu.Unlock() + // Re-check under the lock: the timeout response may have been written + // while this write was waiting for the mutex. + if rw.timedOut.Load() { + return 0, http.ErrHandlerTimeout + } + rw.wroteHeader.Store(true) return rw.ResponseWriter.Write(b) } // Flush acquires the mutex and calls the underlying ResponseWriter.Flush if it implements http.Flusher. // This ensures thread-safety when flushing the response from multiple goroutines. func (rw *mutexResponseWriter) Flush() { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if f, ok := rw.ResponseWriter.(http.Flusher); ok { diff --git a/pkg/middleware/ratelimit.go b/pkg/middleware/ratelimit.go index 2461af0..9c064fa 100644 --- a/pkg/middleware/ratelimit.go +++ b/pkg/middleware/ratelimit.go @@ -2,63 +2,222 @@ package middleware import ( - "errors" // Added for error handling "fmt" "net/http" "strconv" "sync" + "sync/atomic" "time" "github.com/Suhaibinator/SRouter/pkg/common" // Import common for shared types "github.com/Suhaibinator/SRouter/pkg/scontext" // Use scontext for context functions - "go.uber.org/ratelimit" "go.uber.org/zap" ) // Note: RateLimitStrategy, RateLimiter, RateLimitConfig moved to pkg/common/types.go -// UberRateLimiter implements the common.RateLimiter interface using Uber's ratelimit library. -// It provides a leaky bucket rate limiting algorithm, which smooths out request rates -// by allowing a steady flow of requests while preventing bursts. -// The implementation maintains a map of rate limiters, one per unique key. +// limiterSweepInterval is the minimum time between sweeps of stale limiter +// entries. Sweeps are triggered by entry creation — the only operation that +// grows the map — so a stable key set never pays for sweeping, while key +// churn (the only way the map can grow without bound) triggers a sweep at +// most once per interval. Each sweep runs in its own short-lived goroutine. +const limiterSweepInterval = time.Minute + +// UberRateLimiter implements the common.RateLimiter interface using a +// non-blocking sliding-window counter. Over-limit requests are denied +// immediately (never queued or slept on), and the configured limit/window +// semantics are honored exactly — e.g. "2 per minute" allows 2 requests per +// minute, not 1 per second. +// +// Memory is bounded by eviction: entries idle for at least two of their own +// windows hold no usable history (the next request would reset their counters +// anyway), so they are deleted by an amortized sweep and lazily recreated on +// next use. Eviction therefore never changes a rate-limit decision. Sweeps +// are scheduled only when a new entry is created, so live entries are +// roughly those active within the last 2×window + limiterSweepInterval; +// stale entries beyond that linger only until the next new key arrives. +// +// The name is retained for backwards compatibility with earlier versions that +// were backed by Uber's ratelimit library. The implementation maintains one +// window counter per unique key. type UberRateLimiter struct { - limiters sync.Map // map[string]ratelimit.Limiter + limiters sync.Map // map[string]*slidingWindowLimiter + + // lastSweep is the unix-nano time at which a sweep was last started. + // Entry creation CASes it forward to claim the right to start the next + // sweep, guaranteeing at most one sweep per limiterSweepInterval. + lastSweep atomic.Int64 + + // nowFunc overrides the clock in tests. nil means time.Now. Must be set + // before the limiter is first used. + nowFunc func() time.Time +} + +// timeNow returns the current time from the configured clock. +func (u *UberRateLimiter) timeNow() time.Time { + if u.nowFunc != nil { + return u.nowFunc() + } + return time.Now() } // NewUberRateLimiter creates a new UberRateLimiter instance. -// The returned limiter uses the leaky bucket algorithm to enforce rate limits. -// It maintains separate rate limiters for different keys (e.g., different IPs or users). +// It maintains separate rate limit counters for different keys +// (e.g., different IPs or users). func NewUberRateLimiter() *UberRateLimiter { return &UberRateLimiter{} } -// getLimiter gets or creates a limiter for the given key and rate (requests per second). -// It uses a composite key including the RPS to handle different rate limits for the same base key. -func (u *UberRateLimiter) getLimiter(key string, rps int) ratelimit.Limiter { - compositeKey := fmt.Sprintf("%s-%d", key, rps) // Combine key and rps +// slidingWindowLimiter tracks request counts for the current and previous +// windows. The effective count is the current window's count plus the +// previous window's count weighted by how much of it still overlaps a +// sliding window ending now. This smooths bursts at window boundaries while +// keeping O(1) memory per key. +type slidingWindowLimiter struct { + mu sync.Mutex + windowStart time.Time + prevCount int + currCount int + + // window is fixed at creation (the map's composite key includes it) and + // is read by the sweeper to decide staleness. + window time.Duration + + // evicted is set under mu by the sweeper immediately before the entry is + // deleted from the map. Counting against an evicted entry would silently + // lose the count, so tryAllow refuses and the caller re-fetches. + evicted bool +} + +// allow locks the limiter and counts the request. Unlike tryAllow it ignores +// the evicted flag; use it only where the entry is known not to be shared +// with a sweeper (e.g. unit tests on a bare slidingWindowLimiter). +func (l *slidingWindowLimiter) allow(limit int, window time.Duration, now time.Time) (bool, int, time.Duration) { + l.mu.Lock() + defer l.mu.Unlock() + return l.allowLocked(limit, window, now) +} + +// tryAllow counts the request unless the entry has been evicted from the +// limiter map, in which case valid is false and the caller must re-fetch the +// entry and try again. +func (l *slidingWindowLimiter) tryAllow(limit int, window time.Duration, now time.Time) (allowed bool, remaining int, reset time.Duration, valid bool) { + l.mu.Lock() + defer l.mu.Unlock() + if l.evicted { + return false, 0, 0, false + } + allowed, remaining, reset = l.allowLocked(limit, window, now) + return allowed, remaining, reset, true +} + +// allowLocked implements the sliding-window decision. l.mu must be held. +func (l *slidingWindowLimiter) allowLocked(limit int, window time.Duration, now time.Time) (bool, int, time.Duration) { + if l.windowStart.IsZero() { + l.windowStart = now + } + + // now may have been read before another goroutine created or rolled this + // entry (each request reads the clock once, before the map lookup), so + // elapsed can be slightly negative. Clamp it so overlap stays in [0, 1] + // and reset never exceeds the window. + elapsed := max(now.Sub(l.windowStart), 0) + if elapsed >= window { + // Roll the window forward, keeping alignment to window boundaries. + if elapsed >= 2*window { + l.prevCount = 0 + } else { + l.prevCount = l.currCount + } + l.currCount = 0 + periods := elapsed / window + l.windowStart = l.windowStart.Add(window * periods) + elapsed = now.Sub(l.windowStart) + } + + // Weight the previous window by its remaining overlap with a sliding + // window ending now. + overlap := 1 - float64(elapsed)/float64(window) + estimated := int(float64(l.prevCount)*overlap) + l.currCount + + if estimated >= limit { + // Denied: report how long until the current window rolls over. + // The rollover above keeps elapsed within [0, window), so this is + // always positive. + return false, 0, window - elapsed + } + + l.currCount++ + // estimated < limit here, so the remaining count is never negative. + return true, limit - estimated - 1, 0 +} +// getLimiter gets or creates the window counter stored under compositeKey. +func (u *UberRateLimiter) getLimiter(compositeKey string, window time.Duration, now time.Time) *slidingWindowLimiter { // Fast path: Check if limiter already exists. if limiter, ok := u.limiters.Load(compositeKey); ok { - return limiter.(ratelimit.Limiter) + return limiter.(*slidingWindowLimiter) } - // Slow path: Limiter doesn't exist, create a new one. - newLimiter := ratelimit.New(rps) + // Slow path: atomically load or store a new counter. windowStart is + // initialized to now so a concurrent sweep can never see the brand-new + // entry as stale. + actualLimiter, loaded := u.limiters.LoadOrStore(compositeKey, &slidingWindowLimiter{window: window, windowStart: now}) + if !loaded { + // The map just grew — the only way it can accumulate garbage — so + // this is the only place that schedules a sweep. + u.maybeSweep(now) + } + return actualLimiter.(*slidingWindowLimiter) +} - // Atomically load or store. - // - If compositeKey already exists (due to concurrent creation), LoadOrStore loads and returns the existing value. - // - If compositeKey doesn't exist, LoadOrStore stores newLimiter and returns it. - actualLimiter, _ := u.limiters.LoadOrStore(compositeKey, newLimiter) +// maybeSweep starts a sweep of stale entries unless one already started within +// the last limiterSweepInterval. The CAS allows at most one winner per +// interval; losing callers return immediately. +func (u *UberRateLimiter) maybeSweep(now time.Time) { + last := u.lastSweep.Load() + if now.UnixNano()-last < int64(limiterSweepInterval) { + return + } + if !u.lastSweep.CompareAndSwap(last, now.UnixNano()) { + return + } + go u.sweep(now) +} - // Return the actual limiter stored in the map (either the existing one or the new one). - return actualLimiter.(ratelimit.Limiter) +// sweep deletes entries that can no longer influence any rate-limit decision. +// An entry is stale once it has been idle for at least two of its own windows: +// at that point the next request would zero both window counters anyway, so +// deleting the entry (and lazily recreating it on next use) is unobservable. +// +// The staleness check and the evicted flag are written under the same lock +// acquisition that tryAllow counts under, so no request can record a count +// between the decision to evict and the eviction itself: any request that +// counted before the check rolled windowStart forward (making the entry +// non-stale), and any request arriving after sees the evicted flag and +// re-fetches a fresh entry. +func (u *UberRateLimiter) sweep(now time.Time) { + u.limiters.Range(func(key, value any) bool { + l := value.(*slidingWindowLimiter) + l.mu.Lock() + stale := l.window > 0 && !l.windowStart.IsZero() && now.Sub(l.windowStart) >= 2*l.window + if stale { + l.evicted = true + } + l.mu.Unlock() + if stale { + u.limiters.Delete(key) + } + return true + }) } // Ensure UberRateLimiter implements the common.RateLimiter interface. var _ common.RateLimiter = (*UberRateLimiter)(nil) // Allow checks if a request is allowed based on the key and rate limit configuration. -// It implements the common.RateLimiter interface using the leaky bucket algorithm. +// It implements the common.RateLimiter interface using a sliding-window counter. +// The check never blocks: over-limit requests are denied immediately. // // Parameters: // - key: Unique identifier for the rate limit bucket (e.g., "api:IP:192.168.1.1") @@ -70,41 +229,33 @@ var _ common.RateLimiter = (*UberRateLimiter)(nil) // - remaining: Estimated number of remaining requests in the current window // - reset: Duration until the next request will be allowed (0 if allowed now) func (u *UberRateLimiter) Allow(key string, limit int, window time.Duration) (bool, int, time.Duration) { - // Convert limit and window to Requests Per Second (RPS) for Uber's limiter. - // Ensure RPS is at least 1. - rps := int(float64(limit) / window.Seconds()) - if rps < 1 { - rps = 1 + if limit <= 0 { + // A non-positive limit allows nothing. + return false, 0, window } - - limiter := u.getLimiter(key, rps) - - // Take() blocks until a token is available or returns immediately if available. - // It returns the time when the next token will be available. - now := time.Now() - nextAvailable := limiter.Take() - waitTime := nextAvailable.Sub(now) - - // Estimate remaining tokens based on the wait time relative to the window. - // This is an approximation for leaky bucket. - remaining := int(float64(limit) * (1 - waitTime.Seconds()/window.Seconds())) - if remaining < 0 { - remaining = 0 + if window <= 0 { + window = time.Second } - // If the wait time is significant (e.g., > 1ms, indicating actual rate limiting), deny. - // Uber's Take() might return a time slightly in the future even if not strictly limited. - // A small threshold helps distinguish actual limiting from minor clock differences. - // If waitTime is 0 or very small, the request is allowed. - allowed := waitTime <= time.Millisecond // Allow if wait time is negligible - - // Reset time is the duration until the next token is available. - resetDuration := waitTime - if resetDuration < 0 { - resetDuration = 0 // Cannot reset in the past + // The composite key includes limit and window so different rate limits + // for the same base key don't share counters. Built with strconv instead + // of fmt.Sprintf: this runs on every rate-limited request and Sprintf's + // reflection is measurably slower. + compositeKey := key + "|" + strconv.Itoa(limit) + "|" + strconv.FormatInt(int64(window), 10) + + now := u.timeNow() + for { + limiter := u.getLimiter(compositeKey, window, now) + allowed, remaining, reset, valid := limiter.tryAllow(limit, window, now) + if valid { + return allowed, remaining, reset + } + // The sweeper marked this entry evicted between our lookup and use. + // Its own Delete may not have landed yet, so remove the entry here — + // CompareAndDelete is a no-op if the sweeper already removed it or a + // fresh entry has replaced it — and retry with a fresh entry. + u.limiters.CompareAndDelete(compositeKey, limiter) } - - return allowed, remaining, resetDuration } // convertUserIDToString provides default conversions for common comparable types to string. @@ -128,33 +279,33 @@ func convertUserIDToString[T comparable](userID T) string { // extractUserKey extracts the user-based key (as a string) from the request context. // It prioritizes the user object if UserIDFromUser is provided, otherwise uses the user ID directly. -// Returns an empty string if no user information is found or conversion fails. -func extractUserKey[T comparable, U any](r *http.Request, config *common.RateLimitConfig[T, U]) (string, error) { // Use common.RateLimitConfig - if config.UserIDToString == nil { - return "", errors.New("UserIDToString function is required for StrategyUser") +// If config.UserIDToString is nil, a default conversion (convertUserIDToString) is used, +// so StrategyUser works out of the box for common ID types. +// Returns an empty string if no user information is found. +func extractUserKey[T comparable, U any](r *http.Request, config *common.RateLimitConfig[T, U]) string { // Use common.RateLimitConfig + userIDToString := config.UserIDToString + if userIDToString == nil { + userIDToString = convertUserIDToString[T] } - // Try getting the full user object first + // Try getting the full user object first. Extracting an ID from it + // requires the UserIDFromUser function; without it, fall through to the + // user ID from the context. user, userOk := scontext.GetUserFromRequest[T, U](r) // Use scontext - if userOk && user != nil { - if config.UserIDFromUser == nil { - // Cannot extract ID from user object without UserIDFromUser function - // Try falling back to UserID directly - } else { - userID := config.UserIDFromUser(*user) - return config.UserIDToString(userID), nil - } + if userOk && user != nil && config.UserIDFromUser != nil { + userID := config.UserIDFromUser(*user) + return userIDToString(userID) } // Fallback: Try getting the user ID directly from context userID, idOk := scontext.GetUserIDFromRequest[T, U](r) // Use scontext if idOk { - // Use the provided conversion function - return config.UserIDToString(userID), nil + // Use the conversion function + return userIDToString(userID) } // No user information found in context - return "", nil // Return empty key, let the caller decide how to handle (e.g., fallback to IP) + return "" // Return empty key, let the caller decide how to handle (e.g., fallback to IP) } // RateLimit creates a middleware that enforces rate limits based on the provided configuration. @@ -206,16 +357,7 @@ func RateLimit[T comparable, U any](config *common.RateLimitConfig[T, U], limite case common.StrategyUser: // Use common.StrategyUser strategyUsed = "User" - key, err = extractUserKey(r, config) - if err != nil { - logger.Error("Failed to extract user key for rate limiting", - zap.Error(err), - zap.String("method", r.Method), - zap.String("path", r.URL.Path), - ) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + key = extractUserKey(r, config) // If no user key found, fall back to IP strategy as a safety measure if key == "" { strategyUsed = "User (fallback to IP)" diff --git a/pkg/middleware/ratelimit_codecov_test.go b/pkg/middleware/ratelimit_codecov_test.go index 476dd60..18828d4 100644 --- a/pkg/middleware/ratelimit_codecov_test.go +++ b/pkg/middleware/ratelimit_codecov_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/scontext" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -52,27 +53,27 @@ func TestUberRateLimiter_AllowedCondition(t *testing.T) { }) } -// TestExtractUserKey_NilUserIDToString specifically tests lines 126-127 in ratelimit.go -// where it checks if UserIDToString is nil and returns an error +// TestExtractUserKey_NilUserIDToString verifies that a nil UserIDToString +// falls back to the default conversion, so StrategyUser works when configured +// via [any, any] overrides. func TestExtractUserKey_NilUserIDToString_Codecov(t *testing.T) { - // Create a request + // Create a request with a user ID in the context req := httptest.NewRequest("GET", "/", nil) + ctx := scontext.WithUserID[string, string](req.Context(), "user-42") + req = req.WithContext(ctx) // Create a config with nil UserIDToString config := &common.RateLimitConfig[string, string]{ Strategy: common.StrategyUser, UserIDFromUser: func(u string) string { return u }, - UserIDToString: nil, // Explicitly set to nil to test the nil check + UserIDToString: nil, // Nil triggers the default conversion } // Call extractUserKey - key, err := extractUserKey(req, config) + key := extractUserKey(req, config) - // Verify that an error is returned - assert.Error(t, err, "extractUserKey should return an error when UserIDToString is nil") - assert.Equal(t, "", key, "Key should be empty when error is returned") - assert.Contains(t, err.Error(), "UserIDToString function is required", - "Error message should indicate UserIDToString is required") + // Verify the default conversion is used + assert.Equal(t, "user-42", key, "Key should come from the default user ID conversion") } // TestRateLimit_CustomStrategyEmptyKey_Codecov specifically tests lines 258-259 in ratelimit.go diff --git a/pkg/middleware/ratelimit_coverage_test.go b/pkg/middleware/ratelimit_coverage_test.go index ca51c8e..06fe2bb 100644 --- a/pkg/middleware/ratelimit_coverage_test.go +++ b/pkg/middleware/ratelimit_coverage_test.go @@ -13,7 +13,6 @@ import ( "github.com/Suhaibinator/SRouter/pkg/common" // Added import "github.com/Suhaibinator/SRouter/pkg/scontext" // Added import "github.com/stretchr/testify/assert" - "go.uber.org/ratelimit" // Added import for assert.Same "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" @@ -42,8 +41,8 @@ func TestGetLimiter_Coverage(t *testing.T) { key := "concurrent-get-limiter-key" limit := 1000 window := 10 * time.Millisecond - // Calculate rps the same way Allow does - rps := int(float64(limit) / window.Seconds()) + // Composite key format used by getLimiter + compositeKey := fmt.Sprintf("%s|%d|%d", key, limit, int64(window)) numGoroutines := 20 callsPerGoroutine := 5 @@ -52,7 +51,7 @@ func TestGetLimiter_Coverage(t *testing.T) { wg.Add(numGoroutines) // Use a map to store the first limiter instance seen by any goroutine - var firstLimiter ratelimit.Limiter + var firstLimiter *slidingWindowLimiter var once sync.Once for i := 0; i < numGoroutines; i++ { @@ -64,18 +63,14 @@ func TestGetLimiter_Coverage(t *testing.T) { assert.True(t, allowed, "Request should be allowed initially") // Check if the same limiter instance is returned across goroutines - // This indirectly tests that the double-check prevents creating multiple limiters - compositeKey := fmt.Sprintf("%s-%d", key, rps) + // This indirectly tests that LoadOrStore prevents creating multiple limiters val, ok := limiter.limiters.Load(compositeKey) // Explicitly check 'ok' before proceeding to prevent panic on nil interface conversion if !assert.True(t, ok, "Limiter should exist in the map for key %s", compositeKey) { - // If the assertion fails (ok is false), skip the rest of the checks for this iteration - // as val will be nil, causing a panic on type assertion. - // This indicates a potential timing issue or problem in limiter creation/storage. continue // Skip to the next iteration of the inner loop } // Only proceed if ok is true - currentLimiter := val.(ratelimit.Limiter) + currentLimiter := val.(*slidingWindowLimiter) once.Do(func() { firstLimiter = currentLimiter // Capture the first successfully retrieved limiter @@ -89,15 +84,15 @@ func TestGetLimiter_Coverage(t *testing.T) { wg.Wait() - // Final check that only one limiter was created for the key/rps combination + // Final check that only one limiter was created for the key/limit/window combination count := 0 limiter.limiters.Range(func(k, v interface{}) bool { - if k == fmt.Sprintf("%s-%d", key, rps) { + if k == compositeKey { count++ } return true }) - assert.Equal(t, 1, count, "Expected exactly one limiter instance for the key/rps") + assert.Equal(t, 1, count, "Expected exactly one limiter instance for the key/limit/window") } // TestRateLimitWithCustomKeyExtractor tests the RateLimit function with a custom key extractor @@ -401,8 +396,8 @@ func TestRateLimit_UserStrategyFallback(t *testing.T) { }) } -// TestExtractUserKey_NilUserIDToString specifically tests lines 126-127 in ratelimit.go -// where it checks if UserIDToString is nil and returns an error +// TestExtractUserKey_NilUserIDToString verifies that a nil UserIDToString +// falls back to the default user ID conversion instead of erroring. func TestExtractUserKey_NilUserIDToString(t *testing.T) { // Create a request with a user in context req := httptest.NewRequest("GET", "/", nil) @@ -414,17 +409,14 @@ func TestExtractUserKey_NilUserIDToString(t *testing.T) { config := &common.RateLimitConfig[string, string]{ Strategy: common.StrategyUser, UserIDFromUser: func(u string) string { return u }, - UserIDToString: nil, // Explicitly set to nil to test the nil check + UserIDToString: nil, // Nil triggers the default conversion } // Call extractUserKey - key, err := extractUserKey(req, config) + key := extractUserKey(req, config) - // Verify that an error is returned - assert.Error(t, err, "extractUserKey should return an error when UserIDToString is nil") - assert.Equal(t, "", key, "Key should be empty when error is returned") - assert.Contains(t, err.Error(), "UserIDToString function is required", - "Error message should indicate UserIDToString is required") + // Verify the default conversion is used + assert.Equal(t, "test-user", key, "Key should come from the default user ID conversion") } // TestRateLimit_CustomStrategyEmptyKey tests the case where the custom key extractor returns an empty string diff --git a/pkg/middleware/ratelimit_eviction_test.go b/pkg/middleware/ratelimit_eviction_test.go new file mode 100644 index 0000000..11c82cf --- /dev/null +++ b/pkg/middleware/ratelimit_eviction_test.go @@ -0,0 +1,492 @@ +package middleware + +import ( + "fmt" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Tests in this file cover eviction of stale rate-limiter entries: staleness +// detection, sweep scheduling, the evicted-entry retry path in Allow, and the +// guarantee that eviction never changes a rate-limit decision. + +// fakeClock is a manually advanced clock for driving the limiter's sweep and +// staleness logic deterministically. +type fakeClock struct { + nanos atomic.Int64 +} + +func newFakeClock(start time.Time) *fakeClock { + c := &fakeClock{} + c.nanos.Store(start.UnixNano()) + return c +} + +func (c *fakeClock) now() time.Time { return time.Unix(0, c.nanos.Load()) } +func (c *fakeClock) advance(d time.Duration) { c.nanos.Add(int64(d)) } + +// limiterEntryCount counts the entries currently stored in the limiter map. +func limiterEntryCount(u *UberRateLimiter) int { + n := 0 + u.limiters.Range(func(_, _ any) bool { + n++ + return true + }) + return n +} + +// limiterCompositeKey mirrors the composite key format Allow builds. +func limiterCompositeKey(key string, limit int, window time.Duration) string { + return key + "|" + strconv.Itoa(limit) + "|" + strconv.FormatInt(int64(window), 10) +} + +// TestSweepEvictsStaleEntries verifies that a sweep removes every entry that +// has been idle for at least two of its own windows, including entries idle +// for exactly two windows (the boundary is inclusive because at that point +// the next request would zero both counters anyway). +func TestSweepEvictsStaleEntries(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + + const keys = 100 + for i := range keys { + limiter.Allow(fmt.Sprintf("ip-%d", i), 5, window) + } + if got := limiterEntryCount(limiter); got != keys { + t.Fatalf("expected %d entries after creation, got %d", keys, got) + } + + // One window of idleness is not enough: prev/curr still carry history. + clock.advance(window) + limiter.sweep(clock.now()) + if got := limiterEntryCount(limiter); got != keys { + t.Fatalf("entries idle for one window must survive, got %d of %d", got, keys) + } + + // Exactly two windows idle: stale, everything goes. + clock.advance(window) + limiter.sweep(clock.now()) + if got := limiterEntryCount(limiter); got != 0 { + t.Fatalf("expected all entries evicted after 2 windows idle, got %d", got) + } +} + +// TestSweepKeepsActiveEntriesAndTheirCounts verifies that a sweep only removes +// stale entries, and that a surviving entry keeps its window counts intact. +func TestSweepKeepsActiveEntriesAndTheirCounts(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + const limit = 3 + + // keyOld is used once at t0 and then goes idle. + limiter.Allow("key-old", limit, window) + + // keyNew is used once at t0 + 1.5w. + clock.advance(3 * window / 2) + if allowed, remaining, _ := limiter.Allow("key-new", limit, window); !allowed || remaining != limit-1 { + t.Fatalf("key-new first request: got (allowed=%v, remaining=%d), want (true, %d)", allowed, remaining, limit-1) + } + + // At t0 + 2w: keyOld is 2 windows idle (stale), keyNew only 0.5w idle. + clock.advance(window / 2) + limiter.sweep(clock.now()) + + if _, ok := limiter.limiters.Load(limiterCompositeKey("key-old", limit, window)); ok { + t.Error("key-old should have been evicted") + } + if _, ok := limiter.limiters.Load(limiterCompositeKey("key-new", limit, window)); !ok { + t.Fatal("key-new must survive the sweep") + } + + // keyNew's earlier request must still count: same window, so remaining + // reflects one prior use (limit - 1 used - 1 for this request). + if allowed, remaining, _ := limiter.Allow("key-new", limit, window); !allowed || remaining != limit-2 { + t.Errorf("key-new after sweep: got (allowed=%v, remaining=%d), want (true, %d) — prior count was lost", + allowed, remaining, limit-2) + } + + // keyOld was legitimately idle for 2 windows, so recreation grants the + // full limit — the same answer a non-evicted entry would have given. + if allowed, remaining, _ := limiter.Allow("key-old", limit, window); !allowed || remaining != limit-1 { + t.Errorf("key-old after eviction: got (allowed=%v, remaining=%d), want (true, %d)", allowed, remaining, limit-1) + } +} + +// TestEvictionIsLossless runs the same scripted request sequence against two +// limiters sharing one clock — one swept aggressively, one never swept — and +// requires identical results (allowed, remaining, and reset) at every step. +// This is the core guarantee: eviction must never change a decision. +func TestEvictionIsLossless(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + swept := &UberRateLimiter{nowFunc: clock.now} + control := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + const limit = 3 + + step := func(label string) { + t.Helper() + a1, r1, reset1 := swept.Allow("k", limit, window) + a2, r2, reset2 := control.Allow("k", limit, window) + if a1 != a2 || r1 != r2 || reset1 != reset2 { + t.Fatalf("%s: swept limiter diverged: got (%v, %d, %v), control says (%v, %d, %v)", + label, a1, r1, reset1, a2, r2, reset2) + } + } + + // Exhaust the limit, then one denial. + for i := range limit + 1 { + step(fmt.Sprintf("initial request %d", i+1)) + } + + // Advance exactly two windows (a multiple of the window keeps the two + // limiters' window phases aligned) and evict on the swept limiter only. + clock.advance(2 * window) + swept.sweep(clock.now()) + if got := limiterEntryCount(swept); got != 0 { + t.Fatalf("expected swept limiter to be empty, got %d entries", got) + } + + // Both must now agree on a fresh burst: allow, allow, allow, deny. + for i := range limit + 1 { + step(fmt.Sprintf("post-eviction request %d", i+1)) + } + + // And on partial-window behavior afterwards. + clock.advance(window / 2) + step("half-window later") + clock.advance(window) + swept.sweep(clock.now()) // not stale (recent use): must be a no-op + step("one more window later") +} + +// TestAllowRetriesWhenEntryEvictedMidFlight pins the race where a request +// loads an entry just before the sweeper marks it evicted: Allow must detect +// the tombstone, replace the entry, and count against a fresh one rather than +// losing the count or hanging. +func TestAllowRetriesWhenEntryEvictedMidFlight(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + const limit = 5 + + limiter.Allow("victim", limit, window) + compositeKey := limiterCompositeKey("victim", limit, window) + + v, ok := limiter.limiters.Load(compositeKey) + if !ok { + t.Fatal("expected entry to exist") + } + old := v.(*slidingWindowLimiter) + + // Simulate a sweeper paused between marking the entry and deleting it + // from the map: the tombstone is set but the entry is still loadable. + old.mu.Lock() + old.evicted = true + old.mu.Unlock() + + done := make(chan struct{}) + var allowed bool + var remaining int + go func() { + defer close(done) + allowed, remaining, _ = limiter.Allow("victim", limit, window) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Allow hung on an evicted-but-not-deleted entry") + } + + // The retry must have created a fresh entry; the count lands there. + if !allowed || remaining != limit-1 { + t.Errorf("got (allowed=%v, remaining=%d), want (true, %d) against a fresh entry", allowed, remaining, limit-1) + } + v, ok = limiter.limiters.Load(compositeKey) + if !ok { + t.Fatal("expected a replacement entry in the map") + } + if v.(*slidingWindowLimiter) == old { + t.Error("evicted entry is still in the map; Allow must replace it") + } +} + +// TestMaybeSweepThrottlesToOnePerInterval verifies the sweep-scheduling gate: +// entry creations within limiterSweepInterval of the last sweep do not start +// another one, and a creation after the interval does. +func TestMaybeSweepThrottlesToOnePerInterval(t *testing.T) { + start := time.Unix(1_000_000, 0) + clock := newFakeClock(start) + limiter := &UberRateLimiter{nowFunc: clock.now} + // 2*window must exceed every clock advance below so async sweeps from + // maybeSweep can never evict anything while the test runs. + window := 24 * time.Hour + + // First creation claims the first sweep slot. + limiter.Allow("a", 1, window) + if got := limiter.lastSweep.Load(); got != start.UnixNano() { + t.Fatalf("expected first creation to claim a sweep at %d, got %d", start.UnixNano(), got) + } + + // Creations inside the interval must not claim another sweep. + clock.advance(limiterSweepInterval - time.Second) + limiter.Allow("b", 1, window) + limiter.Allow("c", 1, window) + if got := limiter.lastSweep.Load(); got != start.UnixNano() { + t.Errorf("creation inside the interval rescheduled a sweep: lastSweep moved to %d", got) + } + + // A creation after the interval claims the next sweep, even with many + // concurrent creators racing for it. + clock.advance(2 * time.Second) + want := clock.now().UnixNano() + var wg sync.WaitGroup + for i := range 32 { + wg.Go(func() { + limiter.Allow(fmt.Sprintf("racer-%d", i), 1, window) + }) + } + wg.Wait() + if got := limiter.lastSweep.Load(); got != want { + t.Errorf("expected one sweep claimed at %d after the interval, got %d", want, got) + } + + // Repeat Allows on existing keys never touch the gate (fast path). + clock.advance(2 * limiterSweepInterval) + limiter.Allow("a", 1, window) + if got := limiter.lastSweep.Load(); got != want { + t.Errorf("fast-path Allow scheduled a sweep: lastSweep moved to %d", got) + } +} + +// TestEvictionEndToEndThroughPublicAPI exercises the whole pipeline with no +// internal calls except clock injection: a stale key is evicted by the +// background sweep that a new key's creation triggers. +func TestEvictionEndToEndThroughPublicAPI(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + const limit = 2 + + // Exhaust the victim's limit. + limiter.Allow("victim", limit, window) + limiter.Allow("victim", limit, window) + if allowed, _, _ := limiter.Allow("victim", limit, window); allowed { + t.Fatal("victim should be rate limited") + } + + // Go idle long enough to be stale (2w) and to reopen the sweep gate + // (limiterSweepInterval); then a brand-new key triggers the async sweep. + advance := 2 * window + if advance <= limiterSweepInterval { + advance = limiterSweepInterval + 2*window + } + clock.advance(advance) + limiter.Allow("trigger", limit, window) + + // The sweep runs in a goroutine; wait for the victim entry to vanish. + victimKey := limiterCompositeKey("victim", limit, window) + deadline := time.Now().Add(2 * time.Second) + for { + if _, ok := limiter.limiters.Load(victimKey); !ok { + break + } + if time.Now().After(deadline) { + t.Fatal("victim entry was not evicted by the background sweep") + } + time.Sleep(time.Millisecond) + } + + // After legitimate 2w idleness the full limit is available again. + if allowed, remaining, _ := limiter.Allow("victim", limit, window); !allowed || remaining != limit-1 { + t.Errorf("victim after eviction: got (allowed=%v, remaining=%d), want (true, %d)", allowed, remaining, limit-1) + } +} + +// TestMapSizeStaysBoundedUnderKeyChurn simulates the attack the eviction +// exists for — a flood of never-repeating keys — and asserts the map returns +// to baseline once the flood ages out, rather than growing monotonically. +func TestMapSizeStaysBoundedUnderKeyChurn(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + + const perWave = 1000 + for wave := range 3 { + for i := range perWave { + limiter.Allow(fmt.Sprintf("wave%d-ip%d", wave, i), 10, window) + } + // Each wave then ages past staleness before the next arrives. + clock.advance(2 * window) + limiter.sweep(clock.now()) + + if got := limiterEntryCount(limiter); got != 0 { + t.Fatalf("after wave %d: expected 0 surviving entries, got %d (map is leaking)", wave, got) + } + } +} + +// TestSweepDoesNotEvictConcurrentlyCreatedEntry verifies that an entry created +// at the sweep's own timestamp is never considered stale (its windowStart is +// initialized at creation, so a racing sweep sees zero idle time). +func TestSweepDoesNotEvictConcurrentlyCreatedEntry(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + + limiter.Allow("fresh", 5, window) + limiter.sweep(clock.now()) + + if _, ok := limiter.limiters.Load(limiterCompositeKey("fresh", 5, window)); !ok { + t.Fatal("a just-created entry must survive a sweep at the same instant") + } +} + +// TestSweepIgnoresZeroValueEntries verifies the sweeper's defensive guards: an +// entry with no window or no windowStart (impossible via Allow, possible via +// direct construction) is left alone rather than evicted by arithmetic on +// zero values. +func TestSweepIgnoresZeroValueEntries(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + + limiter.limiters.Store("bare", &slidingWindowLimiter{}) + limiter.limiters.Store("no-start", &slidingWindowLimiter{window: time.Minute}) + limiter.limiters.Store("no-window", &slidingWindowLimiter{windowStart: clock.now().Add(-time.Hour)}) + + limiter.sweep(clock.now()) + + for _, key := range []string{"bare", "no-start", "no-window"} { + if _, ok := limiter.limiters.Load(key); !ok { + t.Errorf("zero-value entry %q must not be evicted", key) + } + } +} + +// TestNoOverAdmissionWhileSweepsRun hammers one key from many goroutines with +// the clock frozen — so nothing is ever stale and every count must be kept — +// while sweeps run continuously in the background. The total number of allowed +// requests must be exactly the limit; one extra admission would mean a sweep +// raced a count away. +func TestNoOverAdmissionWhileSweepsRun(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := time.Minute + const limit = 50 + const workers = 8 + const requestsPerWorker = 100 + + stopSweeps := make(chan struct{}) + var sweepWG sync.WaitGroup + sweepWG.Go(func() { + for { + select { + case <-stopSweeps: + return + default: + limiter.sweep(clock.now()) + } + } + }) + + var allowedTotal atomic.Int64 + var wg sync.WaitGroup + for range workers { + wg.Go(func() { + for range requestsPerWorker { + allowed, remaining, _ := limiter.Allow("contended", limit, window) + if allowed { + allowedTotal.Add(1) + } + if remaining < 0 || remaining >= limit { + t.Errorf("remaining out of range: %d", remaining) + } + } + }) + } + wg.Wait() + close(stopSweeps) + sweepWG.Wait() + + if got := allowedTotal.Load(); got != limit { + t.Errorf("expected exactly %d admissions with a frozen clock, got %d", limit, got) + } +} + +// TestEvictionUnderConcurrentChaos races Allow calls on a churning key set +// against continuous clock advancement and sweeping. It asserts liveness (no +// hang on the retry path), sane return values, and a race-free run under +// -race; the eviction/retry machinery is being exercised constantly because +// every clock jump makes existing entries stale. +func TestEvictionUnderConcurrentChaos(t *testing.T) { + clock := newFakeClock(time.Unix(1_000_000, 0)) + limiter := &UberRateLimiter{nowFunc: clock.now} + window := 10 * time.Millisecond + const limit = 5 + const workers = 8 + const iterations = 500 + + stopChaos := make(chan struct{}) + var chaosWG sync.WaitGroup + chaosWG.Go(func() { + for { + select { + case <-stopChaos: + return + default: + // Every jump staleness-es all existing entries, then the + // sweep evicts them out from under in-flight Allows. + clock.advance(2 * window) + limiter.sweep(clock.now()) + } + } + }) + + var wg sync.WaitGroup + for w := range workers { + wg.Go(func() { + for i := range iterations { + key := fmt.Sprintf("chaos-%d", (w*iterations+i)%16) + allowed, remaining, reset := limiter.Allow(key, limit, window) + if remaining < 0 || remaining >= limit { + t.Errorf("remaining out of range: %d", remaining) + } + if !allowed && (reset <= 0 || reset > window) { + t.Errorf("denied request got reset %v outside (0, %v]", reset, window) + } + } + }) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(30 * time.Second): + t.Fatal("workers hung — likely a livelock in the eviction retry path") + } + close(stopChaos) + chaosWG.Wait() +} + +// BenchmarkUberRateLimiterAllowHotPath measures the steady-state path (entry +// already exists). Eviction must not add work here: the sweep gate is only +// consulted on entry creation. +func BenchmarkUberRateLimiterAllowHotPath(b *testing.B) { + limiter := NewUberRateLimiter() + limiter.Allow("bench", 1_000_000_000, time.Minute) + + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + limiter.Allow("bench", 1_000_000_000, time.Minute) + } +} diff --git a/pkg/middleware/ratelimit_generic_additional_test.go b/pkg/middleware/ratelimit_generic_additional_test.go index ee68594..49f7492 100644 --- a/pkg/middleware/ratelimit_generic_additional_test.go +++ b/pkg/middleware/ratelimit_generic_additional_test.go @@ -9,7 +9,6 @@ import ( "github.com/Suhaibinator/SRouter/pkg/common" // Added import "github.com/Suhaibinator/SRouter/pkg/scontext" // Import scontext - "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -46,10 +45,7 @@ func TestExtractUser(t *testing.T) { } // Extract the user key - userKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + userKey := extractUserKey(req, config) if userKey != "testUser" { t.Errorf("Expected user key 'testUser', got '%s'", userKey) } @@ -71,12 +67,10 @@ func TestExtractUser(t *testing.T) { }, } - // Extract the user key - expect error because UserIDToString is missing - _, err := extractUserKey(req, config) // Use extractUserKey - if err == nil { - t.Error("Expected error when UserIDToString is missing, got nil") - } else { - assert.Contains(t, err.Error(), "UserIDToString function is required") + // Extract the user key - the default conversion is used when UserIDToString is nil + userKey := extractUserKey(req, config) + if userKey != "testUser-modified" { + t.Errorf("Expected user key 'testUser-modified', got '%s'", userKey) } }) @@ -99,10 +93,7 @@ func TestExtractUser(t *testing.T) { // Extract the user key // Add UserIDToString for int config.UserIDToString = func(id int) string { return strconv.Itoa(id) } - userKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + userKey := extractUserKey(req, config) if userKey != "42" { t.Errorf("Expected user key '42', got '%s'", userKey) } @@ -130,10 +121,7 @@ func TestExtractUser(t *testing.T) { // Extract the user key // Add UserIDToString for CustomID config.UserIDToString = func(id CustomID) string { return id.String() } // Use String() method - userKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + userKey := extractUserKey(req, config) if userKey != "custom-id" { // Expect the result of String() t.Errorf("Expected user key 'custom-id', got '%s'", userKey) } @@ -155,10 +143,7 @@ func TestExtractUser(t *testing.T) { } // Extract the user key - extractedKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + extractedKey := extractUserKey(req, config) if extractedKey != "user123-suffix" { t.Errorf("Expected user key 'user123-suffix', got '%s'", extractedKey) } @@ -172,17 +157,15 @@ func TestExtractUser(t *testing.T) { ctx := scontext.WithUserID[string, string](req.Context(), userID) // Use scontext req = req.WithContext(ctx) - // Create a config without UserIDToString function - this should error now + // Create a config without UserIDToString function - the default conversion is used config := &common.RateLimitConfig[string, string]{ // Use common.RateLimitConfig // Missing UserIDToString } - // Extract the user key - expect error - _, err := extractUserKey(req, config) // Use extractUserKey - if err == nil { - t.Error("Expected error when UserIDToString is missing, got nil") - } else { - assert.Contains(t, err.Error(), "UserIDToString function is required") + // Extract the user key - the default string conversion applies + extractedKey := extractUserKey(req, config) + if extractedKey != "user123" { + t.Errorf("Expected user key 'user123', got '%s'", extractedKey) } }) @@ -200,10 +183,7 @@ func TestExtractUser(t *testing.T) { } // Extract the user key - extractedKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + extractedKey := extractUserKey(req, config) if extractedKey != "42" { t.Errorf("Expected user key '42', got '%s'", extractedKey) } @@ -223,10 +203,7 @@ func TestExtractUser(t *testing.T) { } // Extract the user key - extractedKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + extractedKey := extractUserKey(req, config) if extractedKey != "true" { t.Errorf("Expected user key 'true', got '%s'", extractedKey) } @@ -243,10 +220,7 @@ func TestExtractUser(t *testing.T) { } // Extract the user key - extractedKey, err := extractUserKey(req, config) // Use extractUserKey - if err != nil { - t.Fatalf("extractUserKey failed: %v", err) - } + extractedKey := extractUserKey(req, config) if extractedKey != "" { t.Errorf("Expected empty user key, got '%s'", extractedKey) } @@ -361,7 +335,8 @@ func TestRateLimit(t *testing.T) { Limit: 10, Window: time.Duration(1000) * time.Millisecond, Strategy: common.StrategyUser, // Use common.StrategyUser - // Missing UserIDToString - this should cause an error in extractUserKey + // Missing UserIDToString - the default conversion is used, and with + // no user in context the middleware falls back to the IP strategy. } // Create the middleware @@ -385,12 +360,13 @@ func TestRateLimit(t *testing.T) { // Call the handler handler.ServeHTTP(rr, req) - // Check that the handler was NOT called and status is 500 - if handlerCalled { - t.Error("Handler was called unexpectedly") + // With no user in context, the middleware falls back to IP-based limiting + // and the request proceeds normally. + if !handlerCalled { + t.Error("Handler should have been called via IP fallback") } - if rr.Code != http.StatusInternalServerError { - t.Errorf("Expected status code %d due to missing UserIDToString, got %d", http.StatusInternalServerError, rr.Code) + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d via IP fallback, got %d", http.StatusOK, rr.Code) } }) diff --git a/pkg/middleware/ratelimit_slidingwindow_test.go b/pkg/middleware/ratelimit_slidingwindow_test.go new file mode 100644 index 0000000..c482a46 --- /dev/null +++ b/pkg/middleware/ratelimit_slidingwindow_test.go @@ -0,0 +1,97 @@ +package middleware + +import ( + "testing" + "time" +) + +// TestSlidingWindowRolloverWeightsPreviousWindow verifies the window rollover +// behavior of the sliding-window limiter: the previous window's count keeps +// counting against the limit in proportion to how much it still overlaps the +// sliding window, smoothing bursts at window boundaries. +func TestSlidingWindowRolloverWeightsPreviousWindow(t *testing.T) { + l := &slidingWindowLimiter{} + const limit = 2 + window := time.Minute + t0 := time.Unix(1_000_000, 0) + + // Use up the full limit inside the first window. + if allowed, remaining, _ := l.allow(limit, window, t0); !allowed || remaining != 1 { + t.Fatalf("first request: got (allowed=%v, remaining=%d), want (true, 1)", allowed, remaining) + } + if allowed, remaining, _ := l.allow(limit, window, t0.Add(time.Second)); !allowed || remaining != 0 { + t.Fatalf("second request: got (allowed=%v, remaining=%d), want (true, 0)", allowed, remaining) + } + + // Over the limit: denied, and reset reports the time left in the window. + allowed, _, reset := l.allow(limit, window, t0.Add(10*time.Second)) + if allowed { + t.Fatal("third request inside the window should be denied") + } + if reset != 50*time.Second { + t.Errorf("expected reset of 50s (window remainder), got %v", reset) + } + + // Exactly at the window boundary the previous window still fully overlaps + // the sliding window, so a burst right at the boundary is still denied. + allowed, _, reset = l.allow(limit, window, t0.Add(window)) + if allowed { + t.Fatal("request at the window boundary should still be denied (previous window fully weighted)") + } + if reset != window { + t.Errorf("expected reset of one full window, got %v", reset) + } + + // Halfway through the second window only half of the previous window's + // 2 requests count (estimated usage 1 of 2), so one request is allowed. + if allowed, remaining, _ := l.allow(limit, window, t0.Add(90*time.Second)); !allowed || remaining != 0 { + t.Fatalf("mid-second-window request: got (allowed=%v, remaining=%d), want (true, 0)", allowed, remaining) + } + + // After a gap of two or more full windows all history is discarded and the + // full limit is available again. + if allowed, remaining, _ := l.allow(limit, window, t0.Add(4*window)); !allowed || remaining != 1 { + t.Fatalf("request after long idle gap: got (allowed=%v, remaining=%d), want (true, 1)", allowed, remaining) + } +} + +// TestUberRateLimiterAllowNonPositiveLimit verifies that a zero or negative +// limit denies every request immediately and reports the window as the reset. +func TestUberRateLimiterAllowNonPositiveLimit(t *testing.T) { + limiter := NewUberRateLimiter() + + for _, limit := range []int{0, -1} { + allowed, remaining, reset := limiter.Allow("nonpositive", limit, time.Minute) + if allowed { + t.Errorf("limit %d: request should be denied", limit) + } + if remaining != 0 { + t.Errorf("limit %d: expected 0 remaining, got %d", limit, remaining) + } + if reset != time.Minute { + t.Errorf("limit %d: expected reset equal to window, got %v", limit, reset) + } + } +} + +// TestUberRateLimiterAllowNonPositiveWindowDefaultsToOneSecond verifies that a +// zero (or negative) window falls back to a one-second window instead of +// producing a degenerate limiter. +func TestUberRateLimiterAllowNonPositiveWindowDefaultsToOneSecond(t *testing.T) { + limiter := NewUberRateLimiter() + + allowed, remaining, _ := limiter.Allow("zero-window", 1, 0) + if !allowed || remaining != 0 { + t.Fatalf("first request: got (allowed=%v, remaining=%d), want (true, 0)", allowed, remaining) + } + + // The second request must be denied, and the reset must reflect the + // defaulted one-second window. + allowed, _, reset := limiter.Allow("zero-window", 1, 0) + if allowed { + t.Fatal("second request should be denied within the defaulted window") + } + if reset <= 0 || reset > time.Second { + t.Errorf("expected reset within (0, 1s] for the defaulted window, got %v", reset) + } +} diff --git a/pkg/middleware/ratelimit_test.go b/pkg/middleware/ratelimit_test.go index e1de4f6..02fab04 100644 --- a/pkg/middleware/ratelimit_test.go +++ b/pkg/middleware/ratelimit_test.go @@ -33,14 +33,9 @@ func TestUberRateLimiter(t *testing.T) { t.Errorf("Expected remaining to be positive, got %d", remaining) } - // Test that the limiter is reusing the same limiter for the same key and rps + // Test that the limiter is reusing the same limiter for the same key/limit/window limiter.Allow(key, limit, window) - // Calculate expected RPS and composite key - rps1 := int(float64(limit) / window.Seconds()) - if rps1 < 1 { - rps1 = 1 - } - compositeKey1 := fmt.Sprintf("%s-%d", key, rps1) + compositeKey1 := fmt.Sprintf("%s|%d|%d", key, limit, int64(window)) _, exists := limiter.limiters.Load(compositeKey1) if !exists { t.Errorf("Expected limiter to be stored for composite key %s", compositeKey1) @@ -56,13 +51,8 @@ func TestUberRateLimiter(t *testing.T) { t.Errorf("Expected remaining to be positive, got %d", remaining) } - // Test that the limiter is storing different limiters for different keys (with the same rps) - // Calculate expected RPS and composite key for the other key - rps2 := int(float64(limit) / window.Seconds()) // Same limit/window as first test - if rps2 < 1 { - rps2 = 1 - } - compositeKey2 := fmt.Sprintf("%s-%d", otherKey, rps2) + // Test that the limiter is storing different limiters for different keys (with the same limit/window) + compositeKey2 := fmt.Sprintf("%s|%d|%d", otherKey, limit, int64(window)) _, exists = limiter.limiters.Load(compositeKey2) if !exists { t.Errorf("Expected limiter to be stored for composite key %s", compositeKey2) @@ -79,12 +69,8 @@ func TestUberRateLimiter(t *testing.T) { t.Errorf("Expected remaining to be positive, got %d", remaining) } - // Also test that the new limiter with different rps is stored - rps3 := int(float64(differentLimit) / differentWindow.Seconds()) - if rps3 < 1 { - rps3 = 1 - } - compositeKey3 := fmt.Sprintf("%s-%d", key, rps3) + // Also test that the new limiter with a different limit/window is stored + compositeKey3 := fmt.Sprintf("%s|%d|%d", key, differentLimit, int64(differentWindow)) _, exists = limiter.limiters.Load(compositeKey3) if !exists { t.Errorf("Expected limiter to be stored for composite key %s (different limit/window)", compositeKey3) @@ -187,7 +173,13 @@ func TestConvertUserIDToString(t *testing.T) { t.Errorf("Expected string to be true, got %s", str) } - // Test with a custom type that implements String() + // Test with a custom type that implements fmt.Stringer + str = convertUserIDToString(CustomID{id: "custom-id"}) + if str != "custom-id" { + t.Errorf("Expected string to be custom-id, got %s", str) + } + + // Test with a custom type without a String method (fmt.Sprint fallback) type CustomType struct{} str = convertUserIDToString(CustomType{}) if str != "{}" { @@ -828,3 +820,93 @@ func TestRateLimitMiddlewareDefaultStrategy(t *testing.T) { t.Errorf("Expected Retry-After header to be set") } } + +// TestUberRateLimiter_WindowSemanticsAndNonBlocking verifies the limiter +// honors the configured limit/window exactly (e.g. "2 per minute" is not +// inflated to 1/sec) and denies over-limit requests immediately instead of +// sleeping (regression for the blocking leaky-bucket implementation). +func TestUberRateLimiter_WindowSemanticsAndNonBlocking(t *testing.T) { + limiter := NewUberRateLimiter() + key := "window-semantics" + limit := 2 + window := time.Minute + + start := time.Now() + + // The first `limit` requests are allowed. + for i := 0; i < limit; i++ { + allowed, _, _ := limiter.Allow(key, limit, window) + if !allowed { + t.Fatalf("request %d should be allowed (limit %d per %v)", i+1, limit, window) + } + } + + // Subsequent requests within the window are denied immediately. + for i := 0; i < 5; i++ { + allowed, remaining, reset := limiter.Allow(key, limit, window) + if allowed { + t.Fatalf("request beyond limit should be denied (window semantics: %d per %v)", limit, window) + } + if remaining != 0 { + t.Errorf("expected 0 remaining when denied, got %d", remaining) + } + if reset <= 0 || reset > window { + t.Errorf("expected reset in (0, %v], got %v", window, reset) + } + } + + // The whole sequence must not have blocked (old implementation slept until + // the next leaky-bucket slot before denying). + if elapsed := time.Since(start); elapsed > 100*time.Millisecond { + t.Fatalf("Allow calls blocked for %v; rate limit checks must be non-blocking", elapsed) + } +} + +// TestRateLimitDefaultStrategyUsesContextIP verifies that an unknown strategy +// falls back to IP-based limiting using the client IP from the context when +// it is available (rather than RemoteAddr). +func TestRateLimitDefaultStrategyUsesContextIP(t *testing.T) { + limiter := &captureLimiter{} + config := &common.RateLimitConfig[string, any]{ + BucketName: "test-bucket", + Limit: 1, + Window: time.Second, + Strategy: common.RateLimitStrategy(99), // Unknown strategy triggers the default case + } + + middleware := RateLimit(config, limiter, zap.NewNop()) + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.RemoteAddr = "10.0.0.1:1234" // ensure remote addr differs from context IP + ctx := scontext.WithClientIP[string, any](req.Context(), "203.0.113.9") + + handler.ServeHTTP(httptest.NewRecorder(), req.WithContext(ctx)) + + expectedKey := "test-bucket:203.0.113.9" + if limiter.lastKey != expectedKey { + t.Fatalf("expected limiter key %s, got %s", expectedKey, limiter.lastKey) + } +} + +// TestExtractUserKeyUserObjectWithoutUserIDFromUser verifies that when a user +// object is in the context but no UserIDFromUser function is configured, the +// key falls back to the user ID from the context. +func TestExtractUserKeyUserObjectWithoutUserIDFromUser(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + user := "userObject" + ctx := scontext.WithUser[string](req.Context(), &user) + ctx = scontext.WithUserID[string, string](ctx, "id-from-context") + req = req.WithContext(ctx) + + config := &common.RateLimitConfig[string, string]{ + Strategy: common.StrategyUser, + // No UserIDFromUser: the user object alone cannot produce a key. + } + + if key := extractUserKey(req, config); key != "id-from-context" { + t.Fatalf("expected fallback to user ID from context, got %q", key) + } +} diff --git a/pkg/middleware/timeout_writer_test.go b/pkg/middleware/timeout_writer_test.go new file mode 100644 index 0000000..7487113 --- /dev/null +++ b/pkg/middleware/timeout_writer_test.go @@ -0,0 +1,142 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// TestTimeoutDoesNotOverwriteStartedResponse verifies that when the handler +// has already started writing a response before the timeout fires, the +// middleware does not write a 408 on top of it: the client receives the +// handler's partial response untouched. +func TestTimeoutDoesNotOverwriteStartedResponse(t *testing.T) { + wrote := make(chan struct{}) + release := make(chan struct{}) + handlerDone := make(chan struct{}) + + handler := timeout(30 * time.Millisecond)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + w.WriteHeader(http.StatusAccepted) + _, _ = w.Write([]byte("partial")) + close(wrote) + <-release // Keep running past the timeout. + })) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil)) + + // The middleware returned at the timeout. Synchronize with the handler's + // write, after which it stays blocked, so the recorder is safe to inspect. + <-wrote + if rec.Code != http.StatusAccepted { + t.Errorf("status = %d, want %d (handler's status must not be replaced by a timeout response)", rec.Code, http.StatusAccepted) + } + if got := rec.Body.String(); got != "partial" { + t.Errorf("body = %q, want %q (no timeout error may be appended)", got, "partial") + } + + close(release) + <-handlerDone +} + +// TestMutexResponseWriterRejectsAllWritesAfterTimeout verifies that once the +// timeout has fired, a late handler can no longer touch the underlying +// response writer: WriteHeader is dropped, Write fails with +// http.ErrHandlerTimeout, and Flush is a no-op. +func TestMutexResponseWriterRejectsAllWritesAfterTimeout(t *testing.T) { + rec := httptest.NewRecorder() + var mu sync.Mutex + rw := &mutexResponseWriter{ResponseWriter: rec, mu: &mu} + rw.timedOut.Store(true) + + rw.WriteHeader(http.StatusTeapot) + if rec.Code != http.StatusOK { + t.Errorf("WriteHeader after timeout reached the underlying writer: code = %d", rec.Code) + } + if rw.wroteHeader.Load() { + t.Error("WriteHeader after timeout must not mark the header as written") + } + + n, err := rw.Write([]byte("late")) + if n != 0 || err != http.ErrHandlerTimeout { + t.Errorf("Write after timeout = (%d, %v), want (0, http.ErrHandlerTimeout)", n, err) + } + if rec.Body.Len() != 0 { + t.Errorf("Write after timeout reached the underlying writer: body = %q", rec.Body.String()) + } + + rw.Flush() + if rec.Flushed { + t.Error("Flush after timeout must not flush the underlying writer") + } +} + +// TestMutexResponseWriterWriteRecheckUnderLock verifies the race window where +// a handler Write passes the initial timeout check but the timeout response is +// written while the handler is waiting for the mutex: the write must be +// rejected by the re-check under the lock instead of corrupting the response. +func TestMutexResponseWriterWriteRecheckUnderLock(t *testing.T) { + rec := httptest.NewRecorder() + var mu sync.Mutex + rw := &mutexResponseWriter{ResponseWriter: rec, mu: &mu} + + // Hold the lock as the timeout path does while writing its response. + mu.Lock() + writeErr := make(chan error) + go func() { + _, err := rw.Write([]byte("late")) + writeErr <- err + }() + + // Let the handler write pass the initial check and block on the mutex, + // then mark the timeout before releasing the lock. + time.Sleep(50 * time.Millisecond) + rw.timedOut.Store(true) + mu.Unlock() + + if err := <-writeErr; err != http.ErrHandlerTimeout { + t.Errorf("late Write = %v, want http.ErrHandlerTimeout", err) + } + if rec.Body.Len() != 0 { + t.Errorf("late Write reached the underlying writer: body = %q", rec.Body.String()) + } +} + +// TestMutexResponseWriterMarkTimedOut verifies the timed-out transition the +// timeout middleware performs when the deadline fires: it must reject further +// handler writes, and it may only claim the response if no handler write got +// there first. +func TestMutexResponseWriterMarkTimedOut(t *testing.T) { + t.Run("claims an unwritten response", func(t *testing.T) { + var mu sync.Mutex + rw := &mutexResponseWriter{ResponseWriter: httptest.NewRecorder(), mu: &mu} + + if !rw.markTimedOut() { + t.Error("markTimedOut must claim the response when nothing was written") + } + if !rw.timedOut.Load() { + t.Error("markTimedOut must put the writer into the timed-out state") + } + if n, err := rw.Write([]byte("late")); n != 0 || err != http.ErrHandlerTimeout { + t.Errorf("Write after markTimedOut = (%d, %v), want (0, http.ErrHandlerTimeout)", n, err) + } + }) + + t.Run("does not claim a response a handler write got to first", func(t *testing.T) { + var mu sync.Mutex + rw := &mutexResponseWriter{ResponseWriter: httptest.NewRecorder(), mu: &mu} + // A handler write claimed the response between the middleware's last + // wroteHeader check and the timed-out transition. + rw.wroteHeader.Store(true) + + if rw.markTimedOut() { + t.Error("markTimedOut must not claim a response the handler already started") + } + if !rw.timedOut.Load() { + t.Error("markTimedOut must put the writer into the timed-out state even when it cannot claim the response") + } + }) +} diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index d19c6c2..a03a445 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "net/http" "sync" - "time" "github.com/Suhaibinator/SRouter/pkg/common" "github.com/Suhaibinator/SRouter/pkg/scontext" // Added import @@ -46,71 +45,17 @@ func (g *IDGenerator) init() { g.idChan <- generateUUID() } - // Then start the background worker to keep it filled + // Then start the background worker to keep it filled. A blocking send + // parks the goroutine for free while the buffer is full (no polling) + // and resumes the instant a consumer takes an ID. go func() { - // Pre-allocate a batch of UUIDs to insert quickly when needed - const batchSize = 1000 - batchUUIDs := make([]string, 0, batchSize) - - // Used to determine if we need to batch-fill when channel is getting empty - lastChannelLen := g.size - emptyThreshold := g.size / 10 // 10% capacity threshold to trigger batch fill - for { select { case <-g.stop: return - default: - } - - // Get current channel capacity - currentLen := len(g.idChan) - - // If the channel is getting depleted quickly (below threshold), - // batch-fill it immediately with multiple UUIDs - if currentLen < emptyThreshold && lastChannelLen > currentLen { - // Channel is being consumed quickly, pre-generate a batch - if len(batchUUIDs) == 0 { - // Refill our batch - batchUUIDs = batchUUIDs[:0] // Clear without deallocating - for range batchSize { - batchUUIDs = append(batchUUIDs, generateUUID()) - } - } - - // Add from our batch as many as we can without blocking - for len(batchUUIDs) > 0 { - select { - case <-g.stop: - return - case g.idChan <- batchUUIDs[0]: - // Successfully added one from batch - batchUUIDs = batchUUIDs[1:] - default: - // Channel is now full, stop adding - } - if len(g.idChan) == g.size { - break - } - } - - // Very short sleep to prevent CPU thrashing but still be responsive - time.Sleep(100 * time.Microsecond) // 100μs instead of 10ms - } else { - // Normal case: channel has plenty of capacity, add one at a time - select { - case <-g.stop: - return - case g.idChan <- generateUUID(): - // Successfully added a new UUID - default: - // Channel is full, sleep longer to save CPU - time.Sleep(1 * time.Millisecond) // 1ms instead of 10ms - } + case g.idChan <- generateUUID(): + // Successfully added a new UUID; loop to top up again. } - - // Update our last seen channel length - lastChannelLen = currentLen } }() }) @@ -160,26 +105,60 @@ func (g *IDGenerator) GetIDNonBlocking() string { // Note: WithTraceID, GetTraceIDFromContext, GetTraceID, AddTraceIDToRequest were moved to pkg/scontext/context.go +// maxTraceIDLength bounds inbound X-Trace-ID values; generated IDs are 32 hex +// characters, and common formats (UUID, W3C trace IDs) fit comfortably below this. +const maxTraceIDLength = 64 + +// isValidTraceID reports whether an inbound X-Trace-ID header value is safe to +// propagate into logs and response headers. Only ASCII alphanumerics, '-', and +// '_' are allowed, with a bounded length, so clients can't inject log content +// or oversized values. +func isValidTraceID(id string) bool { + if id == "" || len(id) > maxTraceIDLength { + return false + } + for i := 0; i < len(id); i++ { + c := id[i] + switch { + case c >= '0' && c <= '9': + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c == '-' || c == '_': + default: + return false + } + } + return true +} + // CreateTraceMiddleware creates a trace middleware with the provided ID generator. // This is the core implementation used by both traceMiddleware and traceMiddlewareWithConfig. // It checks for an existing trace ID in the request headers before generating a new one, -// which allows for trace ID propagation across service calls. +// which allows for trace ID propagation across service calls. Client-supplied +// trace IDs are validated (bounded length, [A-Za-z0-9_-] only) before being +// accepted; invalid values are replaced with a generated ID. // It's now generic to accept the UserID (T) and User (U) types from the router. func CreateTraceMiddleware[T comparable, U any](generator *IDGenerator) common.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var traceID string - // Check if there's already a trace ID in the request headers + // Check if there's already a valid trace ID in the request headers existingTraceID := r.Header.Get("X-Trace-ID") - if existingTraceID != "" { + if isValidTraceID(existingTraceID) { // Use the existing trace ID for propagation traceID = existingTraceID } else { - // Generate a new trace ID if none exists + // Generate a new trace ID if none exists (or it was invalid) traceID = generator.GetIDNonBlocking() } + // When an SRouterContext already exists (always the case inside the + // router, which installs it before dispatch), WithTraceID mutates it + // in place and returns the same context, so cloning the request is + // unnecessary. + _, hadSRouterCtx := scontext.GetSRouterContext[T, U](r.Context()) + // Add the trace ID to the request context using the correct generic types ctx := scontext.WithTraceID[T, U](r.Context(), traceID) // Use scontext with router's T and U @@ -187,7 +166,10 @@ func CreateTraceMiddleware[T comparable, U any](generator *IDGenerator) common.M w.Header().Set("X-Trace-ID", traceID) // Call the next handler with the request containing the updated context - next.ServeHTTP(w, r.WithContext(ctx)) + if !hadSRouterCtx { + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) }) } } diff --git a/pkg/middleware/trace_validation_test.go b/pkg/middleware/trace_validation_test.go new file mode 100644 index 0000000..8b6d531 --- /dev/null +++ b/pkg/middleware/trace_validation_test.go @@ -0,0 +1,113 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/scontext" +) + +// TestIsValidTraceID exercises the character-class and length validation for +// inbound X-Trace-ID values. +func TestIsValidTraceID(t *testing.T) { + tests := []struct { + name string + id string + want bool + }{ + {"empty", "", false}, + {"digits only", "0123456789", true}, + {"lowercase letters", "abcxyz", true}, + {"uppercase letters", "ABCXYZ", true}, + {"dash and underscore", "trace-id_1", true}, + {"mixed valid", "REQ-123_abcXYZ", true}, + {"max length (64)", strings.Repeat("a", 64), true}, + {"too long (65)", strings.Repeat("a", 65), false}, + {"disallowed punctuation", "abc!123", false}, + {"embedded newline", "abc\n123", false}, + {"embedded space", "abc 123", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := isValidTraceID(tc.id); got != tc.want { + t.Errorf("isValidTraceID(%q) = %v, want %v", tc.id, got, tc.want) + } + }) + } +} + +// TestTraceMiddlewarePropagatesValidInboundTraceID verifies that a valid +// client-supplied X-Trace-ID is propagated unchanged to both the request +// context and the response header. +func TestTraceMiddlewarePropagatesValidInboundTraceID(t *testing.T) { + generator := NewIDGenerator(4) + defer generator.Stop() + + const inbound = "REQ-123_abcXYZ" + var handlerTraceID string + handler := CreateTraceMiddleware[string, any](generator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerTraceID = scontext.GetTraceIDFromRequest[string, any](r) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Trace-ID", inbound) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if handlerTraceID != inbound { + t.Errorf("handler saw trace ID %q, want propagated inbound ID %q", handlerTraceID, inbound) + } + if got := rec.Header().Get("X-Trace-ID"); got != inbound { + t.Errorf("response X-Trace-ID = %q, want propagated inbound ID %q", got, inbound) + } +} + +// TestTraceMiddlewareReplacesInvalidInboundTraceID verifies that invalid +// client-supplied trace IDs (oversized or containing unsafe characters) are +// not propagated: the middleware substitutes a freshly generated, valid ID. +func TestTraceMiddlewareReplacesInvalidInboundTraceID(t *testing.T) { + generator := NewIDGenerator(4) + defer generator.Stop() + + tests := []struct { + name string + inbound string + }{ + {"oversized", strings.Repeat("a", 65)}, + {"log injection attempt", "abc\nINJECTED"}, + {"disallowed characters", "abc!123"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var handlerTraceID string + handler := CreateTraceMiddleware[string, any](generator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerTraceID = scontext.GetTraceIDFromRequest[string, any](r) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Trace-ID", tc.inbound) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + got := rec.Header().Get("X-Trace-ID") + if got == tc.inbound { + t.Fatalf("invalid inbound trace ID %q was propagated; it should have been replaced", tc.inbound) + } + if got == "" { + t.Fatal("expected a generated trace ID in the response header, got empty") + } + if !isValidTraceID(got) { + t.Errorf("generated replacement trace ID %q is not itself valid", got) + } + if handlerTraceID != got { + t.Errorf("handler saw trace ID %q but response header has %q; they must match", handlerTraceID, got) + } + }) + } +} diff --git a/pkg/router/benchmark_test.go b/pkg/router/benchmark_test.go index 460513d..6f58de7 100644 --- a/pkg/router/benchmark_test.go +++ b/pkg/router/benchmark_test.go @@ -382,3 +382,36 @@ func BenchmarkMemoryUsage(b *testing.B) { // b.Logf("Final Sys = %v MiB", bToMb(mEnd.Sys)) // b.Logf("GC Runs = %v", mEnd.NumGC - mStart.NumGC) } + +// BenchmarkInstrumentedAuthRoute measures the fully-instrumented hot path: +// trace ID injection, request summary logging (status/bytes capture via the +// pooled metrics writer), required authentication, and X-Forwarded-For client +// IP extraction. A no-op logger is used so the benchmark measures router +// overhead rather than log encoding. +func BenchmarkInstrumentedAuthRoute(b *testing.B) { + authLevel := AuthRequired + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + TraceIDBufferSize: 1000, + IPConfig: DefaultIPConfig(), + }, nopAuthFunc, userIDFromString) + r.RegisterRoute(RouteConfigBase{ + Path: "/secure", + Methods: []HttpMethod{MethodGet}, + AuthLevel: &authLevel, + Handler: simpleHandler, + }) + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req, _ := http.NewRequest(http.MethodGet, "/secure", nil) + req.Header.Set("Authorization", "Bearer token") + req.Header.Set("X-Forwarded-For", "203.0.113.7, 10.0.0.1") + req.RemoteAddr = "10.0.0.1:1234" + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + } + }) +} diff --git a/pkg/router/bug_regression_test.go b/pkg/router/bug_regression_test.go new file mode 100644 index 0000000..a7374d3 --- /dev/null +++ b/pkg/router/bug_regression_test.go @@ -0,0 +1,214 @@ +package router + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/codec" + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "go.uber.org/zap" +) + +// countingMiddleware returns a middleware that increments counter on every request. +func countingMiddleware(counter *atomic.Int64) common.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + counter.Add(1) + next.ServeHTTP(w, req) + }) + } +} + +// Regression test for BUGS.md #1: StrategyUser rate limits configured via +// RouteOverrides must work (previously every request returned 500 because the +// user ID conversion functions were dropped in the [any, any] -> [T, U] +// conversion). +func TestStrategyUserRateLimitViaOverrides(t *testing.T) { + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + SubRouters: []SubRouterConfig{{ + PathPrefix: "/api", + Routes: []RouteDefinition{ + RouteConfigBase{ + Path: "/limited", + Methods: []HttpMethod{MethodGet}, + Overrides: common.RouteOverrides{ + RateLimit: &common.RateLimitConfig[any, any]{ + BucketName: "user-bucket", + Limit: 100, + Window: time.Minute, + Strategy: common.StrategyUser, + }, + }, + Handler: func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }, + }, + }, + }}, + }, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "/api/limited", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d for user-strategy rate limited route, got %d (body: %s)", + http.StatusOK, rr.Code, rr.Body.String()) + } +} + +// Regression test for BUGS.md #2: RegisterGenericRouteOnSubRouter must not +// apply global middlewares twice. +func TestRegisterGenericRouteOnSubRouterAppliesGlobalsOnce(t *testing.T) { + var globalCount atomic.Int64 + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + Middlewares: []common.Middleware{countingMiddleware(&globalCount)}, + SubRouters: []SubRouterConfig{{PathPrefix: "/api"}}, + }, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + err := RegisterGenericRouteOnSubRouter(r, "/api", RouteConfig[map[string]string, map[string]string]{ + Path: "/echo", + Methods: []HttpMethod{MethodGet}, + Codec: codec.NewJSONCodec[map[string]string, map[string]string](), + Handler: func(req *http.Request, data map[string]string) (map[string]string, error) { + return map[string]string{"ok": "true"}, nil + }, + SourceType: Empty, + Sanitizer: func(d map[string]string) (map[string]string, error) { return d, nil }, + }) + if err != nil { + t.Fatalf("RegisterGenericRouteOnSubRouter failed: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/echo", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d (body: %s)", http.StatusOK, rr.Code, rr.Body.String()) + } + if got := globalCount.Load(); got != 1 { + t.Fatalf("expected global middleware to run exactly once per request, ran %d times", got) + } +} + +// Regression test for BUGS.md #3: nested sub-routers must inherit the parent +// sub-router's middlewares (and AuthLevel) as documented. +func TestNestedSubRouterInheritsParentMiddlewares(t *testing.T) { + var parentCount, childCount atomic.Int64 + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + SubRouters: []SubRouterConfig{{ + PathPrefix: "/api", + Middlewares: []common.Middleware{countingMiddleware(&parentCount)}, + SubRouters: []SubRouterConfig{{ + PathPrefix: "/v1", + Middlewares: []common.Middleware{countingMiddleware(&childCount)}, + Routes: []RouteDefinition{ + RouteConfigBase{ + Path: "/ping", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }, + }, + }, + }}, + }}, + }, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/ping", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } + if got := parentCount.Load(); got != 1 { + t.Fatalf("expected parent sub-router middleware to run once for nested route, ran %d times", got) + } + if got := childCount.Load(); got != 1 { + t.Fatalf("expected nested sub-router middleware to run once, ran %d times", got) + } +} + +// Regression test for BUGS.md #3 (AuthLevel part): a nested sub-router without +// its own AuthLevel inherits the parent's. +func TestNestedSubRouterInheritsParentAuthLevel(t *testing.T) { + authRequired := AuthRequired + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + SubRouters: []SubRouterConfig{{ + PathPrefix: "/api", + AuthLevel: &authRequired, + SubRouters: []SubRouterConfig{{ + PathPrefix: "/v1", + Routes: []RouteDefinition{ + RouteConfigBase{ + Path: "/secret", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }, + }, + }, + }}, + }}, + }, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + // Without credentials: must be rejected because AuthRequired is inherited. + req := httptest.NewRequest(http.MethodGet, "/api/v1/secret", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d for unauthenticated request to nested auth-required route, got %d", + http.StatusUnauthorized, rr.Code) + } + + // With valid credentials: allowed. + req = httptest.NewRequest(http.MethodGet, "/api/v1/secret", nil) + req.Header.Set("Authorization", "Bearer valid-token") + rr = httptest.NewRecorder() + r.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d for authenticated request, got %d", http.StatusOK, rr.Code) + } +} + +// Regression test for BUGS.md #15: sub-routers added via RegisterSubRouter +// after router creation must be discoverable by RegisterGenericRouteOnSubRouter. +func TestRegisterGenericRouteOnDynamicallyAddedSubRouter(t *testing.T) { + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + r.RegisterSubRouter(SubRouterConfig{PathPrefix: "/late"}) + + err := RegisterGenericRouteOnSubRouter(r, "/late", RouteConfig[map[string]string, map[string]string]{ + Path: "/route", + Methods: []HttpMethod{MethodGet}, + Codec: codec.NewJSONCodec[map[string]string, map[string]string](), + Handler: func(req *http.Request, data map[string]string) (map[string]string, error) { + return map[string]string{"ok": "true"}, nil + }, + SourceType: Empty, + Sanitizer: func(d map[string]string) (map[string]string, error) { return d, nil }, + }) + if err != nil { + t.Fatalf("expected dynamically added sub-router to be found, got error: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/late/route", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } +} diff --git a/pkg/router/config.go b/pkg/router/config.go index 2aed700..c34cefd 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -99,16 +99,18 @@ type MetricsConfig struct { // not implement metrics.MetricsRegistry), no metrics middleware is installed. Collector any // metrics.MetricsRegistry - // MiddlewareFactory is reserved for supplying a custom metrics middleware - // factory. It is currently not consumed by the router: when Collector - // implements metrics.MetricsRegistry, the router builds its own metrics - // middleware from it. + // MiddlewareFactory optionally supplies a custom metrics middleware. If it + // implements metrics.MetricsMiddleware[T, U] (with the router's type + // parameters), it takes precedence over Collector and is used to wrap all + // requests. Otherwise the router builds its own middleware from Collector. MiddlewareFactory any // metrics.MetricsMiddleware - // Namespace for metrics. + // Namespace for metrics. Applied as the "service" tag on all metrics + // emitted by the built-in metrics middleware. Namespace string - // Subsystem for metrics. + // Subsystem for metrics. Applied as the "subsystem" tag on all metrics + // emitted by the built-in metrics middleware. Subsystem string // EnableLatency enables latency metrics. @@ -134,7 +136,7 @@ type RouterConfig struct { GlobalRateLimit *common.RateLimitConfig[any, any] // Use common.RateLimitConfig // Default rate limit for all routes GlobalAuthToken *common.AuthTokenConfig // Default auth token source for built-in auth middleware IPConfig *IPConfig // Configuration for client IP extraction - EnableTraceLogging bool // Enable trace logging + EnableTraceLogging bool // Enable per-request summary logging even when TraceIDBufferSize is 0 TraceLoggingUseInfo bool // Use Info level for trace logging TraceIDBufferSize int // Buffer size for trace ID generator (0 disables trace ID) MetricsConfig *MetricsConfig // Metrics configuration (optional) diff --git a/pkg/router/e2e_test.go b/pkg/router/e2e_test.go new file mode 100644 index 0000000..7650f29 --- /dev/null +++ b/pkg/router/e2e_test.go @@ -0,0 +1,863 @@ +package router + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/codec" + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/scontext" + "go.uber.org/zap" +) + +// End-to-end tests in this file exercise the full stack over a real TCP +// connection: a router served by httptest.NewServer and driven by a real +// http.Client, rather than calling ServeHTTP with a recorder directly. + +type e2eUser struct { + ID string + Name string +} + +type e2eCreateUserRequest struct { + Name string `json:"name"` + Email string `json:"email"` +} + +type e2eCreateUserResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Email string `json:"email"` +} + +// newE2EAuthFunctions returns auth functions that accept the tokens in the +// given map (token -> user name) and reject everything else. +func newE2EAuthFunctions(tokens map[string]string) (func(context.Context, string) (*e2eUser, bool), func(*e2eUser) string) { + authFunc := func(ctx context.Context, token string) (*e2eUser, bool) { + name, ok := tokens[token] + if !ok { + return nil, false + } + return &e2eUser{ID: "id-" + name, Name: name}, true + } + userIDFunc := func(u *e2eUser) string { + if u == nil { + return "" + } + return u.ID + } + return authFunc, userIDFunc +} + +// TestE2EFullStackAPI runs a complete API server over a real HTTP connection: +// trace IDs, global and sub-router middleware, a declarative generic JSON +// route, bearer-token authentication, and standard routing behavior +// (404/405) all working together. +func TestE2EFullStackAPI(t *testing.T) { + authFunc, userIDFunc := newE2EAuthFunctions(map[string]string{"token-alice": "alice"}) + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + GlobalTimeout: 2 * time.Second, + GlobalMaxBodySize: 1 << 20, + TraceIDBufferSize: 10, + AddUserObjectToCtx: true, + Middlewares: []common.Middleware{ + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("X-Global-Middleware", "applied") + next.ServeHTTP(w, req) + }) + }, + }, + SubRouters: []SubRouterConfig{ + { + PathPrefix: "/api/v1", + Middlewares: []common.Middleware{ + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("X-API-Version", "v1") + next.ServeHTTP(w, req) + }) + }, + }, + Routes: []RouteDefinition{ + RouteConfigBase{ + Path: "/health", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ok"}`)) + }, + }, + RouteConfigBase{ + Path: "/me", + Methods: []HttpMethod{MethodGet}, + AuthLevel: new(AuthRequired), + Handler: func(w http.ResponseWriter, req *http.Request) { + userID, ok := scontext.GetUserIDFromRequest[string, e2eUser](req) + if !ok { + http.Error(w, "no user in context", http.StatusInternalServerError) + return + } + user, ok := scontext.GetUserFromRequest[string, e2eUser](req) + if !ok || user == nil { + http.Error(w, "no user object in context", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"id":%q,"name":%q}`, userID, user.Name) + }, + }, + NewGenericRouteDefinition[e2eCreateUserRequest, e2eCreateUserResponse, string, e2eUser]( + RouteConfig[e2eCreateUserRequest, e2eCreateUserResponse]{ + Path: "/users", + Methods: []HttpMethod{MethodPost}, + Codec: codec.NewJSONCodec[e2eCreateUserRequest, e2eCreateUserResponse](), + Sanitizer: func(req e2eCreateUserRequest) (e2eCreateUserRequest, error) { + req.Name = strings.TrimSpace(req.Name) + return req, nil + }, + Handler: func(req *http.Request, data e2eCreateUserRequest) (e2eCreateUserResponse, error) { + return e2eCreateUserResponse{ + ID: "user-1", + Name: data.Name, + Email: data.Email, + }, nil + }, + }, + ), + }, + }, + }, + }, authFunc, userIDFunc) + + server := httptest.NewServer(r) + defer server.Close() + client := server.Client() + + t.Run("public route with middleware and trace ID", func(t *testing.T) { + resp, err := client.Get(server.URL + "/api/v1/health") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != `{"status":"ok"}` { + t.Errorf("unexpected body: %s", body) + } + if resp.Header.Get("X-Global-Middleware") != "applied" { + t.Errorf("expected global middleware header, got %q", resp.Header.Get("X-Global-Middleware")) + } + if resp.Header.Get("X-API-Version") != "v1" { + t.Errorf("expected sub-router middleware header, got %q", resp.Header.Get("X-API-Version")) + } + if resp.Header.Get("X-Trace-ID") == "" { + t.Error("expected X-Trace-ID response header to be set") + } + }) + + t.Run("generic JSON route round-trip", func(t *testing.T) { + reqBody := `{"name":" Bob ","email":"bob@example.com"}` + resp, err := client.Post(server.URL+"/api/v1/users", "application/json", strings.NewReader(reqBody)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + var created e2eCreateUserResponse + if err := json.NewDecoder(resp.Body).Decode(&created); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if created.ID != "user-1" { + t.Errorf("expected ID %q, got %q", "user-1", created.ID) + } + if created.Name != "Bob" { + t.Errorf("expected sanitized name %q, got %q", "Bob", created.Name) + } + if created.Email != "bob@example.com" { + t.Errorf("expected email %q, got %q", "bob@example.com", created.Email) + } + }) + + t.Run("generic route rejects malformed JSON", func(t *testing.T) { + resp, err := client.Post(server.URL+"/api/v1/users", "application/json", strings.NewReader("{not json")) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + }) + + t.Run("auth required without token", func(t *testing.T) { + resp, err := client.Get(server.URL + "/api/v1/me") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + }) + + t.Run("auth required with invalid token", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, server.URL+"/api/v1/me", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + }) + + t.Run("auth required with valid token", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, server.URL+"/api/v1/me", nil) + req.Header.Set("Authorization", "Bearer token-alice") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + expected := `{"id":"id-alice","name":"alice"}` + if string(body) != expected { + t.Errorf("expected body %q, got %q", expected, body) + } + }) + + t.Run("unknown path returns 404", func(t *testing.T) { + resp, err := client.Get(server.URL + "/api/v1/nope") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("expected status %d, got %d", http.StatusNotFound, resp.StatusCode) + } + }) + + t.Run("wrong method returns 405", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodDelete, server.URL+"/api/v1/health", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode) + } + }) +} + +// TestE2ERateLimiting verifies that IP-based rate limiting is enforced for +// real HTTP clients, including the X-RateLimit-* and Retry-After headers. +func TestE2ERateLimiting(t *testing.T) { + authFunc, userIDFunc := newE2EAuthFunctions(nil) + + const limit = 3 + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + }, authFunc, userIDFunc) + + r.RegisterRoute(RouteConfigBase{ + Path: "/limited", + Methods: []HttpMethod{MethodGet}, + Overrides: common.RouteOverrides{ + RateLimit: &common.RateLimitConfig[any, any]{ + BucketName: "e2e-limited", + Limit: limit, + Window: time.Minute, + Strategy: common.StrategyIP, + }, + }, + Handler: func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }, + }) + + server := httptest.NewServer(r) + defer server.Close() + client := server.Client() + + for i := 1; i <= limit; i++ { + resp, err := client.Get(server.URL + "/limited") + if err != nil { + t.Fatalf("request %d failed: %v", i, err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("request %d: expected status %d, got %d", i, http.StatusOK, resp.StatusCode) + } + if got := resp.Header.Get("X-RateLimit-Limit"); got != fmt.Sprint(limit) { + t.Errorf("request %d: expected X-RateLimit-Limit %d, got %q", i, limit, got) + } + if got := resp.Header.Get("X-RateLimit-Remaining"); got != fmt.Sprint(limit-i) { + t.Errorf("request %d: expected X-RateLimit-Remaining %d, got %q", i, limit-i, got) + } + } + + // The next request from the same IP must be rejected. + resp, err := client.Get(server.URL + "/limited") + if err != nil { + t.Fatalf("over-limit request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusTooManyRequests { + t.Fatalf("expected status %d, got %d", http.StatusTooManyRequests, resp.StatusCode) + } + if resp.Header.Get("Retry-After") == "" { + t.Error("expected Retry-After header on 429 response") + } +} + +// TestE2ECORS simulates how a real browser drives CORS against the server, +// over a real HTTP connection. Each subtest mirrors a step a browser takes: +// same-origin requests carry no Origin header, non-simple cross-origin +// requests are preceded by a credential-less OPTIONS preflight, and the +// browser blocks based purely on the presence or absence of the +// Access-Control-Allow-* response headers (the server replies 204 to every +// preflight regardless). +func TestE2ECORS(t *testing.T) { + const allowedOrigin = "http://example.com" + + authFunc, userIDFunc := newE2EAuthFunctions(map[string]string{"token-alice": "alice"}) + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + CORSConfig: &CORSConfig{ + Origins: []string{allowedOrigin}, + Methods: []string{"GET", "POST"}, + Headers: []string{"Content-Type", "Authorization"}, + AllowCredentials: true, + MaxAge: time.Hour, + }, + }, authFunc, userIDFunc) + + r.RegisterRoute(RouteConfigBase{ + Path: "/resource", + Methods: []HttpMethod{MethodGet, MethodPost}, + Handler: func(w http.ResponseWriter, req *http.Request) { + _, _ = w.Write([]byte("resource")) + }, + }) + // A protected endpoint, as an SPA calling an authenticated API would use. + r.RegisterRoute(RouteConfigBase{ + Path: "/protected", + Methods: []HttpMethod{MethodPost}, + AuthLevel: new(AuthRequired), + Handler: func(w http.ResponseWriter, req *http.Request) { + _, _ = w.Write([]byte("protected data")) + }, + }) + + server := httptest.NewServer(r) + defer server.Close() + client := server.Client() + + t.Run("same-origin request gets no CORS headers", func(t *testing.T) { + // Browsers do not send an Origin header on same-origin GETs; the + // response must work normally and carry no CORS headers. + resp, err := client.Get(server.URL + "/resource") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("expected no Access-Control-Allow-Origin header, got %q", got) + } + }) + + t.Run("preflight request", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodOptions, server.URL+"/resource", nil) + req.Header.Set("Origin", allowedOrigin) + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "Content-Type") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("preflight request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected preflight status %d, got %d", http.StatusNoContent, resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != allowedOrigin { + t.Errorf("expected Access-Control-Allow-Origin %q, got %q", allowedOrigin, got) + } + if got := resp.Header.Get("Access-Control-Allow-Methods"); !strings.Contains(got, "POST") { + t.Errorf("expected Access-Control-Allow-Methods to contain POST, got %q", got) + } + if got := resp.Header.Get("Access-Control-Allow-Headers"); !strings.Contains(got, "Content-Type") { + t.Errorf("expected Access-Control-Allow-Headers to contain Content-Type, got %q", got) + } + if got := resp.Header.Get("Access-Control-Allow-Credentials"); got != "true" { + t.Errorf("expected Access-Control-Allow-Credentials true, got %q", got) + } + if got := resp.Header.Get("Access-Control-Max-Age"); got == "" { + t.Error("expected Access-Control-Max-Age to be set so browsers can cache the preflight") + } + }) + + t.Run("actual cross-origin request", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, server.URL+"/resource", nil) + req.Header.Set("Origin", allowedOrigin) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "resource" { + t.Errorf("expected body %q, got %q", "resource", body) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != allowedOrigin { + t.Errorf("expected Access-Control-Allow-Origin %q, got %q", allowedOrigin, got) + } + // Responses that depend on the Origin header must say so, or shared + // caches could serve one origin's CORS headers to another. + if got := resp.Header.Values("Vary"); !slices.Contains(got, "Origin") { + t.Errorf("expected Vary to contain Origin, got %v", got) + } + }) + + t.Run("browser flow against protected API", func(t *testing.T) { + // Step 1: the browser preflights the credentialed POST. Preflights + // never carry credentials, so this must succeed without a token — + // the CORS layer answers before authentication runs. + preflight, _ := http.NewRequest(http.MethodOptions, server.URL+"/protected", nil) + preflight.Header.Set("Origin", allowedOrigin) + preflight.Header.Set("Access-Control-Request-Method", "POST") + preflight.Header.Set("Access-Control-Request-Headers", "Authorization, Content-Type") + + resp, err := client.Do(preflight) + if err != nil { + t.Fatalf("preflight request failed: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected credential-less preflight to succeed with %d, got %d", + http.StatusNoContent, resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != allowedOrigin { + t.Fatalf("expected Access-Control-Allow-Origin %q, got %q (browser would block the request)", + allowedOrigin, got) + } + if got := resp.Header.Get("Access-Control-Allow-Headers"); !strings.Contains(got, "Authorization") { + t.Errorf("expected Access-Control-Allow-Headers to contain Authorization, got %q", got) + } + + // Step 2: the preflight passed, so the browser sends the real request + // with credentials attached. + actual, _ := http.NewRequest(http.MethodPost, server.URL+"/protected", strings.NewReader("{}")) + actual.Header.Set("Origin", allowedOrigin) + actual.Header.Set("Authorization", "Bearer token-alice") + actual.Header.Set("Content-Type", "application/json") + + resp, err = client.Do(actual) + if err != nil { + t.Fatalf("actual request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "protected data" { + t.Errorf("expected body %q, got %q", "protected data", body) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != allowedOrigin { + t.Errorf("expected Access-Control-Allow-Origin %q, got %q", allowedOrigin, got) + } + if got := resp.Header.Get("Access-Control-Allow-Credentials"); got != "true" { + t.Errorf("expected Access-Control-Allow-Credentials true, got %q", got) + } + + // A request without credentials (e.g. an expired session) is still + // rejected by authentication; CORS does not bypass it. + unauthed, _ := http.NewRequest(http.MethodPost, server.URL+"/protected", strings.NewReader("{}")) + unauthed.Header.Set("Origin", allowedOrigin) + unauthed.Header.Set("Content-Type", "application/json") + + resp2, err := client.Do(unauthed) + if err != nil { + t.Fatalf("unauthenticated request failed: %v", err) + } + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status %d for missing credentials, got %d", + http.StatusUnauthorized, resp2.StatusCode) + } + }) + + t.Run("preflight for disallowed method is refused", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodOptions, server.URL+"/resource", nil) + req.Header.Set("Origin", allowedOrigin) + req.Header.Set("Access-Control-Request-Method", "DELETE") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("preflight request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + // The server still answers 204; the missing Allow-Methods header is + // what makes the browser block the DELETE. + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected preflight status %d, got %d", http.StatusNoContent, resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Methods"); got != "" { + t.Errorf("expected no Access-Control-Allow-Methods for disallowed method, got %q", got) + } + }) + + t.Run("preflight for disallowed header is refused", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodOptions, server.URL+"/resource", nil) + req.Header.Set("Origin", allowedOrigin) + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "X-Custom-Secret") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("preflight request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected preflight status %d, got %d", http.StatusNoContent, resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Headers"); got != "" { + t.Errorf("expected no Access-Control-Allow-Headers for disallowed header, got %q", got) + } + }) + + t.Run("preflight from disallowed origin is refused", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodOptions, server.URL+"/resource", nil) + req.Header.Set("Origin", "http://evil.example.org") + req.Header.Set("Access-Control-Request-Method", "POST") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("preflight request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected preflight status %d, got %d", http.StatusNoContent, resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("expected no Access-Control-Allow-Origin header, got %q", got) + } + }) + + t.Run("disallowed origin gets no CORS headers", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, server.URL+"/resource", nil) + req.Header.Set("Origin", "http://evil.example.org") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("expected no Access-Control-Allow-Origin header, got %q", got) + } + // Still varies by Origin: an allowed origin would have gotten CORS + // headers, so caches must not reuse this response across origins. + if got := resp.Header.Values("Vary"); !slices.Contains(got, "Origin") { + t.Errorf("expected Vary to contain Origin, got %v", got) + } + }) +} + +// TestE2ETimeoutAndPanicRecovery verifies that timeouts and panic recovery +// produce proper HTTP error responses over a real connection, and that the +// server keeps serving normally afterwards. +func TestE2ETimeoutAndPanicRecovery(t *testing.T) { + authFunc, userIDFunc := newE2EAuthFunctions(nil) + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + GlobalTimeout: 100 * time.Millisecond, + }, authFunc, userIDFunc) + + r.RegisterRoute(RouteConfigBase{ + Path: "/slow", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + select { + case <-time.After(2 * time.Second): + w.WriteHeader(http.StatusOK) + case <-req.Context().Done(): + } + }, + }) + r.RegisterRoute(RouteConfigBase{ + Path: "/panic", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + panic("e2e test panic") + }, + }) + r.RegisterRoute(RouteConfigBase{ + Path: "/ok", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + _, _ = w.Write([]byte("still alive")) + }, + }) + + server := httptest.NewServer(r) + defer server.Close() + client := server.Client() + + t.Run("slow handler times out", func(t *testing.T) { + resp, err := client.Get(server.URL + "/slow") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("expected status %d, got %d", http.StatusRequestTimeout, resp.StatusCode) + } + }) + + t.Run("panicking handler returns 500", func(t *testing.T) { + resp, err := client.Get(server.URL + "/panic") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode) + } + }) + + t.Run("server still healthy afterwards", func(t *testing.T) { + resp, err := client.Get(server.URL + "/ok") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "still alive" { + t.Errorf("expected body %q, got %q", "still alive", body) + } + }) +} + +// TestE2EConcurrentRequests fires many concurrent real HTTP requests at a +// generic route and verifies every response succeeds with a unique trace ID. +func TestE2EConcurrentRequests(t *testing.T) { + authFunc, userIDFunc := newE2EAuthFunctions(nil) + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + GlobalTimeout: 5 * time.Second, + TraceIDBufferSize: 100, + }, authFunc, userIDFunc) + + RegisterGenericRoute(r, RouteConfig[e2eCreateUserRequest, e2eCreateUserResponse]{ + Path: "/echo", + Methods: []HttpMethod{MethodPost}, + Codec: codec.NewJSONCodec[e2eCreateUserRequest, e2eCreateUserResponse](), + Handler: func(req *http.Request, data e2eCreateUserRequest) (e2eCreateUserResponse, error) { + return e2eCreateUserResponse{Name: data.Name, Email: data.Email}, nil + }, + }, time.Duration(0), int64(0), nil) + + server := httptest.NewServer(r) + defer server.Close() + client := server.Client() + + const numRequests = 50 + var wg sync.WaitGroup + var mu sync.Mutex + traceIDs := make(map[string]bool, numRequests) + errs := make(chan error, numRequests) + + for i := range numRequests { + wg.Add(1) + go func(i int) { + defer wg.Done() + + name := fmt.Sprintf("user-%d", i) + body, _ := json.Marshal(e2eCreateUserRequest{Name: name}) + resp, err := client.Post(server.URL+"/echo", "application/json", bytes.NewReader(body)) + if err != nil { + errs <- fmt.Errorf("request %d failed: %w", i, err) + return + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + errs <- fmt.Errorf("request %d: expected status 200, got %d", i, resp.StatusCode) + return + } + var echoed e2eCreateUserResponse + if err := json.NewDecoder(resp.Body).Decode(&echoed); err != nil { + errs <- fmt.Errorf("request %d: failed to decode response: %w", i, err) + return + } + if echoed.Name != name { + errs <- fmt.Errorf("request %d: expected name %q, got %q", i, name, echoed.Name) + return + } + traceID := resp.Header.Get("X-Trace-ID") + if traceID == "" { + errs <- fmt.Errorf("request %d: missing X-Trace-ID header", i) + return + } + mu.Lock() + defer mu.Unlock() + if traceIDs[traceID] { + errs <- fmt.Errorf("request %d: duplicate trace ID %q", i, traceID) + return + } + traceIDs[traceID] = true + }(i) + } + + wg.Wait() + close(errs) + for err := range errs { + t.Error(err) + } + if len(traceIDs) != numRequests { + t.Errorf("expected %d unique trace IDs, got %d", numRequests, len(traceIDs)) + } +} + +// TestE2EGracefulShutdown verifies that an in-flight request completes during +// shutdown and that subsequent requests are rejected with 503. +func TestE2EGracefulShutdown(t *testing.T) { + authFunc, userIDFunc := newE2EAuthFunctions(nil) + + r := NewRouter(RouterConfig{ + Logger: zap.NewNop(), + GlobalTimeout: 2 * time.Second, + }, authFunc, userIDFunc) + + started := make(chan struct{}) + r.RegisterRoute(RouteConfigBase{ + Path: "/slow", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, req *http.Request) { + close(started) + time.Sleep(100 * time.Millisecond) + _, _ = w.Write([]byte("completed")) + }, + }) + + server := httptest.NewServer(r) + defer server.Close() + client := server.Client() + + type result struct { + status int + body string + err error + } + inFlight := make(chan result, 1) + go func() { + resp, err := client.Get(server.URL + "/slow") + if err != nil { + inFlight <- result{err: err} + return + } + defer func() { _ = resp.Body.Close() }() + body, _ := io.ReadAll(resp.Body) + inFlight <- result{status: resp.StatusCode, body: string(body)} + }() + + // Wait until the handler is actually executing, then shut down. + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("in-flight request never reached the handler") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := r.Shutdown(ctx); err != nil { + t.Fatalf("Shutdown returned error: %v", err) + } + + // The in-flight request must have completed successfully. + select { + case res := <-inFlight: + if res.err != nil { + t.Fatalf("in-flight request failed: %v", res.err) + } + if res.status != http.StatusOK { + t.Errorf("expected in-flight status %d, got %d", http.StatusOK, res.status) + } + if res.body != "completed" { + t.Errorf("expected in-flight body %q, got %q", "completed", res.body) + } + case <-time.After(2 * time.Second): + t.Fatal("in-flight request did not complete") + } + + // New requests after shutdown must be rejected. + resp, err := client.Get(server.URL + "/slow") + if err != nil { + t.Fatalf("post-shutdown request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("expected status %d after shutdown, got %d", http.StatusServiceUnavailable, resp.StatusCode) + } +} diff --git a/pkg/router/ip.go b/pkg/router/ip.go index 0ad6fe2..38d3e6e 100644 --- a/pkg/router/ip.go +++ b/pkg/router/ip.go @@ -39,6 +39,11 @@ type IPConfig struct { } // DefaultIPConfig returns the default IP configuration. +// The default uses the rightmost X-Forwarded-For entry (the value appended by +// the proxy nearest this server) rather than the client-controlled leftmost +// entry. If the service is exposed directly to the internet (no trusted proxy), +// set Source to IPSourceRemoteAddr or TrustProxy to false so client-supplied +// headers are ignored entirely. func DefaultIPConfig() *IPConfig { return &IPConfig{ Source: IPSourceXForwardedFor, // Default to checking X-Forwarded-For @@ -61,11 +66,19 @@ func ClientIPMiddleware[T comparable, U any](config *IPConfig) func(http.Handler // Extract the client IP based on the configuration clientIP := extractClientIP(r, config) + // When an SRouterContext already exists, WithClientIP mutates it in + // place and returns the same context, so cloning the request is + // unnecessary. + _, hadSRouterCtx := scontext.GetSRouterContext[T, U](r.Context()) + // Add the client IP to the SRouterContext ctx := scontext.WithClientIP[T, U](r.Context(), clientIP) // Use scontext // Call the next handler with the updated context - next.ServeHTTP(w, r.WithContext(ctx)) + if !hadSRouterCtx { + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) }) } } @@ -99,21 +112,30 @@ func extractClientIP(r *http.Request, config *IPConfig) string { return cleanIP(ip) } -// extractIPFromXForwardedFor extracts the client IP from the X-Forwarded-For header -// The X-Forwarded-For header contains a comma-separated list of IPs, with the leftmost being the original client +// extractIPFromXForwardedFor extracts the client IP from the X-Forwarded-For header. +// The header contains a comma-separated list of IPs. Earlier (leftmost) entries are +// supplied by the client and are trivially spoofable; the rightmost entry was appended +// by the proxy closest to this server and is the only value the deployment's own +// infrastructure vouches for. Using it prevents clients from rotating fake IPs to +// bypass IP-based rate limiting. func extractIPFromXForwardedFor(r *http.Request) string { xff := r.Header.Get("X-Forwarded-For") if xff == "" { return "" } - // The leftmost IP is the original client - ips := strings.Split(xff, ",") - if len(ips) > 0 { - return strings.TrimSpace(ips[0]) + // Use the rightmost (most recently appended, least spoofable) entry. + // Scan from the end without splitting so no intermediate slice is allocated. + for { + comma := strings.LastIndexByte(xff, ',') + if ip := strings.TrimSpace(xff[comma+1:]); ip != "" { + return ip + } + if comma < 0 { + return "" + } + xff = xff[:comma] } - - return "" } // cleanIP removes the port from an IP address if present diff --git a/pkg/router/ip_test.go b/pkg/router/ip_test.go index 9236e98..bfc6d8c 100644 --- a/pkg/router/ip_test.go +++ b/pkg/router/ip_test.go @@ -10,12 +10,22 @@ import ( // TestExtractIPFromXForwardedFor tests the extractIPFromXForwardedFor function func TestExtractIPFromXForwardedFor(t *testing.T) { - // Test with valid X-Forwarded-For header containing multiple IPs + // Test with valid X-Forwarded-For header containing multiple IPs. + // The rightmost entry (appended by the nearest proxy) is used because the + // leftmost entries are client-controlled and spoofable. req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Forwarded-For", "203.0.113.1, 198.51.100.1") ip := extractIPFromXForwardedFor(req) - if ip != "203.0.113.1" { - t.Errorf("Expected IP '203.0.113.1', got '%s'", ip) + if ip != "198.51.100.1" { + t.Errorf("Expected IP '198.51.100.1', got '%s'", ip) + } + + // Trailing empty entries are skipped + req = httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "203.0.113.1, 198.51.100.1, ") + ip = extractIPFromXForwardedFor(req) + if ip != "198.51.100.1" { + t.Errorf("Expected IP '198.51.100.1', got '%s'", ip) } // Test with valid X-Forwarded-For header containing a single IP @@ -59,7 +69,7 @@ func TestExtractClientIP(t *testing.T) { config: DefaultIPConfig(), headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 198.51.100.1"}, remoteAddr: "192.0.2.1:1234", - expectedIP: "203.0.113.1", + expectedIP: "198.51.100.1", // Rightmost (proxy-appended) entry wins }, { name: "X-Real-IP Config", @@ -99,9 +109,9 @@ func TestExtractClientIP(t *testing.T) { { name: "IPv6 X-Forwarded-For", config: DefaultIPConfig(), - headers: map[string]string{"X-Forwarded-For": "[2001:db8::1]:54321, 198.51.100.1"}, + headers: map[string]string{"X-Forwarded-For": "198.51.100.1, [2001:db8::1]:54321"}, remoteAddr: "192.0.2.1:1234", - expectedIP: "[2001:db8::1]", // Expects cleaned IPv6 + expectedIP: "[2001:db8::1]", // Rightmost entry, cleaned IPv6 }, { name: "IPv6 RemoteAddr", diff --git a/pkg/router/logging_trace_test.go b/pkg/router/logging_trace_test.go index 5b1d6dc..77e0bd0 100644 --- a/pkg/router/logging_trace_test.go +++ b/pkg/router/logging_trace_test.go @@ -70,17 +70,19 @@ func TestTraceIDLoggingDisabled(t *testing.T) { rr := httptest.NewRecorder() r.ServeHTTP(rr, req) - // With TraceIDBufferSize = 0, the entire deferred logging block in ServeHTTP is skipped. - // Therefore, we expect NO log entries from this specific logger setup. + // With EnableTraceLogging = true the summary logging block runs even when + // TraceIDBufferSize is 0, but no trace_id field may be attached. logEntries := logs.AllUntimed() // Use AllUntimed() for consistency - if len(logEntries) != 0 { - t.Errorf("Expected 0 log entries when TraceIDBufferSize is 0, but got %d", len(logEntries)) - // Log the unexpected entries for debugging - for i, entry := range logEntries { - t.Logf("Unexpected log entry %d: %v", i, entry) + if len(logEntries) == 0 { + t.Errorf("Expected request summary log entries when EnableTraceLogging is true") + } + for _, entry := range logEntries { + for _, field := range entry.Context { + if field.Key == "trace_id" { + t.Errorf("Expected no trace_id field when TraceIDBufferSize is 0, found %q", field.String) + } } } - // The previous loop checking for the absence of 'trace_id' is no longer needed. } // TestHandleErrorWithTraceID tests that handleError includes trace IDs in log entries when TraceIDBufferSize > 0 diff --git a/pkg/router/metrics_test.go b/pkg/router/metrics_test.go index 8473dea..0eb88e0 100644 --- a/pkg/router/metrics_test.go +++ b/pkg/router/metrics_test.go @@ -193,7 +193,7 @@ func TestMetricsResponseWriterFlush(t *testing.T) { // Create a metrics response writer with string as both the user ID and user type mrw := &metricsResponseWriter[string, string]{ - baseResponseWriter: &baseResponseWriter{ResponseWriter: rr}, + baseResponseWriter: baseResponseWriter{ResponseWriter: rr}, statusCode: http.StatusOK, } @@ -270,7 +270,7 @@ func TestMetricsResponseWriter(t *testing.T) { // Create a metrics response writer with string as both the user ID and user type mrw := &metricsResponseWriter[string, string]{ - baseResponseWriter: &baseResponseWriter{ResponseWriter: rr}, + baseResponseWriter: baseResponseWriter{ResponseWriter: rr}, statusCode: http.StatusOK, } diff --git a/pkg/router/patch_coverage_test.go b/pkg/router/patch_coverage_test.go new file mode 100644 index 0000000..bae49d4 --- /dev/null +++ b/pkg/router/patch_coverage_test.go @@ -0,0 +1,205 @@ +package router + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/metrics" + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" +) + +// fakeMetricsMiddlewareFactory implements metrics.MetricsMiddleware[string, string] +// so tests can verify that a user-supplied MiddlewareFactory takes precedence +// over building middleware from the Collector. +type fakeMetricsMiddlewareFactory struct { + mu sync.Mutex + handlerNames []string + requests int +} + +func (f *fakeMetricsMiddlewareFactory) Handler(name string, handler http.Handler) http.Handler { + f.mu.Lock() + f.handlerNames = append(f.handlerNames, name) + f.mu.Unlock() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f.mu.Lock() + f.requests++ + f.mu.Unlock() + w.Header().Set("X-Metrics-Factory", "invoked") + handler.ServeHTTP(w, r) + }) +} + +func (f *fakeMetricsMiddlewareFactory) Configure(config metrics.MetricsMiddlewareConfig) metrics.MetricsMiddleware[string, string] { + return f +} + +func (f *fakeMetricsMiddlewareFactory) WithFilter(filter metrics.MetricsFilter) metrics.MetricsMiddleware[string, string] { + return f +} + +func (f *fakeMetricsMiddlewareFactory) WithSampler(sampler metrics.MetricsSampler) metrics.MetricsMiddleware[string, string] { + return f +} + +// TestMetricsConfigMiddlewareFactory verifies that when MetricsConfig supplies +// a MiddlewareFactory of the router's generic type, the router wraps handlers +// with it (passing the configured ServiceName) and requests flow through it. +func TestMetricsConfigMiddlewareFactory(t *testing.T) { + factory := &fakeMetricsMiddlewareFactory{} + + r := NewRouter(RouterConfig{ + ServiceName: "test-service", + MetricsConfig: &MetricsConfig{ + MiddlewareFactory: factory, + }, + }, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + r.RegisterRoute(RouteConfigBase{ + Path: "/factory", + Methods: []HttpMethod{MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, + }) + + rec := httptest.NewRecorder() + r.ServeHTTP(rec, httptest.NewRequest("GET", "/factory", nil)) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if rec.Header().Get("X-Metrics-Factory") != "invoked" { + t.Error("request did not pass through the factory-provided metrics middleware") + } + + factory.mu.Lock() + defer factory.mu.Unlock() + if factory.requests != 1 { + t.Errorf("factory middleware handled %d requests, want 1", factory.requests) + } + if len(factory.handlerNames) == 0 { + t.Fatal("factory Handler was never called when wrapping routes") + } + for _, name := range factory.handlerNames { + if name != "test-service" { + t.Errorf("factory Handler called with name %q, want configured ServiceName %q", name, "test-service") + } + } +} + +// TestGetEffectiveRateLimitConvertsUserIDFunctions verifies that converting a +// RateLimitConfig[any, any] override to the router's concrete types adapts +// UserIDFromUser and UserIDToString so user-based rate limiting keeps working. +func TestGetEffectiveRateLimitConvertsUserIDFunctions(t *testing.T) { + r := NewRouter(RouterConfig{}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + src := &common.RateLimitConfig[any, any]{ + BucketName: "user-bucket", + Limit: 7, + Window: time.Minute, + Strategy: common.StrategyUser, + UserIDFromUser: func(user any) any { + return "id-" + user.(string) + }, + UserIDToString: func(userID any) string { + return "key:" + userID.(string) + }, + } + + got := r.getEffectiveRateLimit(src, nil) + if got == nil { + t.Fatal("expected a converted rate limit config, got nil") + } + if got.BucketName != "user-bucket" || got.Limit != 7 || got.Window != time.Minute || got.Strategy != common.StrategyUser { + t.Errorf("converted config lost fields: %+v", got) + } + if got.UserIDFromUser == nil { + t.Fatal("UserIDFromUser was not adapted across the type conversion") + } + if id := got.UserIDFromUser("alice"); id != "id-alice" { + t.Errorf("UserIDFromUser(\"alice\") = %q, want %q", id, "id-alice") + } + if got.UserIDToString == nil { + t.Fatal("UserIDToString was not adapted across the type conversion") + } + if key := got.UserIDToString("bob"); key != "key:bob" { + t.Errorf("UserIDToString(\"bob\") = %q, want %q", key, "key:bob") + } + + // A UserIDFromUser returning a value of the wrong type must degrade to the + // zero user ID instead of panicking. Passed as the sub-router override to + // exercise that precedence level too. + wrongType := &common.RateLimitConfig[any, any]{ + Limit: 1, + Window: time.Second, + UserIDFromUser: func(user any) any { + return 42 // not a string + }, + } + converted := r.getEffectiveRateLimit(nil, wrongType) + if converted == nil || converted.UserIDFromUser == nil { + t.Fatal("expected converted sub-router config with adapted UserIDFromUser") + } + if id := converted.UserIDFromUser("alice"); id != "" { + t.Errorf("UserIDFromUser with mismatched return type = %q, want zero value", id) + } +} + +// TestExtractIPFromXForwardedForBlankEntries verifies that an X-Forwarded-For +// header containing only blank entries (commas and whitespace) yields no IP, +// so the caller falls back to RemoteAddr instead of using an empty key. +func TestExtractIPFromXForwardedForBlankEntries(t *testing.T) { + for _, xff := range []string{" ", ",", " , ", ",, ,"} { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-For", xff) + if ip := extractIPFromXForwardedFor(req); ip != "" { + t.Errorf("X-Forwarded-For %q: got %q, want empty string", xff, ip) + } + } + + // End to end: with a blank XFF the extracted client IP must fall back to + // RemoteAddr even when proxy headers are trusted. + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "203.0.113.9:1234" + req.Header.Set("X-Forwarded-For", " , ") + ip := extractClientIP(req, &IPConfig{Source: IPSourceXForwardedFor, TrustProxy: true}) + if ip != "203.0.113.9" { + t.Errorf("extractClientIP with blank XFF = %q, want RemoteAddr fallback %q", ip, "203.0.113.9") + } +} + +// TestRouterMutexResponseWriterWriteRecheckUnderLock verifies the router's +// timeout response writer rejects a handler write that passed the initial +// timeout check but lost the race to the timeout response: the re-check under +// the lock must fail the write instead of corrupting the response. +func TestRouterMutexResponseWriterWriteRecheckUnderLock(t *testing.T) { + rec := httptest.NewRecorder() + var mu sync.Mutex + rw := &mutexResponseWriter{ResponseWriter: rec, mu: &mu} + + // Hold the lock as the timeout path does while writing its response. + mu.Lock() + writeErr := make(chan error) + go func() { + _, err := rw.Write([]byte("late")) + writeErr <- err + }() + + // Let the handler write pass the initial check and block on the mutex, + // then mark the timeout before releasing the lock. + time.Sleep(50 * time.Millisecond) + rw.timedOut.Store(true) + mu.Unlock() + + if err := <-writeErr; err != http.ErrHandlerTimeout { + t.Errorf("late Write = %v, want http.ErrHandlerTimeout", err) + } + if rec.Body.Len() != 0 { + t.Errorf("late Write reached the underlying writer: body = %q", rec.Body.String()) + } +} diff --git a/pkg/router/route.go b/pkg/router/route.go index dcd3387..4f69e14 100644 --- a/pkg/router/route.go +++ b/pkg/router/route.go @@ -109,11 +109,9 @@ func registerGenericRouteWithAuthTokenResolution[Req any, Resp any, UserID compa // Use the codec's Decode method to read directly from the request body data, err = route.Codec.Decode(req) if err != nil { - // Check if this is a MaxBytesReader error (applied in wrapHandler) - // Note: io.ReadAll is no longer called here, the codec handles reading. - // We need to check for the specific error string potentially returned by http.MaxBytesReader - // or similar errors from the codec's Decode implementation. - if err.Error() == "http: request body too large" { // Keep this check + // Check if this is a MaxBytesReader error (applied in wrapHandler). + // errors.As unwraps, so this works even when a codec wraps the error. + if isMaxBytesError(err) { r.handleError(w, req, err, http.StatusRequestEntityTooLarge, "Request entity too large") return } diff --git a/pkg/router/router.go b/pkg/router/router.go index bbd1ef8..4e6d277 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -128,6 +128,14 @@ func NewRouter[T comparable, U any](config RouterConfig, authFunction func(conte // Precompute CORS headers if configured if config.CORSConfig != nil { + // Warn about a contradictory configuration: the CORS spec forbids + // credentials with a wildcard origin, so the credentials header will + // never be emitted for wildcard matches. + if config.CORSConfig.AllowCredentials && slices.Contains(config.CORSConfig.Origins, "*") { + r.logger.Warn("CORS config combines wildcard origin with AllowCredentials; " + + "credentials are never allowed for wildcard origins per the CORS spec. " + + "List explicit origins to enable credentials.") + } if len(config.CORSConfig.Methods) > 0 { r.corsAllowMethods = strings.Join(config.CORSConfig.Methods, ", ") } @@ -152,26 +160,35 @@ func NewRouter[T comparable, U any](config RouterConfig, authFunction func(conte if config.MetricsConfig != nil { var metricsMiddleware common.Middleware - // Use the MetricsConfig - if config.MetricsConfig != nil { - // Check if the collector is a metrics registry - if registry, ok := config.MetricsConfig.Collector.(metrics.MetricsRegistry); ok { - // Create the generic middleware implementation using the router's T and U types - metricsMiddlewareImpl := metrics.NewMetricsMiddleware[T, U](registry, metrics.MetricsMiddlewareConfig{ - EnableLatency: config.MetricsConfig.EnableLatency, - EnableThroughput: config.MetricsConfig.EnableThroughput, - EnableQPS: config.MetricsConfig.EnableQPS, - EnableErrors: config.MetricsConfig.EnableErrors, - DefaultTags: metrics.Tags{ - "service": config.MetricsConfig.Namespace, // Assuming Namespace is intended for service tag - }, - }) - // The middleware instance itself is now generic, but its Handler method - // returns a standard http.Handler, so the adapter function remains the same. - metricsMiddleware = func(next http.Handler) http.Handler { - // Use the ServiceName from the config for the application name - return metricsMiddlewareImpl.Handler(config.ServiceName, next) - } + // A user-supplied middleware factory takes precedence over building one + // from the Collector. + if factory, ok := config.MetricsConfig.MiddlewareFactory.(metrics.MetricsMiddleware[T, U]); ok { + metricsMiddleware = func(next http.Handler) http.Handler { + return factory.Handler(config.ServiceName, next) + } + } else if registry, ok := config.MetricsConfig.Collector.(metrics.MetricsRegistry); ok { + // Tags applied to every metric emitted by the middleware. + defaultTags := metrics.Tags{} + if config.MetricsConfig.Namespace != "" { + defaultTags["service"] = config.MetricsConfig.Namespace // Assuming Namespace is intended for service tag + } + if config.MetricsConfig.Subsystem != "" { + defaultTags["subsystem"] = config.MetricsConfig.Subsystem + } + + // Create the generic middleware implementation using the router's T and U types + metricsMiddlewareImpl := metrics.NewMetricsMiddleware[T, U](registry, metrics.MetricsMiddlewareConfig{ + EnableLatency: config.MetricsConfig.EnableLatency, + EnableThroughput: config.MetricsConfig.EnableThroughput, + EnableQPS: config.MetricsConfig.EnableQPS, + EnableErrors: config.MetricsConfig.EnableErrors, + DefaultTags: defaultTags, + }) + // The middleware instance itself is now generic, but its Handler method + // returns a standard http.Handler, so the adapter function remains the same. + metricsMiddleware = func(next http.Handler) http.Handler { + // Use the ServiceName from the config for the application name + return metricsMiddlewareImpl.Handler(config.ServiceName, next) } } @@ -199,6 +216,9 @@ func NewRouter[T comparable, U any](config RouterConfig, authFunction func(conte // // This is useful for conditionally adding routes or building routes programmatically. func (r *Router[T, U]) RegisterSubRouter(sr SubRouterConfig) { + // Record the sub-router so later lookups (e.g. RegisterGenericRouteOnSubRouter) + // can resolve it just like sub-routers provided in the initial config. + r.config.SubRouters = append(r.config.SubRouters, sr) r.registerSubRouter(sr) } @@ -267,9 +287,15 @@ func (r *Router[T, U]) registerSubRouterWithResolvedOverrides(sr SubRouterConfig // Register nested sub-routers recursively for _, nestedSR := range sr.SubRouters { - // Create a new sub-router with the combined path prefix + // Create a new sub-router with the combined path prefix, inherited + // middlewares (additive: parent's run before the nested sub-router's), + // and inherited AuthLevel (unless the nested sub-router sets its own). nestedSRWithPrefix := nestedSR nestedSRWithPrefix.PathPrefix = sr.PathPrefix + nestedSR.PathPrefix + nestedSRWithPrefix.Middlewares = combineMiddlewares(sr.Middlewares, nestedSR.Middlewares) + if nestedSRWithPrefix.AuthLevel == nil { + nestedSRWithPrefix.AuthLevel = sr.AuthLevel + } // Register the nested sub-router r.registerSubRouterWithResolvedOverrides(nestedSRWithPrefix, resolvedOverrides) @@ -295,30 +321,25 @@ func (r *Router[T, U]) convertToHTTPRouterHandle(handler http.Handler, routeTemp } // wrapHandler wraps a handler with all the necessary middleware. -// It creates a complete request processing pipeline with the following middleware order: -// 1. Recovery (innermost, catches panics) -// 2. Authentication (if authLevel is set) -// 3. Rate limiting (if rateLimit is set) -// 4. Route-specific middlewares (from the middlewares parameter) -// 5. Global middlewares (from RouterConfig, includes trace and metrics if enabled) -// 6. Timeout (if timeout > 0) -// 7. Shutdown check and body size limit (in the base handler) +// It creates a complete request processing pipeline with the following middleware order, +// from outermost (first to see the request) to innermost (closest to the handler): +// 1. Recovery (outermost, catches panics from everything below it) +// 2. Trace ID injection (if enabled) +// 3. Authentication (if authLevel is set) +// 4. Rate limiting (if rateLimit is set) +// 5. Route-specific middlewares (from the middlewares parameter) +// 6. Global middlewares (from RouterConfig, includes metrics if enabled) +// 7. Timeout (innermost, if timeout > 0) +// 8. Body size limit (in the base handler) // // Middlewares are combined additively, not replaced. func (r *Router[T, U]) wrapHandler(handler http.HandlerFunc, authLevel *AuthLevel, authTokenConfig common.AuthTokenConfig, timeout time.Duration, maxBodySize int64, rateLimit *common.RateLimitConfig[T, U], middlewares []common.Middleware) http.Handler { // Use common.RateLimitConfig // Create a base handler that only handles shutdown check and body size limit directly // Timeout is now handled by timeoutMiddleware setting the context. h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // Shutdown Check - r.wg.Add(1) - defer r.wg.Done() - r.shutdownMu.RLock() - isShutdown := r.shutdown - r.shutdownMu.RUnlock() - if isShutdown { - http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) - return - } + // Note: shutdown check and request tracking happen at the top of + // ServeHTTP so the whole middleware chain is covered, not just the + // base handler. // Apply body size limit if maxBodySize > 0 { @@ -334,7 +355,7 @@ func (r *Router[T, U]) wrapHandler(handler http.HandlerFunc, authLevel *AuthLeve // Append middleware in order of execution (outermost first) - // 1. Recovery (Innermost before handler) + // 1. Recovery (outermost, catches panics from the whole chain) chain = chain.Append(r.recoveryMiddleware) // 2. Trace middleware (if enabled) - positioned early so all middlewares have access to trace ID @@ -491,23 +512,30 @@ type subRouterConfigPrefixMatch struct { } func resolveSubRouterConfigForPrefix(subRouters []SubRouterConfig, targetPrefix string, parentPrefix string, parentOverrides common.RouteOverrides) (SubRouterConfig, bool) { - match := resolveSubRouterConfigForPrefixMatch(subRouters, targetPrefix, parentPrefix, parentOverrides) + match := resolveSubRouterConfigForPrefixMatch(subRouters, targetPrefix, parentPrefix, parentOverrides, nil, nil) if !match.found { return SubRouterConfig{}, false } return match.config, true } -func resolveSubRouterConfigForPrefixMatch(subRouters []SubRouterConfig, targetPrefix string, parentPrefix string, parentOverrides common.RouteOverrides) subRouterConfigPrefixMatch { +func resolveSubRouterConfigForPrefixMatch(subRouters []SubRouterConfig, targetPrefix string, parentPrefix string, parentOverrides common.RouteOverrides, parentMiddlewares []common.Middleware, parentAuthLevel *AuthLevel) subRouterConfigPrefixMatch { var fallback SubRouterConfig fallbackFound := false for _, sr := range subRouters { fullPathPrefix := parentPrefix + sr.PathPrefix resolvedOverrides := resolveSubRouterOverrides(parentOverrides, sr) + resolvedMiddlewares := combineMiddlewares(parentMiddlewares, sr.Middlewares) + resolvedAuthLevel := sr.AuthLevel + if resolvedAuthLevel == nil { + resolvedAuthLevel = parentAuthLevel + } resolvedSR := sr resolvedSR.PathPrefix = fullPathPrefix resolvedSR.Overrides = resolvedOverrides + resolvedSR.Middlewares = resolvedMiddlewares + resolvedSR.AuthLevel = resolvedAuthLevel if fullPathPrefix == targetPrefix { return subRouterConfigPrefixMatch{config: resolvedSR, found: true, exact: true} @@ -516,7 +544,7 @@ func resolveSubRouterConfigForPrefixMatch(subRouters []SubRouterConfig, targetPr fallback = resolvedSR fallbackFound = true } - childMatch := resolveSubRouterConfigForPrefixMatch(sr.SubRouters, targetPrefix, fullPathPrefix, resolvedOverrides) + childMatch := resolveSubRouterConfigForPrefixMatch(sr.SubRouters, targetPrefix, fullPathPrefix, resolvedOverrides, resolvedMiddlewares, resolvedAuthLevel) if childMatch.found { if childMatch.exact { return childMatch @@ -533,6 +561,19 @@ func resolveSubRouterConfigForPrefixMatch(subRouters []SubRouterConfig, targetPr return subRouterConfigPrefixMatch{} } +// combineMiddlewares returns a new slice containing parent middlewares followed +// by child middlewares. The inputs are never modified and the result has its +// own backing array. +func combineMiddlewares(parent, child []common.Middleware) []common.Middleware { + if len(parent) == 0 && len(child) == 0 { + return nil + } + combined := make([]common.Middleware, 0, len(parent)+len(child)) + combined = append(combined, parent...) + combined = append(combined, child...) + return combined +} + // RegisterGenericRouteOnSubRouter registers a generic route on a specific sub-router after router creation. // This function is primarily used for dynamic route registration after the router has been initialized. // For static route configuration, prefer using NewGenericRouteDefinition within SubRouterConfig.Routes. @@ -583,10 +624,11 @@ func RegisterGenericRouteOnSubRouter[Req any, Resp any, UserID comparable, User // Prefix the path finalRouteConfig.Path = sr.PathPrefix + route.Path - // Combine middleware: global + sub-router + route-specific - allMiddlewares := make([]common.Middleware, 0, len(r.middlewares)+len(subRouterMiddlewares)+len(route.Middlewares)) - allMiddlewares = append(allMiddlewares, r.middlewares...) // Global first - allMiddlewares = append(allMiddlewares, subRouterMiddlewares...) // Then sub-router + // Combine middleware: sub-router + route-specific. + // Note: Global middlewares are added later by wrapHandler; including them + // here would apply them twice. + allMiddlewares := make([]common.Middleware, 0, len(subRouterMiddlewares)+len(route.Middlewares)) + allMiddlewares = append(allMiddlewares, subRouterMiddlewares...) // Sub-router first allMiddlewares = append(allMiddlewares, route.Middlewares...) // Then route-specific finalRouteConfig.Middlewares = allMiddlewares // Overwrite middlewares in the config passed down @@ -607,6 +649,20 @@ func RegisterGenericRouteOnSubRouter[Req any, Resp any, UserID comparable, User // It handles HTTP requests by applying CORS, client IP extraction, metrics, tracing, // and then delegating to the underlying httprouter. func (r *Router[T, U]) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Track the in-flight request for graceful shutdown. The Add must happen + // under the shutdown lock so it can never race with Shutdown's wg.Wait: + // Shutdown takes the write lock before waiting, so either this request is + // rejected below, or it is registered before Wait can observe the counter. + r.shutdownMu.RLock() + if r.shutdown { + r.shutdownMu.RUnlock() + http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) + return + } + r.wg.Add(1) + r.shutdownMu.RUnlock() + defer r.wg.Done() + // Handle CORS first var corsHandled bool req, corsHandled = r.handleCORS(w, req) @@ -623,13 +679,15 @@ func (r *Router[T, U]) ServeHTTP(w http.ResponseWriter, req *http.Request) { ctx = scontext.WithUserAgent[T, U](ctx, req.UserAgent()) req = req.WithContext(ctx) - // Apply metrics and tracing if enabled - if r.config.TraceIDBufferSize > 0 { + // Apply request summary logging and status/bytes capture if enabled. + // This is independent of trace IDs: EnableTraceLogging turns it on even + // when TraceIDBufferSize is 0 (trace_id fields are simply absent then). + if r.config.TraceIDBufferSize > 0 || r.config.EnableTraceLogging { // Get a metricsResponseWriter from the pool mrw := r.metricsWriterPool.Get().(*metricsResponseWriter[T, U]) // Initialize the writer with the current request data - mrw.baseResponseWriter = &baseResponseWriter{ResponseWriter: w} + mrw.baseResponseWriter = baseResponseWriter{ResponseWriter: w} mrw.statusCode = http.StatusOK mrw.startTime = time.Now() mrw.request = req @@ -645,8 +703,13 @@ func (r *Router[T, U]) ServeHTTP(w http.ResponseWriter, req *http.Request) { ip, _ := scontext.GetClientIPFromRequest[T, U](req) ua, _ := scontext.GetUserAgentFromRequest[T, U](req) - // 2) Build unified fields - the UNION of all previously separate log fields - fields := append(r.baseFields(req), + // 2) Build unified fields - the UNION of all previously separate log + // fields. Sized for all fields (including the optional trace ID) up + // front so this per-request path allocates the slice exactly once. + fields := make([]zap.Field, 0, 8) + fields = append(fields, + zap.String("method", req.Method), + zap.String("path", req.URL.Path), zap.Int("status", mrw.statusCode), zap.Duration("duration", duration), zap.Int64("bytes", mrw.bytesWritten), @@ -672,7 +735,7 @@ func (r *Router[T, U]) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.logger.Log(lvl, "Request summary statistics", fields...) // Reset fields that might hold references to prevent memory leaks - mrw.baseResponseWriter = nil + mrw.baseResponseWriter = baseResponseWriter{} mrw.request = nil mrw.router = nil @@ -724,9 +787,11 @@ func (r *Router[T, U]) handleCORS(w http.ResponseWriter, req *http.Request) (*ht // For actual requests, we *could* block here, but it's often better to let the request proceed // and let the browser enforce the lack of Allow-Origin header. // However, we MUST NOT set the Allow-Origin header. - // We also need to store the *lack* of allowance in the context. - ctx = scontext.WithCORSInfo[T, U](ctx, "", false) // Store empty origin, false credentials - req = req.WithContext(ctx) + // The *lack* of allowance is stored in the context by the + // unconditional WithCORSInfo call below (correctAllowOrigin stays ""). + // The response still varies by Origin (an allowed origin would get + // CORS headers), so set Vary to keep shared caches correct. + w.Header().Add("Vary", "Origin") // If it's a preflight, handle it below (it will fail the checks). // If it's not preflight, let it continue, but CORS headers won't be set. } @@ -906,8 +971,10 @@ func (bw *baseResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // metricsResponseWriter is a wrapper around http.ResponseWriter that captures metrics. // It tracks the status code, bytes written, and timing information for each response. +// baseResponseWriter is embedded by value so pooled writers are reinitialized +// without allocating a fresh wrapper per request. type metricsResponseWriter[T comparable, U any] struct { - *baseResponseWriter + baseResponseWriter statusCode int bytesWritten int64 startTime time.Time @@ -1158,14 +1225,30 @@ func (r *Router[T, U]) getEffectiveRateLimit(routeRateLimit, subRouterRateLimit return nil } + // Adapt the user ID extraction functions across the type conversion so + // StrategyUser keeps working when configured via [any, any] overrides. + var userIDFromUser func(U) T + if fromUser := config.UserIDFromUser; fromUser != nil { + userIDFromUser = func(user U) T { + id, _ := fromUser(user).(T) + return id + } + } + var userIDToString func(T) string + if toString := config.UserIDToString; toString != nil { + userIDToString = func(userID T) string { + return toString(userID) + } + } + // Create a new config with the correct type parameters return &common.RateLimitConfig[T, U]{ // Use common.RateLimitConfig BucketName: config.BucketName, Limit: config.Limit, Window: config.Window, Strategy: config.Strategy, - UserIDFromUser: nil, // These will need to be set by the caller if needed - UserIDToString: nil, // These will need to be set by the caller if needed + UserIDFromUser: userIDFromUser, + UserIDToString: userIDToString, KeyExtractor: config.KeyExtractor, ExceededHandler: config.ExceededHandler, } @@ -1198,6 +1281,13 @@ func (r *Router[T, U]) addTrace(fields []zap.Field, req *http.Request) []zap.Fie return fields } +// isMaxBytesError reports whether err was caused by http.MaxBytesReader +// rejecting a request body, even if a codec has wrapped the error. +func isMaxBytesError(err error) bool { + var maxBytesErr *http.MaxBytesError + return errors.As(err, &maxBytesErr) +} + // handleError handles an error by logging it and returning an appropriate HTTP response. // It checks if the error is a specific HTTPError and uses its status code and message if available. // It also checks for context deadline exceeded errors. @@ -1218,7 +1308,7 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err statusCode = httpErr.StatusCode message = httpErr.Message r.logger.Error(message, fields...) // Log with the custom message - } else if err != nil && err.Error() == "http: request body too large" { + } else if isMaxBytesError(err) { // Specifically handle MaxBytesReader error statusCode = http.StatusRequestEntityTooLarge message = "Request Entity Too Large" @@ -1374,29 +1464,58 @@ func NewHTTPError(statusCode int, message string) *HTTPError { } // recoveryMiddleware is a middleware that recovers from panics in handlers. -// It logs the panic and returns a 500 Internal Server Error response. +// It logs the panic and returns a 500 Internal Server Error response if the +// response has not been started yet. If the panic occurred after the handler +// began writing, no second response is written (the partial response cannot +// be repaired) and the panic is only logged. // This prevents the server from crashing when a handler panics. func (r *Router[T, U]) recoveryMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + rw := &recoveryResponseWriter{baseResponseWriter: baseResponseWriter{ResponseWriter: w}} defer func() { if rec := recover(); rec != nil { fields := append([]zap.Field{zap.Any("panic", rec)}, r.baseFields(req)...) fields = r.addTrace(fields, req) r.logger.Error("Panic recovered", fields...) - // Return a 500 Internal Server Error + if rw.wrote { + // The handler already started writing; appending a JSON + // error would corrupt the response and trigger a + // superfluous WriteHeader. Log only. + return + } + // Return a 500 Internal Server Error as JSON - // We attempt to write the JSON error. If headers were already written, - // writeJSONError might log an error, but we can't do much more here. traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(w, req, http.StatusInternalServerError, "Internal Server Error", traceID) + r.writeJSONError(rw, req, http.StatusInternalServerError, "Internal Server Error", traceID) } }() - next.ServeHTTP(w, req) + next.ServeHTTP(rw, req) }) } +// recoveryResponseWriter tracks whether the response has been started so the +// recovery middleware can avoid writing a second response after a panic that +// occurred mid-write. baseResponseWriter is embedded by value so the +// per-request wrapper costs a single allocation. +type recoveryResponseWriter struct { + baseResponseWriter + wrote bool +} + +// WriteHeader marks the response as started and delegates to the underlying writer. +func (rw *recoveryResponseWriter) WriteHeader(statusCode int) { + rw.wrote = true + rw.baseResponseWriter.WriteHeader(statusCode) +} + +// Write marks the response as started and delegates to the underlying writer. +func (rw *recoveryResponseWriter) Write(b []byte) (int, error) { + rw.wrote = true + return rw.baseResponseWriter.Write(b) +} + // authenticateRequest attempts to authenticate the request and, if successful, // returns a new request with user information stored in the context. // It does not perform any logging; callers handle logging based on the result. @@ -1408,14 +1527,19 @@ func (r *Router[T, U]) authenticateRequest(req *http.Request, extractToken authT if user, valid := r.authFunction(req.Context(), token); valid { id := r.getUserIdFromUser(user) + // When an SRouterContext already exists (always the case for requests + // routed through ServeHTTP, which installs it before dispatch), the + // With* helpers mutate it in place and return the same context. Any + // trace ID already on that shared context is preserved automatically, + // and cloning the request is only needed when a context was created. + _, hadSRouterCtx := scontext.GetSRouterContext[T, U](req.Context()) ctx := scontext.WithUserID[T, U](req.Context(), id) if r.config.AddUserObjectToCtx { ctx = scontext.WithUser[T](ctx, user) } - if traceID := scontext.GetTraceIDFromRequest[T, U](req); traceID != "" { - ctx = scontext.WithTraceID[T, U](ctx, traceID) + if !hadSRouterCtx { + req = req.WithContext(ctx) } - req = req.WithContext(ctx) return req, true, "" } return req, false, "invalid token" @@ -1530,6 +1654,11 @@ func (rw *mutexResponseWriter) Write(b []byte) (int, error) { } rw.mu.Lock() defer rw.mu.Unlock() + // Re-check under the lock: the timeout response may have been written + // while this write was waiting for the mutex. + if rw.timedOut.Load() { + return 0, http.ErrHandlerTimeout + } rw.wroteHeader.Store(true) // Mark as written (headers might be implicitly written here) return rw.ResponseWriter.Write(b) } diff --git a/pkg/router/timeout_middleware_race_test.go b/pkg/router/timeout_middleware_race_test.go index b8ffa20..b53732c 100644 --- a/pkg/router/timeout_middleware_race_test.go +++ b/pkg/router/timeout_middleware_race_test.go @@ -102,8 +102,10 @@ func TestTimeoutMiddleware_WhenHandlerPanicsInCASFailurePath_RethrowsToRecovery( mrw := <-mrwCh if rr.Code == http.StatusAccepted && mrw.timedOut.Load() { - if msg := parseJSONErrorMessage(t, rr.Body.Bytes()); msg != "Internal Server Error" { - t.Fatalf("expected internal server error payload, got %q", msg) + // The handler won the write race (202 was sent) before panicking, + // so recovery must not append a JSON error to the started response. + if body := rr.Body.String(); body != "" { + t.Fatalf("expected no additional body after mid-response panic, got %q", body) } return } diff --git a/pkg/router/timeout_middleware_test.go b/pkg/router/timeout_middleware_test.go index cfca017..e48215d 100644 --- a/pkg/router/timeout_middleware_test.go +++ b/pkg/router/timeout_middleware_test.go @@ -2,7 +2,6 @@ package router import ( "context" - "encoding/json" "net/http" "net/http/httptest" "testing" @@ -12,20 +11,6 @@ import ( "go.uber.org/zap" ) -func parseJSONErrorMessage(t *testing.T, body []byte) string { - t.Helper() - - var payload struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } - if err := json.Unmarshal(body, &payload); err != nil { - t.Fatalf("expected JSON error payload, got %q: %v", string(body), err) - } - return payload.Error.Message -} - func TestTimeoutMiddleware_WhenHandlerStartedWriting_DoesNotOverrideResponse(t *testing.T) { r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) @@ -95,7 +80,9 @@ func TestTimeoutMiddleware_WhenHandlerPanicsAfterTimeoutAndStartedWrite_Rethrows if rr.Code != http.StatusTeapot { t.Fatalf("expected status %d, got %d", http.StatusTeapot, rr.Code) } - if msg := parseJSONErrorMessage(t, rr.Body.Bytes()); msg != "Internal Server Error" { - t.Fatalf("expected internal server error payload, got %q", msg) + // The handler already started writing before it panicked, so recovery must + // not append a second (JSON error) response onto the partial one. + if body := rr.Body.String(); body != "" { + t.Fatalf("expected no additional body after mid-response panic, got %q", body) } } diff --git a/pkg/scontext/context.go b/pkg/scontext/context.go index b06e8fc..452c9a4 100644 --- a/pkg/scontext/context.go +++ b/pkg/scontext/context.go @@ -9,6 +9,7 @@ import ( "context" "maps" "net/http" + "sync" "github.com/julienschmidt/httprouter" // Import for Params type "gorm.io/gorm" // Needed for DatabaseTransaction @@ -33,7 +34,17 @@ type DatabaseTransaction interface { // It provides a centralized storage for all request-scoped data, avoiding // the need for multiple context.WithValue calls and deep context nesting. // T is the User ID type (comparable), U is the User object type (any). +// +// The struct is shared by pointer across the request's middleware chain, and +// a timed-out request's handler goroutine may still be mutating it while the +// router goroutine reads it. All access through this package's With*/Get* +// helpers is therefore synchronized by an internal lock; prefer the helpers +// over touching fields directly. type SRouterContext[T comparable, U any] struct { + // mu guards all fields below. The With*/Get* helper functions in this + // package take it automatically. + mu sync.RWMutex + UserID T User *U @@ -74,13 +85,12 @@ type SRouterContext[T comparable, U any] struct { Flags map[string]bool } -// NewSRouterContext creates a new SRouterContext instance with initialized fields. -// It returns a pointer to a new context with an empty Flags map ready for use. +// NewSRouterContext creates a new SRouterContext instance. +// The Flags map is allocated lazily by WithFlag on first use, so contexts on +// requests that never set a flag (the common case) avoid the map allocation. // T is the User ID type (comparable), U is the User object type (any). func NewSRouterContext[T comparable, U any]() *SRouterContext[T, U] { - return &SRouterContext[T, U]{ - Flags: make(map[string]bool), - } + return &SRouterContext[T, U]{} } // GetSRouterContext retrieves the SRouterContext from a standard context.Context. @@ -118,8 +128,10 @@ func EnsureSRouterContext[T comparable, U any](ctx context.Context) (*SRouterCon // T is the User ID type (comparable), U is the User object type (any). func WithUserID[T comparable, U any](ctx context.Context, userID T) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.UserID = userID rc.UserIDSet = true + rc.mu.Unlock() return ctx } @@ -130,7 +142,12 @@ func WithUserID[T comparable, U any](ctx context.Context, userID T) context.Cont func GetUserID[T comparable, U any](ctx context.Context) (T, bool) { var zero T rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.UserIDSet { + if !ok { + return zero, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.UserIDSet { return zero, false } return rc.UserID, true @@ -148,8 +165,10 @@ func GetUserIDFromRequest[T comparable, U any](r *http.Request) (T, bool) { // T is the User ID type (comparable), U is the User object type (any). func WithUser[T comparable, U any](ctx context.Context, user *U) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.User = user rc.UserSet = true + rc.mu.Unlock() return ctx } @@ -159,7 +178,12 @@ func WithUser[T comparable, U any](ctx context.Context, user *U) context.Context // T is the User ID type (comparable), U is the User object type (any). func GetUser[T comparable, U any](ctx context.Context) (*U, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.UserSet { + if !ok { + return nil, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.UserSet { return nil, false } return rc.User, true @@ -178,10 +202,12 @@ func GetUserFromRequest[T comparable, U any](r *http.Request) (*U, bool) { // T is the User ID type (comparable), U is the User object type (any). func WithFlag[T comparable, U any](ctx context.Context, name string, value bool) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() if rc.Flags == nil { rc.Flags = make(map[string]bool) } rc.Flags[name] = value + rc.mu.Unlock() return ctx } @@ -191,7 +217,12 @@ func WithFlag[T comparable, U any](ctx context.Context, name string, value bool) // T is the User ID type (comparable), U is the User object type (any). func GetFlag[T comparable, U any](ctx context.Context, name string) (bool, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || rc.Flags == nil { + if !ok { + return false, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if rc.Flags == nil { return false, false } value, exists := rc.Flags[name] @@ -211,8 +242,10 @@ func GetFlagFromRequest[T comparable, U any](r *http.Request, name string) (bool // T is the User ID type (comparable), U is the User object type (any). func WithClientIP[T comparable, U any](ctx context.Context, ip string) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.ClientIP = ip rc.ClientIPSet = true + rc.mu.Unlock() return ctx } @@ -222,7 +255,12 @@ func WithClientIP[T comparable, U any](ctx context.Context, ip string) context.C // T is the User ID type (comparable), U is the User object type (any). func GetClientIP[T comparable, U any](ctx context.Context) (string, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.ClientIPSet { + if !ok { + return "", false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.ClientIPSet { return "", false } return rc.ClientIP, true @@ -240,8 +278,10 @@ func GetClientIPFromRequest[T comparable, U any](r *http.Request) (string, bool) // T is the User ID type (comparable), U is the User object type (any). func WithUserAgent[T comparable, U any](ctx context.Context, ua string) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.UserAgent = ua rc.UserAgentSet = true + rc.mu.Unlock() return ctx } @@ -251,7 +291,12 @@ func WithUserAgent[T comparable, U any](ctx context.Context, ua string) context. // T is the User ID type (comparable), U is the User object type (any). func GetUserAgent[T comparable, U any](ctx context.Context) (string, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.UserAgentSet { + if !ok { + return "", false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.UserAgentSet { return "", false } return rc.UserAgent, true @@ -271,8 +316,10 @@ func GetUserAgentFromRequest[T comparable, U any](r *http.Request) (string, bool // T is the User ID type (comparable), U is the User object type (any). func WithTransaction[T comparable, U any](ctx context.Context, tx DatabaseTransaction) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.Transaction = tx rc.TransactionSet = true + rc.mu.Unlock() return ctx } @@ -282,7 +329,12 @@ func WithTransaction[T comparable, U any](ctx context.Context, tx DatabaseTransa // T is the User ID type (comparable), U is the User object type (any). func GetTransaction[T comparable, U any](ctx context.Context) (DatabaseTransaction, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.TransactionSet { + if !ok { + return nil, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.TransactionSet { return nil, false } return rc.Transaction, true @@ -302,6 +354,8 @@ func GetTransactionFromRequest[T comparable, U any](r *http.Request) (DatabaseTr // T is the User ID type (comparable), U is the User object type (any). func WithTraceID[T comparable, U any](ctx context.Context, traceID string) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() + defer rc.mu.Unlock() // If TraceID is already set, do not overwrite it. if rc.TraceIDSet { return ctx @@ -318,7 +372,12 @@ func WithTraceID[T comparable, U any](ctx context.Context, traceID string) conte // T is the User ID type (comparable), U is the User object type (any). func GetTraceIDFromContext[T comparable, U any](ctx context.Context) string { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.TraceIDSet { + if !ok { + return "" + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.TraceIDSet { return "" } return rc.TraceID @@ -339,9 +398,11 @@ func GetTraceIDFromRequest[T comparable, U any](r *http.Request) string { // T is the User ID type (comparable), U is the User object type (any). func WithRouteInfo[T comparable, U any](ctx context.Context, params httprouter.Params, routeTemplate string) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.PathParams = params rc.RouteTemplate = routeTemplate rc.RouteTemplateSet = true + rc.mu.Unlock() return ctx } @@ -352,7 +413,12 @@ func WithRouteInfo[T comparable, U any](ctx context.Context, params httprouter.P // T is the User ID type (comparable), U is the User object type (any). func GetRouteTemplateFromContext[T comparable, U any](ctx context.Context) (string, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.RouteTemplateSet { + if !ok { + return "", false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.RouteTemplateSet { return "", false } return rc.RouteTemplate, true @@ -371,7 +437,12 @@ func GetRouteTemplateFromRequest[T comparable, U any](r *http.Request) (string, // T is the User ID type (comparable), U is the User object type (any). func GetPathParamsFromContext[T comparable, U any](ctx context.Context) (httprouter.Params, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.RouteTemplateSet { // Use RouteTemplateSet as indicator that params are also set + if !ok { + return nil, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.RouteTemplateSet { // Use RouteTemplateSet as indicator that params are also set return nil, false } return rc.PathParams, true @@ -391,10 +462,12 @@ func GetPathParamsFromRequest[T comparable, U any](r *http.Request) (httprouter. // T is the User ID type (comparable), U is the User object type (any). func WithCORSInfo[T comparable, U any](ctx context.Context, allowedOrigin string, credentialsAllowed bool) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.AllowedOrigin = allowedOrigin rc.CredentialsAllowed = credentialsAllowed rc.AllowedOriginSet = true rc.CredentialsAllowedSet = true // Set both flags when info is added + rc.mu.Unlock() return ctx } @@ -406,7 +479,12 @@ func WithCORSInfo[T comparable, U any](ctx context.Context, allowedOrigin string // T is the User ID type (comparable), U is the User object type (any). func GetCORSInfo[T comparable, U any](ctx context.Context) (allowedOrigin string, credentialsAllowed bool, ok bool) { rc, found := GetSRouterContext[T, U](ctx) - if !found || !rc.AllowedOriginSet { // Check if origin was set as the primary indicator + if !found { + return "", false, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.AllowedOriginSet { // Check if origin was set as the primary indicator return "", false, false } // Return the stored values. CredentialsAllowedSet is implicitly true if AllowedOriginSet is true based on WithCORSInfo logic. @@ -426,8 +504,10 @@ func GetCORSInfoFromRequest[T comparable, U any](r *http.Request) (allowedOrigin // T is the User ID type (comparable), U is the User object type (any). func WithCORSRequestedHeaders[T comparable, U any](ctx context.Context, requestedHeaders string) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.RequestedHeaders = requestedHeaders rc.RequestedHeadersSet = true + rc.mu.Unlock() return ctx } @@ -438,7 +518,12 @@ func WithCORSRequestedHeaders[T comparable, U any](ctx context.Context, requeste // T is the User ID type (comparable), U is the User object type (any). func GetCORSRequestedHeaders[T comparable, U any](ctx context.Context) (string, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.RequestedHeadersSet { + if !ok { + return "", false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.RequestedHeadersSet { return "", false } return rc.RequestedHeaders, true @@ -456,8 +541,10 @@ func GetCORSRequestedHeadersFromRequest[T comparable, U any](r *http.Request) (s // T is the User ID type (comparable), U is the User object type (any). func WithHandlerError[T comparable, U any](ctx context.Context, err error) context.Context { rc, ctx := EnsureSRouterContext[T, U](ctx) + rc.mu.Lock() rc.HandlerError = err rc.HandlerErrorSet = true + rc.mu.Unlock() return ctx } @@ -467,7 +554,12 @@ func WithHandlerError[T comparable, U any](ctx context.Context, err error) conte // T is the User ID type (comparable), U is the User object type (any). func GetHandlerError[T comparable, U any](ctx context.Context) (error, bool) { rc, ok := GetSRouterContext[T, U](ctx) - if !ok || !rc.HandlerErrorSet { + if !ok { + return nil, false + } + rc.mu.RLock() + defer rc.mu.RUnlock() + if !rc.HandlerErrorSet { return nil, false } return rc.HandlerError, true @@ -501,6 +593,8 @@ func GetHandlerErrorFromRequest[T comparable, U any](r *http.Request) (error, bo // This is an internal helper function used by the various copy functions. // T is the User ID type (comparable), U is the User object type (any). func cloneSRouterContext[T comparable, U any](src *SRouterContext[T, U]) *SRouterContext[T, U] { + src.mu.RLock() + defer src.mu.RUnlock() dst := &SRouterContext[T, U]{ UserID: src.UserID, User: src.User, diff --git a/pkg/scontext/missing_context_test.go b/pkg/scontext/missing_context_test.go new file mode 100644 index 0000000..0c19c34 --- /dev/null +++ b/pkg/scontext/missing_context_test.go @@ -0,0 +1,72 @@ +package scontext + +import ( + "context" + "testing" +) + +// TestGettersWithoutSRouterContext verifies that every getter degrades +// gracefully when called on a context that has no SRouterContext — e.g. a +// helper invoked outside the router's middleware chain. Each getter must +// return its zero value and report not-found instead of panicking or +// returning stale data. +func TestGettersWithoutSRouterContext(t *testing.T) { + ctx := context.Background() + + if id, ok := GetUserID[string, any](ctx); ok || id != "" { + t.Errorf("GetUserID = (%q, %v), want (\"\", false)", id, ok) + } + if user, ok := GetUser[string, any](ctx); ok || user != nil { + t.Errorf("GetUser = (%v, %v), want (nil, false)", user, ok) + } + if value, exists := GetFlag[string, any](ctx, "feature"); exists || value { + t.Errorf("GetFlag = (%v, %v), want (false, false)", value, exists) + } + if ip, ok := GetClientIP[string, any](ctx); ok || ip != "" { + t.Errorf("GetClientIP = (%q, %v), want (\"\", false)", ip, ok) + } + if ua, ok := GetUserAgent[string, any](ctx); ok || ua != "" { + t.Errorf("GetUserAgent = (%q, %v), want (\"\", false)", ua, ok) + } + if tx, ok := GetTransaction[string, any](ctx); ok || tx != nil { + t.Errorf("GetTransaction = (%v, %v), want (nil, false)", tx, ok) + } + if traceID := GetTraceIDFromContext[string, any](ctx); traceID != "" { + t.Errorf("GetTraceIDFromContext = %q, want \"\"", traceID) + } + if tmpl, ok := GetRouteTemplateFromContext[string, any](ctx); ok || tmpl != "" { + t.Errorf("GetRouteTemplateFromContext = (%q, %v), want (\"\", false)", tmpl, ok) + } + if params, ok := GetPathParamsFromContext[string, any](ctx); ok || params != nil { + t.Errorf("GetPathParamsFromContext = (%v, %v), want (nil, false)", params, ok) + } + if origin, creds, ok := GetCORSInfo[string, any](ctx); ok || origin != "" || creds { + t.Errorf("GetCORSInfo = (%q, %v, %v), want (\"\", false, false)", origin, creds, ok) + } + if hdrs, ok := GetCORSRequestedHeaders[string, any](ctx); ok || hdrs != "" { + t.Errorf("GetCORSRequestedHeaders = (%q, %v), want (\"\", false)", hdrs, ok) + } + if err, ok := GetHandlerError[string, any](ctx); ok || err != nil { + t.Errorf("GetHandlerError = (%v, %v), want (nil, false)", err, ok) + } +} + +// TestGettersWithSRouterContextButUnsetValues verifies that getters report +// not-found when an SRouterContext exists but the specific value was never set, +// so callers can distinguish "never set" from a set zero value. +func TestGettersWithSRouterContextButUnsetValues(t *testing.T) { + ctx := WithSRouterContext(context.Background(), NewSRouterContext[string, any]()) + + if id, ok := GetUserID[string, any](ctx); ok || id != "" { + t.Errorf("GetUserID = (%q, %v), want (\"\", false)", id, ok) + } + if value, exists := GetFlag[string, any](ctx, "feature"); exists || value { + t.Errorf("GetFlag = (%v, %v), want (false, false)", value, exists) + } + if traceID := GetTraceIDFromContext[string, any](ctx); traceID != "" { + t.Errorf("GetTraceIDFromContext = %q, want \"\"", traceID) + } + if origin, creds, ok := GetCORSInfo[string, any](ctx); ok || origin != "" || creds { + t.Errorf("GetCORSInfo = (%q, %v, %v), want (\"\", false, false)", origin, creds, ok) + } +}