Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions autorouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"mime"
"mime/multipart"
"net"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -624,7 +625,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if isWebSocketUpgrade(r) && a.wsUpgrader != nil && a.wsDialer != nil {
if err := a.ForwardWebSocket(r.Context(), w, r); err != nil {
if !headerSent(w) {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, err.Error(), statusCodeForForwardError(err))
}
}
return
Expand Down Expand Up @@ -655,7 +656,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, err := a.ForwardStreaming(r.Context(), r, w)
if err != nil {
if !headerSent(w) {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, err.Error(), statusCodeForForwardError(err))
}
return
}
Expand All @@ -665,7 +666,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {

resp, meta, err := a.Forward(r.Context(), r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, err.Error(), statusCodeForForwardError(err))
return
}
defer resp.Body.Close()
Expand Down Expand Up @@ -709,6 +710,27 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

func statusCodeForForwardError(err error) int {
if isForwardTimeoutError(err) {
return http.StatusGatewayTimeout
}
return http.StatusInternalServerError
}

func isForwardTimeoutError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
return strings.Contains(strings.ToLower(err.Error()), "timeout awaiting response headers")
}

func isWebSocketUpgrade(r *http.Request) bool {
connection := strings.ToLower(r.Header.Get("Connection"))
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
Expand Down Expand Up @@ -966,6 +988,11 @@ func normalizeProviderRequest(raw map[string]any, providerName string) {
return
}

if providerName == "openai" {
normalizeOpenAIRequest(raw)
return
}

if providerName != "deepseek" {
return
}
Expand Down Expand Up @@ -1117,6 +1144,36 @@ func hasPositiveNumber(value any) bool {
}
}

func normalizeOpenAIRequest(raw map[string]any) {
if !openAIModelUsesMaxCompletionTokens(raw["model"]) {
return
}

maxTokens, ok := raw["max_tokens"]
if !ok {
return
}
if _, exists := raw["max_completion_tokens"]; !exists {
raw["max_completion_tokens"] = maxTokens
}
delete(raw, "max_tokens")
}

func openAIModelUsesMaxCompletionTokens(model any) bool {
modelName, ok := model.(string)
if !ok {
return false
}
modelName = strings.ToLower(strings.TrimSpace(modelName))
if stripped, hasPrefix := stripProviderPrefix(modelName); hasPrefix {
modelName = stripped
}
return strings.HasPrefix(modelName, "gpt-5") ||
strings.HasPrefix(modelName, "o1") ||
strings.HasPrefix(modelName, "o3") ||
strings.HasPrefix(modelName, "o4")
}

func normalizeGoogleAIRequest(raw map[string]any) {
delete(raw, "stream")

Expand Down
221 changes: 221 additions & 0 deletions autorouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,80 @@ func TestAutoRouter_ServeHTTP(t *testing.T) {
}
}

func TestAutoRouter_ServeHTTPMapsUpstreamTimeoutToGatewayTimeout(t *testing.T) {
provider := &mockProvider{
name: "test",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "gpt-5"}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse("https://api.openai.com")
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
return ResponseMetadata{}, nil, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(ProviderHint) string { return "test" })),
WithAutoRouterHTTPClient(&http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("Post https://api.openai.com/v1/chat/completions: net/http: timeout awaiting response headers")
})}),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-5","messages":[{"role":"user","content":"hi"}]}`)))
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusGatewayTimeout {
t.Fatalf("StatusCode = %d, want 504", w.Code)
}
}

func TestAutoRouter_ServeHTTPMapsStreamingUpstreamTimeoutToGatewayTimeout(t *testing.T) {
provider := &mockProvider{
name: "test",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "gpt-5", Stream: true}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse("https://api.openai.com")
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
return ResponseMetadata{}, nil, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(ProviderHint) string { return "test" })),
WithAutoRouterHTTPClient(&http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, context.DeadlineExceeded
})}),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-5","stream":true,"messages":[{"role":"user","content":"hi"}]}`)))
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusGatewayTimeout {
t.Fatalf("StatusCode = %d, want 504", w.Code)
}
}

type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func ParseURL(s string) (*url.URL, error) {
return url.Parse(s)
}
Expand Down Expand Up @@ -1736,6 +1810,153 @@ func TestAutoRouter_AnthropicRemovesSystemMessageWithMissingContent(t *testing.T
}
}

func TestAutoRouter_OpenAIGPT5UsesMaxCompletionTokens(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"chatcmpl_test","object":"chat.completion","model":"gpt-5","choices":[{"message":{"role":"assistant","content":"Hello"}}],"usage":{"prompt_tokens":8,"completion_tokens":1,"total_tokens":9}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "openai",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "gpt-5", MaxTokens: 64}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "chatcmpl_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "openai" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"gpt-5","max_tokens":64,"messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if _, exists := receivedBody["max_tokens"]; exists {
t.Fatalf("max_tokens should be removed for gpt-5: %#v", receivedBody)
}
if got := receivedBody["max_completion_tokens"]; got != float64(64) {
t.Fatalf("max_completion_tokens = %v, want 64", got)
}
}

func TestAutoRouter_OpenAILegacyModelPreservesMaxTokens(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"chatcmpl_test","object":"chat.completion","model":"gpt-4o","choices":[{"message":{"role":"assistant","content":"Hello"}}],"usage":{"prompt_tokens":8,"completion_tokens":1,"total_tokens":9}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "openai",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "gpt-4o", MaxTokens: 64}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "chatcmpl_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "openai" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4o","max_tokens":64,"messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["max_tokens"]; got != float64(64) {
t.Fatalf("max_tokens = %v, want 64", got)
}
if _, exists := receivedBody["max_completion_tokens"]; exists {
t.Fatalf("max_completion_tokens should not be injected for gpt-4o: %#v", receivedBody)
}
}

func TestAutoRouter_OpenAIPreservesExplicitMaxCompletionTokens(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"chatcmpl_test","object":"chat.completion","model":"o4-mini","choices":[{"message":{"role":"assistant","content":"Hello"}}],"usage":{"prompt_tokens":8,"completion_tokens":1,"total_tokens":9}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "openai",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "o4-mini", MaxTokens: 64}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "chatcmpl_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "openai" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"o4-mini","max_tokens":64,"max_completion_tokens":32,"messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if _, exists := receivedBody["max_tokens"]; exists {
t.Fatalf("max_tokens should be removed for o4-mini: %#v", receivedBody)
}
if got := receivedBody["max_completion_tokens"]; got != float64(32) {
t.Fatalf("max_completion_tokens = %v, want 32", got)
}
}

func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand Down
15 changes: 14 additions & 1 deletion interceptors/promptcaching.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (i *PromptCachingInterceptor) Intercept(req *http.Request, meta llmproxy.Bo
case "anthropic":
shouldApply = strings.Contains(modelLower, "claude")
case "openai":
shouldApply = isOpenAIModel(modelLower)
shouldApply = isOpenAICacheProvider(meta) && isOpenAIModel(modelLower)
case "xai":
shouldApply = isXAIModel(modelLower)
case "fireworks":
Expand Down Expand Up @@ -648,6 +648,19 @@ func isOpenAIModel(modelLower string) bool {
strings.Contains(modelLower, "chatgpt")
}

func isOpenAICacheProvider(meta llmproxy.BodyMetadata) bool {
if meta.Custom == nil {
return true
}
provider, _ := meta.Custom["provider"].(string)
switch provider {
case "", "openai", "azure":
return true
default:
return false
}
}

func isXAIModel(modelLower string) bool {
return strings.Contains(modelLower, "grok")
}
Expand Down
Loading
Loading