diff --git a/autorouter.go b/autorouter.go index 96cb45a..b17395a 100644 --- a/autorouter.go +++ b/autorouter.go @@ -961,6 +961,11 @@ func normalizeProviderRequest(raw map[string]any, providerName string) { return } + if providerName == "anthropic" { + normalizeAnthropicRequest(raw) + return + } + if providerName != "deepseek" { return } @@ -993,6 +998,49 @@ func normalizeProviderRequest(raw map[string]any, providerName string) { } } +const defaultAnthropicMaxTokens = 1024 + +func normalizeAnthropicRequest(raw map[string]any) { + if hasPositiveNumber(raw["max_tokens"]) { + return + } + raw["max_tokens"] = defaultAnthropicMaxTokens +} + +func hasPositiveNumber(value any) bool { + switch v := value.(type) { + case int: + return v > 0 + case int8: + return v > 0 + case int16: + return v > 0 + case int32: + return v > 0 + case int64: + return v > 0 + case uint: + return v > 0 + case uint8: + return v > 0 + case uint16: + return v > 0 + case uint32: + return v > 0 + case uint64: + return v > 0 + case float32: + return v > 0 + case float64: + return v > 0 + case json.Number: + f, err := v.Float64() + return err == nil && f > 0 + default: + return false + } +} + func normalizeGoogleAIRequest(raw map[string]any) { delete(raw, "stream") diff --git a/autorouter_test.go b/autorouter_test.go index a4f6de4..f32b668 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -1411,6 +1411,101 @@ func TestAutoRouter_AnthropicStreamingNoStreamOptions(t *testing.T) { } } +func TestAutoRouter_AnthropicDefaultMaxTokens(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "anthropic", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var raw map[string]any + _ = json.Unmarshal(data, &raw) + maxTokens, _ := raw["max_tokens"].(float64) + return BodyMetadata{Model: "claude-3-opus", MaxTokens: int(maxTokens)}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "msg_test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + if got := receivedBody["max_tokens"]; got != float64(defaultAnthropicMaxTokens) { + t.Fatalf("max_tokens = %v, want %d", got, defaultAnthropicMaxTokens) + } +} + +func TestAutoRouter_AnthropicPreservesMaxTokens(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "anthropic", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "claude-3-opus", MaxTokens: 64}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "msg_test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","max_tokens":64,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + if got := receivedBody["max_tokens"]; got != float64(64) { + t.Fatalf("max_tokens = %v, want 64", got) + } +} + func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream")