From a580f9d5b3557a8ab97c4cab21c320b620ae9b18 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 25 May 2026 19:41:55 -0500 Subject: [PATCH 1/3] Skip OpenAI cache keys for Groq OSS models --- interceptors/promptcaching.go | 15 +++++- interceptors/promptcaching_test.go | 74 ++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/interceptors/promptcaching.go b/interceptors/promptcaching.go index 3b7be27..bd4c49b 100644 --- a/interceptors/promptcaching.go +++ b/interceptors/promptcaching.go @@ -66,7 +66,7 @@ func (i *PromptCachingInterceptor) Intercept(req *http.Request, meta llmproxy.Bo case "anthropic": shouldApply = strings.Contains(modelLower, "claude") case "openai": - shouldApply = isOpenAIModel(modelLower) + shouldApply = isOpenAICacheProvider(meta) && isOpenAIModel(modelLower) case "xai": shouldApply = isXAIModel(modelLower) case "fireworks": @@ -648,6 +648,19 @@ func isOpenAIModel(modelLower string) bool { strings.Contains(modelLower, "chatgpt") } +func isOpenAICacheProvider(meta llmproxy.BodyMetadata) bool { + if meta.Custom == nil { + return true + } + provider, _ := meta.Custom["provider"].(string) + switch provider { + case "", "openai", "azure": + return true + default: + return false + } +} + func isXAIModel(modelLower string) bool { return strings.Contains(modelLower, "grok") } diff --git a/interceptors/promptcaching_test.go b/interceptors/promptcaching_test.go index b674de8..e26a93b 100644 --- a/interceptors/promptcaching_test.go +++ b/interceptors/promptcaching_test.go @@ -368,6 +368,80 @@ func TestPromptCachingInterceptor_SkipsNonOpenAI(t *testing.T) { } } +func TestPromptCachingInterceptor_SkipsGroqOpenAIOSSModel(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"prompt_cache_key"`)) { + t.Error("Request body should NOT contain prompt_cache_key for Groq-routed OpenAI OSS model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "test-key") + + reqBody := []byte(`{"model":"openai/gpt-oss-20b","messages":[]}`) + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader(reqBody)) + meta := llmproxy.BodyMetadata{ + Model: "openai/gpt-oss-20b", + Custom: map[string]any{ + "provider": "groq", + }, + } + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, reqBody, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_AllowsAzureOpenAIModel(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if !bytes.Contains(body, []byte(`"prompt_cache_key"`)) { + t.Error("Request body should contain prompt_cache_key for Azure OpenAI model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "test-key") + + reqBody := []byte(`{"model":"gpt-4","messages":[]}`) + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader(reqBody)) + meta := llmproxy.BodyMetadata{ + Model: "gpt-4", + Custom: map[string]any{ + "provider": "azure", + }, + } + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, reqBody, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + func TestPromptCachingInterceptor_AnthropicExistingCacheControl(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) From 24079cdd395ae0e7d0674136591b6a5b08d14757 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 25 May 2026 19:44:30 -0500 Subject: [PATCH 2/3] Map upstream timeouts to gateway timeout --- autorouter.go | 28 ++++++++++++++++-- autorouter_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/autorouter.go b/autorouter.go index afa08e7..6ea5262 100644 --- a/autorouter.go +++ b/autorouter.go @@ -9,6 +9,7 @@ import ( "io" "mime" "mime/multipart" + "net" "net/http" "net/url" "strings" @@ -624,7 +625,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { if isWebSocketUpgrade(r) && a.wsUpgrader != nil && a.wsDialer != nil { if err := a.ForwardWebSocket(r.Context(), w, r); err != nil { if !headerSent(w) { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), statusCodeForForwardError(err)) } } return @@ -655,7 +656,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, err := a.ForwardStreaming(r.Context(), r, w) if err != nil { if !headerSent(w) { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), statusCodeForForwardError(err)) } return } @@ -665,7 +666,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp, meta, err := a.Forward(r.Context(), r) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), statusCodeForForwardError(err)) return } defer resp.Body.Close() @@ -709,6 +710,27 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func statusCodeForForwardError(err error) int { + if isForwardTimeoutError(err) { + return http.StatusGatewayTimeout + } + return http.StatusInternalServerError +} + +func isForwardTimeoutError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + return strings.Contains(strings.ToLower(err.Error()), "timeout awaiting response headers") +} + func isWebSocketUpgrade(r *http.Request) bool { connection := strings.ToLower(r.Header.Get("Connection")) upgrade := strings.ToLower(r.Header.Get("Upgrade")) diff --git a/autorouter_test.go b/autorouter_test.go index 2645986..a1fe379 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -878,6 +878,80 @@ func TestAutoRouter_ServeHTTP(t *testing.T) { } } +func TestAutoRouter_ServeHTTPMapsUpstreamTimeoutToGatewayTimeout(t *testing.T) { + provider := &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-5"}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse("https://api.openai.com") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + return ResponseMetadata{}, nil, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(ProviderHint) string { return "test" })), + WithAutoRouterHTTPClient(&http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("Post https://api.openai.com/v1/chat/completions: net/http: timeout awaiting response headers") + })}), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-5","messages":[{"role":"user","content":"hi"}]}`))) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusGatewayTimeout { + t.Fatalf("StatusCode = %d, want 504", w.Code) + } +} + +func TestAutoRouter_ServeHTTPMapsStreamingUpstreamTimeoutToGatewayTimeout(t *testing.T) { + provider := &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-5", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse("https://api.openai.com") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + return ResponseMetadata{}, nil, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(ProviderHint) string { return "test" })), + WithAutoRouterHTTPClient(&http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, context.DeadlineExceeded + })}), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-5","stream":true,"messages":[{"role":"user","content":"hi"}]}`))) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusGatewayTimeout { + t.Fatalf("StatusCode = %d, want 504", w.Code) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + func ParseURL(s string) (*url.URL, error) { return url.Parse(s) } From 33ed3bd94cb7970804b1698c829a0e56d0c99093 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 25 May 2026 19:56:45 -0500 Subject: [PATCH 3/3] Normalize OpenAI max token parameters --- autorouter.go | 35 +++++++++++ autorouter_test.go | 147 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+) diff --git a/autorouter.go b/autorouter.go index 6ea5262..c528ac2 100644 --- a/autorouter.go +++ b/autorouter.go @@ -988,6 +988,11 @@ func normalizeProviderRequest(raw map[string]any, providerName string) { return } + if providerName == "openai" { + normalizeOpenAIRequest(raw) + return + } + if providerName != "deepseek" { return } @@ -1139,6 +1144,36 @@ func hasPositiveNumber(value any) bool { } } +func normalizeOpenAIRequest(raw map[string]any) { + if !openAIModelUsesMaxCompletionTokens(raw["model"]) { + return + } + + maxTokens, ok := raw["max_tokens"] + if !ok { + return + } + if _, exists := raw["max_completion_tokens"]; !exists { + raw["max_completion_tokens"] = maxTokens + } + delete(raw, "max_tokens") +} + +func openAIModelUsesMaxCompletionTokens(model any) bool { + modelName, ok := model.(string) + if !ok { + return false + } + modelName = strings.ToLower(strings.TrimSpace(modelName)) + if stripped, hasPrefix := stripProviderPrefix(modelName); hasPrefix { + modelName = stripped + } + return strings.HasPrefix(modelName, "gpt-5") || + strings.HasPrefix(modelName, "o1") || + strings.HasPrefix(modelName, "o3") || + strings.HasPrefix(modelName, "o4") +} + func normalizeGoogleAIRequest(raw map[string]any) { delete(raw, "stream") diff --git a/autorouter_test.go b/autorouter_test.go index a1fe379..5045556 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -1810,6 +1810,153 @@ func TestAutoRouter_AnthropicRemovesSystemMessageWithMissingContent(t *testing.T } } +func TestAutoRouter_OpenAIGPT5UsesMaxCompletionTokens(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"chatcmpl_test","object":"chat.completion","model":"gpt-5","choices":[{"message":{"role":"assistant","content":"Hello"}}],"usage":{"prompt_tokens":8,"completion_tokens":1,"total_tokens":9}}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-5", MaxTokens: 64}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "chatcmpl_test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "openai" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"gpt-5","max_tokens":64,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + if _, exists := receivedBody["max_tokens"]; exists { + t.Fatalf("max_tokens should be removed for gpt-5: %#v", receivedBody) + } + if got := receivedBody["max_completion_tokens"]; got != float64(64) { + t.Fatalf("max_completion_tokens = %v, want 64", got) + } +} + +func TestAutoRouter_OpenAILegacyModelPreservesMaxTokens(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"chatcmpl_test","object":"chat.completion","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"Hello"}}],"usage":{"prompt_tokens":8,"completion_tokens":1,"total_tokens":9}}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4o", MaxTokens: 64}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "chatcmpl_test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "openai" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4o","max_tokens":64,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + if got := receivedBody["max_tokens"]; got != float64(64) { + t.Fatalf("max_tokens = %v, want 64", got) + } + if _, exists := receivedBody["max_completion_tokens"]; exists { + t.Fatalf("max_completion_tokens should not be injected for gpt-4o: %#v", receivedBody) + } +} + +func TestAutoRouter_OpenAIPreservesExplicitMaxCompletionTokens(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"chatcmpl_test","object":"chat.completion","model":"o4-mini","choices":[{"message":{"role":"assistant","content":"Hello"}}],"usage":{"prompt_tokens":8,"completion_tokens":1,"total_tokens":9}}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "o4-mini", MaxTokens: 64}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "chatcmpl_test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "openai" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"o4-mini","max_tokens":64,"max_completion_tokens":32,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + if _, exists := receivedBody["max_tokens"]; exists { + t.Fatalf("max_tokens should be removed for o4-mini: %#v", receivedBody) + } + if got := receivedBody["max_completion_tokens"]; got != float64(32) { + t.Fatalf("max_completion_tokens = %v, want 32", got) + } +} + func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream")