From 9399a1dd80cad7646898bee82dbfd5b28e2fa75c Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Fri, 8 May 2026 21:29:43 -0500 Subject: [PATCH] Fix streaming metadata and DeepSeek reasoning --- autorouter.go | 151 +++++++++++++++++- autorouter_test.go | 384 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 527 insertions(+), 8 deletions(-) diff --git a/autorouter.go b/autorouter.go index b522c8b..50d63b1 100644 --- a/autorouter.go +++ b/autorouter.go @@ -154,11 +154,12 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { raw["model"] = strippedModel model = strippedModel - var err error - body, err = json.Marshal(raw) - if err != nil { - return nil, ResponseMetadata{}, fmt.Errorf("failed to marshal request body: %w", err) - } + } + normalizeProviderRequest(raw, providerName) + var err error + body, err = json.Marshal(raw) + if err != nil { + return nil, ResponseMetadata{}, fmt.Errorf("failed to marshal request body: %w", err) } } @@ -290,6 +291,7 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w raw["model"] = strippedModel model = strippedModel } + normalizeProviderRequest(raw, providerName) if a.billingCalculator != nil { if stream, ok := raw["stream"].(bool); ok && stream { if !nativeStreamUsageProviders[providerName] && apiType != APITypeResponses { @@ -378,7 +380,14 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w w.WriteHeader(upstreamResp.StatusCode) - rc := http.NewResponseController(w) + var sseWriter *sseTerminalHoldingWriter + streamWriter := w + if a.billingCalculator != nil && IsSSEStream(upstreamResp.Header.Get("Content-Type")) { + sseWriter = newSSETerminalHoldingWriter(w) + streamWriter = sseWriter + } + + rc := http.NewResponseController(streamWriter) extractor := provider.ResponseExtractor() streamExtractor, isStreaming := extractor.(StreamingResponseExtractor) @@ -386,12 +395,12 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w var respMeta ResponseMetadata if isStreaming && streamExtractor.IsStreamingResponse(upstreamResp) { - respMeta, err = streamExtractor.ExtractStreamingWithController(upstreamResp, w, rc) + respMeta, err = streamExtractor.ExtractStreamingWithController(upstreamResp, streamWriter, rc) if err != nil { return respMeta, err } } else { - respMeta, err = a.streamResponseWithFlush(upstreamResp.Body, w, rc, extractor) + respMeta, err = a.streamResponseWithFlush(upstreamResp.Body, streamWriter, rc, extractor) if err != nil { return respMeta, err } @@ -404,12 +413,105 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) + if sseWriter != nil && sseWriter.HasTerminal() { + if err := writeGatewayMetadataEvent(w, rc, billing); err != nil { + return respMeta, err + } + } + } + } + if sseWriter != nil { + if err := sseWriter.FlushTerminal(); err != nil { + return respMeta, err + } + if err := rc.Flush(); err != nil { + return respMeta, err } } return respMeta, nil } +type sseTerminalHoldingWriter struct { + http.ResponseWriter + terminal []byte +} + +func newSSETerminalHoldingWriter(w http.ResponseWriter) *sseTerminalHoldingWriter { + return &sseTerminalHoldingWriter{ResponseWriter: w} +} + +func (w *sseTerminalHoldingWriter) Write(data []byte) (int, error) { + idx := bytes.Index(data, []byte("data: [DONE]")) + if idx < 0 { + n, err := w.ResponseWriter.Write(data) + if err != nil { + return n, err + } + return len(data), nil + } + + if idx > 0 { + if _, err := w.ResponseWriter.Write(data[:idx]); err != nil { + return 0, err + } + } + w.terminal = append(w.terminal, data[idx:]...) + return len(data), nil +} + +func (w *sseTerminalHoldingWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (w *sseTerminalHoldingWriter) HasTerminal() bool { + return len(w.terminal) > 0 +} + +func (w *sseTerminalHoldingWriter) FlushTerminal() error { + if len(w.terminal) == 0 { + return nil + } + _, err := w.ResponseWriter.Write(w.terminal) + w.terminal = nil + return err +} + +func writeGatewayMetadataEvent(w http.ResponseWriter, rc *http.ResponseController, billing BillingResult) error { + payload := map[string]any{ + "type": "gateway.metadata", + "data": map[string]any{ + "cost": map[string]any{ + "total": billing.TotalCost, + "promptTokens": billing.PromptTokens, + "completionTokens": billing.CompletionTokens, + }, + }, + } + + data, err := json.Marshal(payload) + if err != nil { + return err + } + + if _, err := w.Write([]byte("event: gateway.metadata\n")); err != nil { + return err + } + if _, err := w.Write([]byte("data: ")); err != nil { + return err + } + if _, err := w.Write(data); err != nil { + return err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return err + } + + return rc.Flush() +} + func (a *AutoRouter) streamResponseWithFlush(r io.Reader, w http.ResponseWriter, rc *http.ResponseController, extractor ResponseExtractor) (ResponseMetadata, error) { var buf bytes.Buffer tee := io.TeeReader(r, &buf) @@ -600,3 +702,36 @@ func stripProviderPrefix(model string) (stripped string, hasPrefix bool) { } return model, false } + +func normalizeProviderRequest(raw map[string]any, providerName string) { + if providerName != "deepseek" { + return + } + + reasoning, hasReasoning := raw["reasoning"] + if !hasReasoning { + return + } + + switch value := reasoning.(type) { + case string: + switch strings.ToLower(value) { + case "", "off", "false", "none", "disabled": + delete(raw, "reasoning") + delete(raw, "reasoning_effort") + raw["thinking"] = map[string]any{"type": "disabled"} + case "low", "medium", "high", "max", "xhigh": + delete(raw, "reasoning") + raw["thinking"] = map[string]any{"type": "enabled"} + raw["reasoning_effort"] = value + } + case bool: + delete(raw, "reasoning") + if value { + raw["thinking"] = map[string]any{"type": "enabled"} + } else { + delete(raw, "reasoning_effort") + raw["thinking"] = map[string]any{"type": "disabled"} + } + } +} diff --git a/autorouter_test.go b/autorouter_test.go index 7aa9a53..22c4251 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -345,6 +345,305 @@ func TestAutoRouter_DeepseekV4StripsProviderPrefixBeforeForwarding(t *testing.T) } } +func TestAutoRouter_DeepseekReasoningOffDisablesThinking(t *testing.T) { + var upstreamBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream request: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-deepseek","model":"deepseek-v4-pro","choices":[]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "deepseek", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(data, &req); err != nil { + return BodyMetadata{}, nil, err + } + return BodyMetadata{Model: req.Model}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "chatcmpl-deepseek", Model: "deepseek-v4-pro"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "deepseek" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{ + "model": "deepseek/deepseek-v4-pro", + "reasoning": "off", + "messages": [{"role":"user","content":"Reply with OK and nothing else."}] + }`))) + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if _, ok := upstreamBody["reasoning"]; ok { + t.Fatalf("upstream reasoning should be removed: %#v", upstreamBody) + } + thinking, ok := upstreamBody["thinking"].(map[string]any) + if !ok { + t.Fatalf("upstream thinking missing: %#v", upstreamBody) + } + if thinking["type"] != "disabled" { + t.Fatalf("upstream thinking.type = %q, want disabled", thinking["type"]) + } +} + +func TestAutoRouter_DeepseekReasoningHighEnablesThinking(t *testing.T) { + var upstreamBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream request: %v", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-deepseek","model":"deepseek-v4-pro","choices":[]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "deepseek", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(data, &req); err != nil { + return BodyMetadata{}, nil, err + } + return BodyMetadata{Model: req.Model}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "chatcmpl-deepseek", Model: "deepseek-v4-pro"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "deepseek" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{ + "model": "deepseek/deepseek-v4-pro", + "reasoning": "high", + "messages": [{"role":"user","content":"Reply with OK and nothing else."}] + }`))) + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if _, ok := upstreamBody["reasoning"]; ok { + t.Fatalf("upstream reasoning should be removed: %#v", upstreamBody) + } + thinking, ok := upstreamBody["thinking"].(map[string]any) + if !ok { + t.Fatalf("upstream thinking missing: %#v", upstreamBody) + } + if thinking["type"] != "enabled" { + t.Fatalf("upstream thinking.type = %q, want enabled", thinking["type"]) + } + if upstreamBody["reasoning_effort"] != "high" { + t.Fatalf("upstream reasoning_effort = %q, want high", upstreamBody["reasoning_effort"]) + } +} + +func TestAutoRouter_DeepseekReasoningOffDisablesThinkingForStreaming(t *testing.T) { + var upstreamBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream request: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-deepseek\"}\n\ndata: [DONE]\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "deepseek", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(data, &req); err != nil { + return BodyMetadata{}, nil, err + } + return BodyMetadata{Model: req.Model, Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + _, _ = io.Copy(w, resp.Body) + _ = rc.Flush() + return ResponseMetadata{ID: "chatcmpl-deepseek", Model: "deepseek-v4-pro"}, nil + }, + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "deepseek" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{ + "model": "deepseek/deepseek-v4-pro", + "reasoning": "off", + "reasoning_effort": "high", + "stream": true, + "messages": [{"role":"user","content":"Reply with OK and nothing else."}] + }`))) + w := httptest.NewRecorder() + _, err := router.ForwardStreaming(context.Background(), req, w) + if err != nil { + t.Fatalf("ForwardStreaming() error = %v", err) + } + + if _, ok := upstreamBody["reasoning"]; ok { + t.Fatalf("upstream reasoning should be removed: %#v", upstreamBody) + } + if _, ok := upstreamBody["reasoning_effort"]; ok { + t.Fatalf("upstream reasoning_effort should be removed: %#v", upstreamBody) + } + thinking, ok := upstreamBody["thinking"].(map[string]any) + if !ok { + t.Fatalf("upstream thinking missing: %#v", upstreamBody) + } + if thinking["type"] != "disabled" { + t.Fatalf("upstream thinking.type = %q, want disabled", thinking["type"]) + } +} + +func TestAutoRouter_DeepseekReasoningHighEnablesThinkingForStreaming(t *testing.T) { + var upstreamBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&upstreamBody); err != nil { + t.Fatalf("decode upstream request: %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-deepseek\"}\n\ndata: [DONE]\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "deepseek", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(data, &req); err != nil { + return BodyMetadata{}, nil, err + } + return BodyMetadata{Model: req.Model, Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + _, _ = io.Copy(w, resp.Body) + _ = rc.Flush() + return ResponseMetadata{ID: "chatcmpl-deepseek", Model: "deepseek-v4-pro"}, nil + }, + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "deepseek" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/v1/chat/completions", bytes.NewReader([]byte(`{ + "model": "deepseek/deepseek-v4-pro", + "reasoning": "high", + "stream": true, + "messages": [{"role":"user","content":"Reply with OK and nothing else."}] + }`))) + w := httptest.NewRecorder() + _, err := router.ForwardStreaming(context.Background(), req, w) + if err != nil { + t.Fatalf("ForwardStreaming() error = %v", err) + } + + if _, ok := upstreamBody["reasoning"]; ok { + t.Fatalf("upstream reasoning should be removed: %#v", upstreamBody) + } + thinking, ok := upstreamBody["thinking"].(map[string]any) + if !ok { + t.Fatalf("upstream thinking missing: %#v", upstreamBody) + } + if thinking["type"] != "enabled" { + t.Fatalf("upstream thinking.type = %q, want enabled", thinking["type"]) + } + if upstreamBody["reasoning_effort"] != "high" { + t.Fatalf("upstream reasoning_effort = %q, want high", upstreamBody["reasoning_effort"]) + } +} + +func TestAutoRouter_DeepseekUnknownReasoningIsLeftUntouched(t *testing.T) { + raw := map[string]any{ + "reasoning": "experimental", + "reasoning_effort": "medium", + } + + normalizeProviderRequest(raw, "deepseek") + + if raw["reasoning"] != "experimental" { + t.Fatalf("reasoning = %q, want experimental", raw["reasoning"]) + } + if raw["reasoning_effort"] != "medium" { + t.Fatalf("reasoning_effort = %q, want medium", raw["reasoning_effort"]) + } + if _, ok := raw["thinking"]; ok { + t.Fatalf("thinking should not be set for unknown reasoning: %#v", raw) + } +} + func TestAutoRouter_CohereCommandRUpstreamEmptyErrorDoesNotBecomeExtractorError(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var body struct { @@ -1110,6 +1409,91 @@ func TestAutoRouter_AnthropicStreamingNoStreamOptions(t *testing.T) { } } +func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"id\":\"test\"}\n\ndata: [DONE]\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + _, _ = io.Copy(w, resp.Body) + _ = rc.Flush() + return ResponseMetadata{ + ID: "test", + Usage: Usage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, + }, nil + }, + }, + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1, Output: 2}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + + body := w.Body.String() + metadataIndex := strings.Index(body, "event: gateway.metadata") + doneIndex := strings.Index(body, "data: [DONE]") + if metadataIndex < 0 || doneIndex < 0 { + t.Fatalf("stream body missing metadata or terminal event: %q", body) + } + if metadataIndex > doneIndex { + t.Fatalf("metadata event should be written before terminal event: %q", body) + } + nextEventIndex := strings.Index(body[metadataIndex+1:], "\nevent: ") + metadataEnd := doneIndex + if nextEventIndex >= 0 { + metadataEnd = metadataIndex + 1 + nextEventIndex + } + metadataChunk := body[metadataIndex:metadataEnd] + if !strings.Contains(metadataChunk, `"type":"gateway.metadata"`) { + t.Fatalf("metadata event missing type: %q", metadataChunk) + } + if !strings.Contains(metadataChunk, `"total":0.0002`) { + t.Fatalf("metadata event missing total cost: %q", metadataChunk) + } + if !strings.Contains(metadataChunk, `"promptTokens":100`) { + t.Fatalf("metadata event missing prompt tokens: %q", metadataChunk) + } + if !strings.Contains(metadataChunk, `"completionTokens":50`) { + t.Fatalf("metadata event missing completion tokens: %q", metadataChunk) + } +} + func TestAutoRouter_ResponsesAPIStreamingNoStreamOptions(t *testing.T) { var receivedBody map[string]any upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {