diff --git a/apitype.go b/apitype.go index 0de2b48..1a16cc2 100644 --- a/apitype.go +++ b/apitype.go @@ -7,12 +7,18 @@ import ( type APIType string const ( - APITypeChatCompletions APIType = "chat_completions" - APITypeResponses APIType = "responses" - APITypeCompletions APIType = "completions" - APITypeMessages APIType = "messages" - APITypeGenerateContent APIType = "generate_content" - APITypeConverse APIType = "converse" + APITypeChatCompletions APIType = "chat_completions" + APITypeResponses APIType = "responses" + APITypeCompletions APIType = "completions" + APITypeMessages APIType = "messages" + APITypeGenerateContent APIType = "generate_content" + APITypeStreamGenerateContent APIType = "stream_generate_content" + APITypePredictLongRunning APIType = "predict_long_running" + APITypeEmbeddings APIType = "embeddings" + APITypeImagesGenerations APIType = "images_generations" + APITypeAudioSpeech APIType = "audio_speech" + APITypeAudioTranscriptions APIType = "audio_transcriptions" + APITypeConverse APIType = "converse" ) func DetectAPIType(body []byte) APIType { @@ -44,8 +50,20 @@ func DetectAPITypeFromPath(path string) APIType { return APITypeResponses case containsPath(path, "/v1/completions"): return APITypeCompletions + case containsPath(path, "/v1/embeddings"): + return APITypeEmbeddings + case containsPath(path, "/v1/images/generations"): + return APITypeImagesGenerations + case containsPath(path, "/v1/audio/speech"): + return APITypeAudioSpeech + case containsPath(path, "/v1/audio/transcriptions"): + return APITypeAudioTranscriptions case containsPath(path, "/v1/messages"): return APITypeMessages + case containsPath(path, ":streamGenerateContent"): + return APITypeStreamGenerateContent + case containsPath(path, ":predictLongRunning"): + return APITypePredictLongRunning case containsPath(path, ":generateContent"): return APITypeGenerateContent case containsPath(path, "/converse"): diff --git a/autorouter.go b/autorouter.go index 50d63b1..96cb45a 100644 --- a/autorouter.go +++ b/autorouter.go @@ -7,7 +7,10 @@ import ( "errors" "fmt" "io" + "mime" + "mime/multipart" "net/http" + "net/url" "strings" "github.com/agentuity/go-common/slice" @@ -37,6 +40,7 @@ type AutoRouter struct { registry Registry detector ProviderDetector modelProviderLookup ModelProviderLookup + modelMetadataLookup ModelMetadataLookup interceptors InterceptorChain client *http.Client fallbackProvider Provider @@ -72,6 +76,10 @@ func WithAutoRouterModelProviderLookup(lookup ModelProviderLookup) AutoRouterOpt return func(a *AutoRouter) { a.modelProviderLookup = lookup } } +func WithAutoRouterModelMetadataLookup(lookup ModelMetadataLookup) AutoRouterOption { + return func(a *AutoRouter) { a.modelMetadataLookup = lookup } +} + func WithAutoRouterBillingCalculator(calculator *BillingCalculator) AutoRouterOption { return func(a *AutoRouter) { a.billingCalculator = calculator } } @@ -119,12 +127,9 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp } req.Body.Close() - var raw map[string]any - var model string - if err := json.Unmarshal(body, &raw); err == nil { - if m, ok := raw["model"].(string); ok { - model = m - } + model, raw := extractRequestModel(req.Header, body) + if model == "" { + model = extractModelFromPath(req.URL.Path) } hint := ProviderHint{ @@ -136,6 +141,9 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp if providerName == "" && a.modelProviderLookup != nil && model != "" { providerName = a.modelProviderLookup(model) } + if providerName == "" { + providerName = providerFromAPIType(DetectAPITypeFromPath(req.URL.Path)) + } var provider Provider if providerName != "" { @@ -161,16 +169,31 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp if err != nil { return nil, ResponseMetadata{}, fmt.Errorf("failed to marshal request body: %w", err) } + } else if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { + if rewrittenBody, contentType, ok := rewriteRequestModel(req.Header.Get("Content-Type"), body, strippedModel); ok { + body = rewrittenBody + req.Header.Set("Content-Type", contentType) + } + model = strippedModel } apiType := DetectAPITypeFromPath(req.URL.Path) if apiType == "" { apiType = DetectAPITypeFromBodyAndProvider(body, providerName) } + if err := a.validateModelSurface(providerName, model, apiType); err != nil { + return nil, ResponseMetadata{}, err + } meta, _, err := provider.BodyParser().Parse(io.NopCloser(bytes.NewReader(body))) if err != nil { - return nil, ResponseMetadata{}, err + if !canBuildPassthroughMetadata(apiType, model) { + return nil, ResponseMetadata{}, err + } + meta = BodyMetadata{Model: model, Custom: make(map[string]any)} + } + if meta.Model == "" && model != "" { + meta.Model = model } if meta.Custom == nil { @@ -250,12 +273,9 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w } req.Body.Close() - var raw map[string]any - var model string - if err := json.Unmarshal(body, &raw); err == nil { - if m, ok := raw["model"].(string); ok { - model = m - } + model, raw := extractRequestModel(req.Header, body) + if model == "" { + model = extractModelFromPath(req.URL.Path) } hint := ProviderHint{ @@ -267,6 +287,9 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w if providerName == "" && a.modelProviderLookup != nil && model != "" { providerName = a.modelProviderLookup(model) } + if providerName == "" { + providerName = providerFromAPIType(DetectAPITypeFromPath(req.URL.Path)) + } var provider Provider if providerName != "" { @@ -285,6 +308,9 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w if apiType == "" { apiType = DetectAPITypeFromBodyAndProvider(body, providerName) } + if err := a.validateModelSurface(providerName, model, apiType); err != nil { + return ResponseMetadata{}, err + } if raw != nil { if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { @@ -310,11 +336,23 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w if err != nil { return ResponseMetadata{}, fmt.Errorf("failed to marshal request body: %w", err) } + } else if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { + if rewrittenBody, contentType, ok := rewriteRequestModel(req.Header.Get("Content-Type"), body, strippedModel); ok { + body = rewrittenBody + req.Header.Set("Content-Type", contentType) + } + model = strippedModel } meta, _, err := provider.BodyParser().Parse(io.NopCloser(bytes.NewReader(body))) if err != nil { - return ResponseMetadata{}, err + if !canBuildPassthroughMetadata(apiType, model) { + return ResponseMetadata{}, err + } + meta = BodyMetadata{Model: model, Custom: make(map[string]any)} + } + if meta.Model == "" && model != "" { + meta.Model = model } if meta.Custom == nil { @@ -373,7 +411,7 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w // Declare HTTP trailers for billing headers (must be before WriteHeader) if a.billingCalculator != nil { - w.Header().Set("Trailer", "X-Gateway-Cost,X-Gateway-Prompt-Tokens,X-Gateway-Completion-Tokens") + w.Header().Set("Trailer", gatewayBillingTrailerHeader()) } copyResponseHeaders(w, upstreamResp.Header) @@ -413,7 +451,8 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) - if sseWriter != nil && sseWriter.HasTerminal() { + setGatewayMeteredBillingHeaders(w.Header(), billing) + if sseWriter != nil { if err := writeGatewayMetadataEvent(w, rc, billing); err != nil { return respMeta, err } @@ -487,6 +526,9 @@ func writeGatewayMetadataEvent(w http.ResponseWriter, rc *http.ResponseControlle "total": billing.TotalCost, "promptTokens": billing.PromptTokens, "completionTokens": billing.CompletionTokens, + "unit": billing.Unit, + "inputQuantity": billing.InputQuantity, + "outputQuantity": billing.OutputQuantity, }, }, } @@ -512,6 +554,29 @@ func writeGatewayMetadataEvent(w http.ResponseWriter, rc *http.ResponseControlle return rc.Flush() } +func gatewayBillingTrailerHeader() string { + return strings.Join([]string{ + "X-Gateway-Cost", + "X-Gateway-Prompt-Tokens", + "X-Gateway-Completion-Tokens", + "X-Gateway-Billing-Unit", + "X-Gateway-Input-Quantity", + "X-Gateway-Output-Quantity", + }, ",") +} + +func setGatewayMeteredBillingHeaders(header http.Header, billing BillingResult) { + if billing.Unit != "" { + header.Set("X-Gateway-Billing-Unit", billing.Unit) + } + if billing.InputQuantity != 0 { + header.Set("X-Gateway-Input-Quantity", fmt.Sprintf("%.6f", billing.InputQuantity)) + } + if billing.OutputQuantity != 0 { + header.Set("X-Gateway-Output-Quantity", fmt.Sprintf("%.6f", billing.OutputQuantity)) + } +} + func (a *AutoRouter) streamResponseWithFlush(r io.Reader, w http.ResponseWriter, rc *http.ResponseController, extractor ResponseExtractor) (ResponseMetadata, error) { var buf bytes.Buffer tee := io.TeeReader(r, &buf) @@ -574,6 +639,10 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { var raw map[string]any var isStreamingRequest bool + if DetectAPITypeFromPath(r.URL.Path) == APITypeStreamGenerateContent || + strings.Contains(r.Header.Get("Accept"), "text/event-stream") { + isStreamingRequest = true + } if err := json.Unmarshal(body, &raw); err == nil { if stream, ok := raw["stream"].(bool); ok && stream { isStreamingRequest = true @@ -607,6 +676,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) + setGatewayMeteredBillingHeaders(w.Header(), billing) } w.WriteHeader(resp.StatusCode) @@ -687,6 +757,188 @@ var knownProviderPrefixes = map[string]bool{ "deepseek": true, } +func (a *AutoRouter) validateModelSurface(providerName string, model string, apiType APIType) error { + if a == nil || a.modelMetadataLookup == nil || model == "" { + return nil + } + lookupModel := model + if stripped, ok := stripProviderPrefix(lookupModel); ok { + lookupModel = stripped + } + metadata, ok := a.modelMetadataLookup(providerName, lookupModel) + if !ok { + return nil + } + return validateAPITypeAgainstModel(apiType, providerName, lookupModel, metadata) +} + +func providerFromAPIType(apiType APIType) string { + switch apiType { + case APITypeGenerateContent, APITypeStreamGenerateContent, APITypePredictLongRunning: + return "googleai" + default: + return "" + } +} + +func extractRequestModel(headers http.Header, body []byte) (string, map[string]any) { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err == nil { + if model, ok := raw["model"].(string); ok { + return model, raw + } + return "", raw + } + + contentType := headers.Get("Content-Type") + mediaType, params, err := mime.ParseMediaType(contentType) + if err != nil { + return "", nil + } + + switch mediaType { + case "multipart/form-data": + boundary := params["boundary"] + if boundary == "" { + return "", nil + } + return extractMultipartModel(body, boundary), nil + case "application/x-www-form-urlencoded": + values, err := url.ParseQuery(string(body)) + if err != nil { + return "", nil + } + return values.Get("model"), nil + default: + return "", nil + } +} + +func extractMultipartModel(body []byte, boundary string) string { + reader := multipart.NewReader(bytes.NewReader(body), boundary) + for { + part, err := reader.NextPart() + if err != nil { + return "" + } + if part.FormName() != "model" { + _ = part.Close() + continue + } + data, err := io.ReadAll(part) + _ = part.Close() + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) + } +} + +func rewriteRequestModel(contentType string, body []byte, model string) ([]byte, string, bool) { + mediaType, params, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, "", false + } + + switch mediaType { + case "multipart/form-data": + boundary := params["boundary"] + if boundary == "" { + return nil, "", false + } + rewrittenBody, rewrittenContentType, err := rewriteMultipartModel(body, boundary, model) + if err != nil { + return nil, "", false + } + return rewrittenBody, rewrittenContentType, true + case "application/x-www-form-urlencoded": + values, err := url.ParseQuery(string(body)) + if err != nil { + return nil, "", false + } + values.Set("model", model) + return []byte(values.Encode()), contentType, true + default: + return nil, "", false + } +} + +func rewriteMultipartModel(body []byte, boundary string, model string) ([]byte, string, error) { + reader := multipart.NewReader(bytes.NewReader(body), boundary) + var output bytes.Buffer + writer := multipart.NewWriter(&output) + wroteModel := false + + for { + part, err := reader.NextPart() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, "", err + } + + if part.FormName() == "model" { + if err := writer.WriteField("model", model); err != nil { + _ = part.Close() + return nil, "", err + } + wroteModel = true + _ = part.Close() + continue + } + + dst, err := writer.CreatePart(part.Header) + if err != nil { + _ = part.Close() + return nil, "", err + } + if _, err := io.Copy(dst, part); err != nil { + _ = part.Close() + return nil, "", err + } + _ = part.Close() + } + + if !wroteModel { + if err := writer.WriteField("model", model); err != nil { + return nil, "", err + } + } + if err := writer.Close(); err != nil { + return nil, "", err + } + return output.Bytes(), writer.FormDataContentType(), nil +} + +func canBuildPassthroughMetadata(apiType APIType, model string) bool { + if model == "" { + return false + } + switch apiType { + case APITypeAudioTranscriptions: + return true + default: + return false + } +} + +func extractModelFromPath(path string) string { + const marker = "/models/" + idx := strings.Index(path, marker) + if idx < 0 { + return "" + } + value := path[idx+len(marker):] + if slash := strings.Index(value, "/"); slash >= 0 { + value = value[:slash] + } + if colon := strings.Index(value, ":"); colon >= 0 { + value = value[:colon] + } + return strings.TrimSpace(value) +} + func stripProviderPrefix(model string) (stripped string, hasPrefix bool) { idx := strings.Index(model, "/") if idx < 0 { @@ -704,6 +956,11 @@ func stripProviderPrefix(model string) (stripped string, hasPrefix bool) { } func normalizeProviderRequest(raw map[string]any, providerName string) { + if providerName == "googleai" { + normalizeGoogleAIRequest(raw) + return + } + if providerName != "deepseek" { return } @@ -735,3 +992,56 @@ func normalizeProviderRequest(raw map[string]any, providerName string) { } } } + +func normalizeGoogleAIRequest(raw map[string]any) { + delete(raw, "stream") + + if maxTokens, ok := raw["max_tokens"]; ok { + generationConfig, _ := raw["generationConfig"].(map[string]any) + if generationConfig == nil { + generationConfig = make(map[string]any) + raw["generationConfig"] = generationConfig + } + if _, exists := generationConfig["maxOutputTokens"]; !exists { + generationConfig["maxOutputTokens"] = maxTokens + } + delete(raw, "max_tokens") + } + + if systemInstruction, ok := raw["system_instruction"].(string); ok { + if _, exists := raw["systemInstruction"]; !exists && systemInstruction != "" { + raw["systemInstruction"] = map[string]any{ + "parts": []any{map[string]any{"text": systemInstruction}}, + } + } + delete(raw, "system_instruction") + } + + if messages, ok := raw["messages"].([]any); ok { + if _, exists := raw["contents"]; !exists { + contents := make([]any, 0, len(messages)) + for _, item := range messages { + message, ok := item.(map[string]any) + if !ok { + continue + } + content, ok := message["content"].(string) + if !ok || content == "" { + continue + } + role, _ := message["role"].(string) + if role == "assistant" { + role = "model" + } else { + role = "user" + } + contents = append(contents, map[string]any{ + "role": role, + "parts": []any{map[string]any{"text": content}}, + }) + } + raw["contents"] = contents + } + delete(raw, "messages") + } +} diff --git a/autorouter_test.go b/autorouter_test.go index 22c4251..a4f6de4 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -7,6 +7,8 @@ import ( "encoding/json" "errors" "io" + "mime" + "mime/multipart" "net/http" "net/http/httptest" "net/url" @@ -1494,6 +1496,84 @@ func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) { } } +func TestAutoRouter_StreamingWritesGatewayMetadataEventWithoutTerminalMarker(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("event: response.completed\n")) + _, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":100,\"output_tokens\":50}}}\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + _, _ = io.Copy(w, resp.Body) + _ = rc.Flush() + return ResponseMetadata{ + ID: "test", + Usage: Usage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, + }, nil + }, + }, + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1, Output: 2}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"input":"Hello"}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", w.Code) + } + + body := w.Body.String() + metadataIndex := strings.Index(body, "event: gateway.metadata") + if metadataIndex < 0 { + t.Fatalf("stream body missing metadata event: %q", body) + } + + metadataChunk := body[metadataIndex:] + if !strings.Contains(metadataChunk, `"type":"gateway.metadata"`) { + t.Fatalf("metadata event missing type: %q", metadataChunk) + } + if !strings.Contains(metadataChunk, `"total":0.0002`) { + t.Fatalf("metadata event missing total cost: %q", metadataChunk) + } + if !strings.Contains(metadataChunk, `"promptTokens":100`) { + t.Fatalf("metadata event missing prompt tokens: %q", metadataChunk) + } + if !strings.Contains(metadataChunk, `"completionTokens":50`) { + t.Fatalf("metadata event missing completion tokens: %q", metadataChunk) + } +} + func TestAutoRouter_ResponsesAPIStreamingNoStreamOptions(t *testing.T) { var receivedBody map[string]any upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1579,6 +1659,63 @@ func TestAutoRouter_ResponsesAPIStreamingNoStreamOptions(t *testing.T) { }) } +func TestNormalizeProviderRequest_GoogleAI(t *testing.T) { + raw := map[string]any{ + "model": "googleai/gemini-2.5-flash", + "max_tokens": float64(256), + "stream": true, + "system_instruction": "You are concise.", + "messages": []any{ + map[string]any{"role": "user", "content": "Say hello"}, + map[string]any{"role": "assistant", "content": "Hello"}, + }, + } + + normalizeProviderRequest(raw, "googleai") + + if _, ok := raw["stream"]; ok { + t.Fatal("stream should be removed from Google AI upstream body") + } + if _, ok := raw["max_tokens"]; ok { + t.Fatal("max_tokens should be converted for Google AI") + } + if _, ok := raw["system_instruction"]; ok { + t.Fatal("system_instruction should be converted for Google AI") + } + if _, ok := raw["messages"]; ok { + t.Fatal("messages should be converted for Google AI") + } + + generationConfig, ok := raw["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("generationConfig missing or invalid: %#v", raw["generationConfig"]) + } + if generationConfig["maxOutputTokens"] != float64(256) { + t.Fatalf("maxOutputTokens = %#v, want 256", generationConfig["maxOutputTokens"]) + } + + systemInstruction, ok := raw["systemInstruction"].(map[string]any) + if !ok { + t.Fatalf("systemInstruction missing or invalid: %#v", raw["systemInstruction"]) + } + systemParts, ok := systemInstruction["parts"].([]any) + if !ok || len(systemParts) != 1 { + t.Fatalf("systemInstruction parts invalid: %#v", systemInstruction["parts"]) + } + + contents, ok := raw["contents"].([]any) + if !ok || len(contents) != 2 { + t.Fatalf("contents missing or invalid: %#v", raw["contents"]) + } + second, ok := contents[1].(map[string]any) + if !ok { + t.Fatalf("second content invalid: %#v", contents[1]) + } + if second["role"] != "model" { + t.Fatalf("assistant role should map to model, got %#v", second["role"]) + } +} + func TestAutoRouter_copyResponseHeaders(t *testing.T) { w := httptest.NewRecorder() copyResponseHeaders(w, http.Header{}) @@ -1651,3 +1788,307 @@ func TestAutoRouter_copyResponseHeaders(t *testing.T) { w = httptest.NewRecorder() } + +func TestExtractModelFromPath(t *testing.T) { + tests := []struct { + path string + want string + }{ + {"/v1beta/models/gemini-3.1-flash-lite:generateContent", "gemini-3.1-flash-lite"}, + {"/v1beta/models/veo-3.1-generate-preview:predictLongRunning", "veo-3.1-generate-preview"}, + {"/v1/chat/completions", ""}, + } + for _, tt := range tests { + if got := extractModelFromPath(tt.path); got != tt.want { + t.Fatalf("extractModelFromPath(%q) = %q, want %q", tt.path, got, tt.want) + } + } +} + +func TestAutoRouter_ModelMetadataValidationRejectsUnsupportedSurface(t *testing.T) { + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { + return "googleai" + })), + WithAutoRouterFallbackProvider(&mockProvider{name: "googleai"}), + WithAutoRouterModelMetadataLookup(func(provider, model string) (ModelMetadata, bool) { + return ModelMetadata{ + APICompatibility: "google-generative-ai-long-running", + InputModalities: []string{"text", "image"}, + OutputModalities: []string{"video"}, + }, true + }), + ) + router.RegisterProvider(&mockProvider{name: "googleai"}) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"googleai/veo-3.1-generate-preview","messages":[{"role":"user","content":"hello"}]}`)) + _, _, err := router.Forward(context.Background(), req) + if err == nil { + t.Fatal("expected unsupported surface error") + } + if !strings.Contains(err.Error(), "does not support chat_completions") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestAutoRouter_ModelMetadataValidationAllowsGoogleLongRunningPathModel(t *testing.T) { + var parsedModel string + provider := &mockProvider{ + name: "googleai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, err := io.ReadAll(body) + return BodyMetadata{Custom: map[string]any{}}, data, err + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + parsedModel = meta.Model + return url.Parse("https://example.com/v1beta/models/" + meta.Model + ":predictLongRunning") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, err := io.ReadAll(resp.Body) + return ResponseMetadata{Custom: map[string]any{}}, body, err + }, + } + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"name":"operations/test"}`)) + })) + defer upstream.Close() + provider.resolveFn = func(meta BodyMetadata) (*url.URL, error) { + parsedModel = meta.Model + return url.Parse(upstream.URL + "/v1beta/models/" + meta.Model + ":predictLongRunning") + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { + return "googleai" + })), + WithAutoRouterHTTPClient(upstream.Client()), + WithAutoRouterFallbackProvider(provider), + WithAutoRouterModelMetadataLookup(func(provider, model string) (ModelMetadata, bool) { + return ModelMetadata{ + APICompatibility: "google-generative-ai-long-running", + InputModalities: []string{"text", "image"}, + OutputModalities: []string{"video"}, + }, true + }), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/veo-3.1-generate-preview:predictLongRunning", strings.NewReader(`{"instances":[{"prompt":"hello"}]}`)) + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + if parsedModel != "veo-3.1-generate-preview" { + t.Fatalf("parsed model = %q, want veo-3.1-generate-preview", parsedModel) + } +} + +func TestAutoRouter_GoogleNativePathSelectsGoogleProviderWithoutModelPrefix(t *testing.T) { + var parsedModel string + provider := &mockProvider{ + name: "googleai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, err := io.ReadAll(body) + return BodyMetadata{Custom: map[string]any{}}, data, err + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + parsedModel = meta.Model + return url.Parse("https://example.com/v1beta/models/" + meta.Model + ":generateContent") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, err := io.ReadAll(resp.Body) + return ResponseMetadata{Custom: map[string]any{}}, body, err + }, + } + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"pong"}]}}]}`)) + })) + defer upstream.Close() + provider.resolveFn = func(meta BodyMetadata) (*url.URL, error) { + parsedModel = meta.Model + return url.Parse(upstream.URL + "/v1beta/models/" + meta.Model + ":generateContent") + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { + return "" + })), + WithAutoRouterHTTPClient(upstream.Client()), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3.1-flash-lite:generateContent", strings.NewReader(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Say pong in one word."}]} + ] + }`)) + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + if parsedModel != "gemini-3.1-flash-lite" { + t.Fatalf("parsed model = %q, want gemini-3.1-flash-lite", parsedModel) + } +} + +func TestAutoRouter_GoogleNativeStreamPathUsesStreamingForwarder(t *testing.T) { + var parsedModel string + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "googleai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, err := io.ReadAll(body) + return BodyMetadata{Custom: map[string]any{}}, data, err + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + parsedModel = meta.Model + return url.Parse("https://example.com/v1beta/models/" + meta.Model + ":streamGenerateContent?alt=sse") + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + _, _ = io.Copy(w, resp.Body) + _ = rc.Flush() + return ResponseMetadata{Custom: map[string]any{}}, nil + }, + }, + } + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Pong\"}]}}]}\n\n")) + })) + defer upstream.Close() + provider.resolveFn = func(meta BodyMetadata) (*url.URL, error) { + parsedModel = meta.Model + return url.Parse(upstream.URL + "/v1beta/models/" + meta.Model + ":streamGenerateContent?alt=sse") + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { + return "" + })), + WithAutoRouterHTTPClient(upstream.Client()), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3.1-flash-lite:streamGenerateContent", strings.NewReader(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Say pong in one word."}]} + ] + }`)) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, body = %q", recorder.Code, recorder.Body.String()) + } + if parsedModel != "gemini-3.1-flash-lite" { + t.Fatalf("parsed model = %q, want gemini-3.1-flash-lite", parsedModel) + } + if !strings.Contains(recorder.Body.String(), `"Pong"`) { + t.Fatalf("stream body = %q, want Pong token", recorder.Body.String()) + } +} + +func TestAutoRouter_OpenAIAudioTranscriptionMultipartPassesThrough(t *testing.T) { + var upstreamBody []byte + var upstreamContentType string + var resolvedModel string + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamContentType = r.Header.Get("Content-Type") + upstreamBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"text":"hello"}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{}, data, errors.New("json parser should not block multipart pass-through") + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + resolvedModel = meta.Model + return url.Parse(upstream.URL + "/v1/audio/transcriptions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, err := io.ReadAll(resp.Body) + return ResponseMetadata{Custom: map[string]any{}}, body, err + }, + } + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if err := writer.WriteField("model", "openai/whisper-1"); err != nil { + t.Fatal(err) + } + file, err := writer.CreateFormFile("file", "hello.wav") + if err != nil { + t.Fatal(err) + } + if _, err := file.Write([]byte("fake wav data")); err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(h ProviderHint) string { + return "openai" + })), + WithAutoRouterHTTPClient(upstream.Client()), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + defer resp.Body.Close() + + if resolvedModel != "whisper-1" { + t.Fatalf("resolved model = %q, want whisper-1", resolvedModel) + } + + reader := multipart.NewReader(bytes.NewReader(upstreamBody), boundaryFromContentType(t, upstreamContentType)) + form, err := reader.ReadForm(1024 * 1024) + if err != nil { + t.Fatalf("ReadForm() error = %v", err) + } + if got := form.Value["model"][0]; got != "whisper-1" { + t.Fatalf("upstream model = %q, want whisper-1", got) + } + if len(form.File["file"]) != 1 { + t.Fatalf("upstream file parts = %d, want 1", len(form.File["file"])) + } +} + +func boundaryFromContentType(t *testing.T, contentType string) string { + t.Helper() + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + t.Fatalf("ParseMediaType() error = %v", err) + } + return params["boundary"] +} diff --git a/billing.go b/billing.go index 194617e..d443f26 100644 --- a/billing.go +++ b/billing.go @@ -10,6 +10,9 @@ type CostInfo struct { CacheRead float64 // CacheWrite is the cost per 1M cache write tokens (optional, Anthropic). CacheWrite float64 + // Unit describes the billing unit for the cost values. Token-based models use + // "per_million_tokens"; non-token modalities may use provider-specific units. + Unit string } // CostLookup is a function that returns the cost for a given provider and model. @@ -33,6 +36,12 @@ type BillingResult struct { CachedTokens int // TotalTokens is the sum of prompt and completion tokens. TotalTokens int + // Unit is the billing unit used for this calculation. + Unit string + // InputQuantity is the input quantity for non-token units. + InputQuantity float64 + // OutputQuantity is the output quantity for non-token units. + OutputQuantity float64 // InputCost is the calculated input cost in USD (non-cached prompt tokens). InputCost float64 // CachedInputCost is the cost for cached prompt tokens in USD. @@ -47,6 +56,32 @@ type BillingResult struct { // Cached tokens are billed at the CacheRead rate (if available), and non-cached prompt // tokens are billed at the full Input rate. func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, completionTokens int, cacheUsage *CacheUsage) BillingResult { + return CalculateCostWithMeteredUsage(provider, model, costInfo, promptTokens, completionTokens, cacheUsage, MeteredUsage{}) +} + +func CalculateCostWithMeteredUsage(provider, model string, costInfo CostInfo, promptTokens, completionTokens int, cacheUsage *CacheUsage, meteredUsage MeteredUsage) BillingResult { + unit := costInfo.Unit + if unit == "" { + unit = "per_million_tokens" + } + if costInfo.Unit != "" && costInfo.Unit != "per_million_tokens" { + inputQuantity, outputQuantity := billingQuantitiesForUnit(unit, meteredUsage) + inputCost, outputCost := meteredCostForUnit(costInfo, unit, inputQuantity, outputQuantity) + return BillingResult{ + Provider: provider, + Model: model, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + Unit: unit, + InputQuantity: inputQuantity, + OutputQuantity: outputQuantity, + InputCost: inputCost, + OutputCost: outputCost, + TotalCost: inputCost + outputCost, + } + } + cachedTokens := 0 if cacheUsage != nil { // Providers populate only one of these fields β€” OpenAI/Fireworks/Bedrock @@ -97,9 +132,38 @@ func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, comp CompletionTokens: completionTokens, CachedTokens: cachedTokens, TotalTokens: promptTokens + completionTokens, + Unit: unit, + InputQuantity: float64(promptTokens), + OutputQuantity: float64(completionTokens), InputCost: inputCost, CachedInputCost: cachedInputCost, OutputCost: outputCost, TotalCost: inputCost + cachedInputCost + outputCost, } } + +func billingQuantitiesForUnit(unit string, usage MeteredUsage) (float64, float64) { + switch unit { + case "per_million_characters": + return float64(usage.InputCharacters), float64(usage.OutputCharacters) + case "per_minute_audio": + return usage.InputAudioSeconds / 60, usage.OutputAudioSeconds / 60 + case "per_second_720p_video", "per_second_720p_1080p_video": + return 0, usage.OutputVideoSeconds + default: + return 0, 0 + } +} + +func meteredCostForUnit(costInfo CostInfo, unit string, inputQuantity float64, outputQuantity float64) (float64, float64) { + switch unit { + case "per_million_characters": + return costInfo.Input * inputQuantity / 1_000_000, costInfo.Output * outputQuantity / 1_000_000 + case "per_minute_audio": + return costInfo.Input * inputQuantity, costInfo.Output * outputQuantity + case "per_second_720p_video", "per_second_720p_1080p_video": + return costInfo.Input * inputQuantity, costInfo.Output * outputQuantity + default: + return 0, 0 + } +} diff --git a/billing_calculator.go b/billing_calculator.go index e8e7e22..5c62de4 100644 --- a/billing_calculator.go +++ b/billing_calculator.go @@ -42,7 +42,8 @@ func (c *BillingCalculator) Calculate(meta BodyMetadata, respMeta *ResponseMetad } } - result := CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage) + meteredUsage := mergeMeteredUsage(meta.MeteredUsage, respMeta.MeteredUsage) + result := CalculateCostWithMeteredUsage(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage, meteredUsage) if respMeta.Custom == nil { respMeta.Custom = make(map[string]any) @@ -56,6 +57,37 @@ func (c *BillingCalculator) Calculate(meta BodyMetadata, respMeta *ResponseMetad return &result } +func mergeMeteredUsage(requestUsage MeteredUsage, responseUsage MeteredUsage) MeteredUsage { + return MeteredUsage{ + InputCharacters: selectInt(responseUsage.InputCharacters, responseUsage.HasInputCharacters, requestUsage.InputCharacters), + OutputCharacters: selectInt(responseUsage.OutputCharacters, responseUsage.HasOutputCharacters, requestUsage.OutputCharacters), + InputAudioSeconds: selectFloat(responseUsage.InputAudioSeconds, responseUsage.HasInputAudioSeconds, requestUsage.InputAudioSeconds), + OutputAudioSeconds: selectFloat(responseUsage.OutputAudioSeconds, responseUsage.HasOutputAudioSeconds, requestUsage.OutputAudioSeconds), + OutputVideoSeconds: selectFloat(responseUsage.OutputVideoSeconds, responseUsage.HasOutputVideoSeconds, requestUsage.OutputVideoSeconds), + GeneratedImages: selectInt(responseUsage.GeneratedImages, responseUsage.HasGeneratedImages, requestUsage.GeneratedImages), + HasInputCharacters: responseUsage.HasInputCharacters || requestUsage.HasInputCharacters, + HasOutputCharacters: responseUsage.HasOutputCharacters || requestUsage.HasOutputCharacters, + HasInputAudioSeconds: responseUsage.HasInputAudioSeconds || requestUsage.HasInputAudioSeconds, + HasOutputAudioSeconds: responseUsage.HasOutputAudioSeconds || requestUsage.HasOutputAudioSeconds, + HasOutputVideoSeconds: responseUsage.HasOutputVideoSeconds || requestUsage.HasOutputVideoSeconds, + HasGeneratedImages: responseUsage.HasGeneratedImages || requestUsage.HasGeneratedImages, + } +} + +func selectInt(responseValue int, responsePresent bool, requestValue int) int { + if responsePresent { + return responseValue + } + return requestValue +} + +func selectFloat(responseValue float64, responsePresent bool, requestValue float64) float64 { + if responsePresent { + return responseValue + } + return requestValue +} + func (c *BillingCalculator) Lookup() CostLookup { return c.lookup } diff --git a/billing_test.go b/billing_test.go index 72ce02a..9e6a2f2 100644 --- a/billing_test.go +++ b/billing_test.go @@ -183,6 +183,74 @@ func TestCalculateCost_ZeroTokens(t *testing.T) { assertFloat(t, "TotalCost", result.TotalCost, 0) } +func TestCalculateCost_PerMillionCharacters(t *testing.T) { + result := CalculateCostWithMeteredUsage( + "openai", + "tts-1", + CostInfo{Input: 15, Unit: "per_million_characters"}, + 0, + 0, + nil, + MeteredUsage{InputCharacters: 10}, + ) + + assertFloat(t, "InputCost", result.InputCost, 0.00015) + assertFloat(t, "TotalCost", result.TotalCost, 0.00015) + if result.InputQuantity != 10 { + t.Fatalf("InputQuantity = %f, want 10", result.InputQuantity) + } +} + +func TestCalculateCost_PerMinuteAudio(t *testing.T) { + result := CalculateCostWithMeteredUsage( + "openai", + "whisper-1", + CostInfo{Input: 0.006, Unit: "per_minute_audio"}, + 0, + 0, + nil, + MeteredUsage{InputAudioSeconds: 30}, + ) + + assertFloat(t, "InputCost", result.InputCost, 0.003) + assertFloat(t, "TotalCost", result.TotalCost, 0.003) + if result.InputQuantity != 0.5 { + t.Fatalf("InputQuantity = %f, want 0.5", result.InputQuantity) + } +} + +func TestCalculateCost_PerSecondVideo(t *testing.T) { + result := CalculateCostWithMeteredUsage( + "googleai", + "veo-3.1-generate-preview", + CostInfo{Output: 0.4, Unit: "per_second_720p_1080p_video"}, + 0, + 0, + nil, + MeteredUsage{OutputVideoSeconds: 16}, + ) + + assertFloat(t, "OutputCost", result.OutputCost, 6.4) + assertFloat(t, "TotalCost", result.TotalCost, 6.4) + if result.OutputQuantity != 16 { + t.Fatalf("OutputQuantity = %f, want 16", result.OutputQuantity) + } +} + +func TestMergeMeteredUsage_ResponseZeroOverridesRequestEstimate(t *testing.T) { + merged := mergeMeteredUsage( + MeteredUsage{InputAudioSeconds: 30, HasInputAudioSeconds: true}, + MeteredUsage{InputAudioSeconds: 0, HasInputAudioSeconds: true}, + ) + + if merged.InputAudioSeconds != 0 { + t.Fatalf("InputAudioSeconds = %f, want response-reported zero", merged.InputAudioSeconds) + } + if !merged.HasInputAudioSeconds { + t.Fatal("HasInputAudioSeconds = false, want true") + } +} + func TestCalculateCost_MixedProviderCacheFields(t *testing.T) { // Both CachedTokens and CacheReadInputTokens set (shouldn't happen, but test summing) costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} diff --git a/interceptors/billing.go b/interceptors/billing.go index 8e092e8..f129475 100644 --- a/interceptors/billing.go +++ b/interceptors/billing.go @@ -50,7 +50,7 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta cacheUsage = &usage } } - result := llmproxy.CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage) + result := llmproxy.CalculateCostWithMeteredUsage(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage, mergeMeteredUsage(meta.MeteredUsage, respMeta.MeteredUsage)) if respMeta.Custom == nil { respMeta.Custom = make(map[string]any) } @@ -63,6 +63,37 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta return resp, respMeta, rawRespBody, nil } +func mergeMeteredUsage(requestUsage llmproxy.MeteredUsage, responseUsage llmproxy.MeteredUsage) llmproxy.MeteredUsage { + return llmproxy.MeteredUsage{ + InputCharacters: selectInt(responseUsage.InputCharacters, responseUsage.HasInputCharacters, requestUsage.InputCharacters), + OutputCharacters: selectInt(responseUsage.OutputCharacters, responseUsage.HasOutputCharacters, requestUsage.OutputCharacters), + InputAudioSeconds: selectFloat(responseUsage.InputAudioSeconds, responseUsage.HasInputAudioSeconds, requestUsage.InputAudioSeconds), + OutputAudioSeconds: selectFloat(responseUsage.OutputAudioSeconds, responseUsage.HasOutputAudioSeconds, requestUsage.OutputAudioSeconds), + OutputVideoSeconds: selectFloat(responseUsage.OutputVideoSeconds, responseUsage.HasOutputVideoSeconds, requestUsage.OutputVideoSeconds), + GeneratedImages: selectInt(responseUsage.GeneratedImages, responseUsage.HasGeneratedImages, requestUsage.GeneratedImages), + HasInputCharacters: responseUsage.HasInputCharacters || requestUsage.HasInputCharacters, + HasOutputCharacters: responseUsage.HasOutputCharacters || requestUsage.HasOutputCharacters, + HasInputAudioSeconds: responseUsage.HasInputAudioSeconds || requestUsage.HasInputAudioSeconds, + HasOutputAudioSeconds: responseUsage.HasOutputAudioSeconds || requestUsage.HasOutputAudioSeconds, + HasOutputVideoSeconds: responseUsage.HasOutputVideoSeconds || requestUsage.HasOutputVideoSeconds, + HasGeneratedImages: responseUsage.HasGeneratedImages || requestUsage.HasGeneratedImages, + } +} + +func selectInt(responseValue int, responsePresent bool, requestValue int) int { + if responsePresent { + return responseValue + } + return requestValue +} + +func selectFloat(responseValue float64, responsePresent bool, requestValue float64) float64 { + if responsePresent { + return responseValue + } + return requestValue +} + // NewBilling creates a new billing interceptor with the given lookup function. // // Example: diff --git a/interceptors/coverage_test.go b/interceptors/coverage_test.go index 4865c8b..a918213 100644 --- a/interceptors/coverage_test.go +++ b/interceptors/coverage_test.go @@ -1,6 +1,7 @@ package interceptors import ( + "io" "net/http" "net/http/httptest" "strconv" @@ -242,6 +243,44 @@ func TestRetryInterceptor_ExhaustedAttempts(t *testing.T) { } } +func TestRetryInterceptor_ExhaustedAttemptsKeepsFinalBodyOpen(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte("rate limited")) + })) + defer upstream.Close() + + retry := NewRetry(3, time.Millisecond) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, err := retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + defer resp.Body.Close() + + if callCount != 3 { + t.Errorf("callCount = %d, want 3", callCount) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("final response body should remain readable: %v", err) + } + if string(body) != "rate limited" { + t.Fatalf("final response body = %q, want %q", string(body), "rate limited") + } +} + func TestRetryInterceptor_NoRetryOn200(t *testing.T) { callCount := 0 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/interceptors/retry.go b/interceptors/retry.go index b536750..fbd576d 100644 --- a/interceptors/retry.go +++ b/interceptors/retry.go @@ -53,7 +53,7 @@ func (i *RetryInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMetada return lastResp, lastMeta, lastRawRespBody, lastErr } - if lastResp != nil && lastResp.Body != nil { + if attempt < i.MaxAttempts && lastResp != nil && lastResp.Body != nil { io.Copy(io.Discard, lastResp.Body) lastResp.Body.Close() } diff --git a/metadata.go b/metadata.go index 6ceb9a8..1eca7c3 100644 --- a/metadata.go +++ b/metadata.go @@ -106,6 +106,8 @@ type BodyMetadata struct { MaxTokens int `json:"max_tokens,omitempty"` // Stream indicates whether streaming is requested. Stream bool `json:"stream"` + // MeteredUsage contains non-token metered request consumption. + MeteredUsage MeteredUsage `json:"metered_usage,omitempty"` // Custom holds provider-specific fields that don't map to standard fields. Custom map[string]any `json:"-"` } @@ -120,6 +122,23 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +// MeteredUsage tracks non-token consumption for modality-specific APIs. +type MeteredUsage struct { + InputCharacters int `json:"input_characters,omitempty"` + OutputCharacters int `json:"output_characters,omitempty"` + InputAudioSeconds float64 `json:"input_audio_seconds,omitempty"` + OutputAudioSeconds float64 `json:"output_audio_seconds,omitempty"` + OutputVideoSeconds float64 `json:"output_video_seconds,omitempty"` + GeneratedImages int `json:"generated_images,omitempty"` + + HasInputCharacters bool `json:"-"` + HasOutputCharacters bool `json:"-"` + HasInputAudioSeconds bool `json:"-"` + HasOutputAudioSeconds bool `json:"-"` + HasOutputVideoSeconds bool `json:"-"` + HasGeneratedImages bool `json:"-"` +} + // CacheUsage tracks prompt caching token consumption. type CacheUsage struct { // CachedTokens is the number of tokens served from cache (OpenAI). @@ -169,6 +188,8 @@ type ResponseMetadata struct { Model string `json:"model,omitempty"` // Usage contains token consumption statistics. Usage Usage `json:"usage"` + // MeteredUsage contains non-token metered response consumption. + MeteredUsage MeteredUsage `json:"metered_usage,omitempty"` // Choices contains the completion choices. Choices []Choice `json:"choices,omitempty"` // Custom holds provider-specific response fields. diff --git a/model_metadata.go b/model_metadata.go new file mode 100644 index 0000000..f53cc6d --- /dev/null +++ b/model_metadata.go @@ -0,0 +1,68 @@ +package llmproxy + +import ( + "fmt" + "strings" +) + +// ModelMetadata describes catalog capabilities for a model. +type ModelMetadata struct { + APICompatibility string + InputModalities []string + OutputModalities []string +} + +// ModelMetadataLookup resolves catalog metadata for a provider/model pair. +type ModelMetadataLookup func(provider string, model string) (ModelMetadata, bool) + +func validateAPITypeAgainstModel(apiType APIType, provider string, model string, metadata ModelMetadata) error { + if apiType == "" { + return nil + } + + switch apiType { + case APITypeChatCompletions, APITypeResponses, APITypeCompletions, APITypeMessages, APITypeGenerateContent, APITypeStreamGenerateContent, APITypeConverse: + if !hasModality(metadata.OutputModalities, "text") { + return unsupportedModelSurfaceError(provider, model, apiType, "text output") + } + case APITypeEmbeddings: + if !hasModality(metadata.OutputModalities, "embedding") { + return unsupportedModelSurfaceError(provider, model, apiType, "embedding output") + } + case APITypeImagesGenerations: + if !hasModality(metadata.OutputModalities, "image") { + return unsupportedModelSurfaceError(provider, model, apiType, "image output") + } + case APITypeAudioSpeech: + if !hasModality(metadata.OutputModalities, "audio") { + return unsupportedModelSurfaceError(provider, model, apiType, "audio output") + } + case APITypeAudioTranscriptions: + if !hasModality(metadata.InputModalities, "audio") || !hasModality(metadata.OutputModalities, "text") { + return unsupportedModelSurfaceError(provider, model, apiType, "audio input and text output") + } + case APITypePredictLongRunning: + if !hasModality(metadata.OutputModalities, "video") { + return unsupportedModelSurfaceError(provider, model, apiType, "video output") + } + } + + return nil +} + +func hasModality(values []string, wanted string) bool { + wanted = strings.ToLower(strings.TrimSpace(wanted)) + for _, value := range values { + if strings.ToLower(strings.TrimSpace(value)) == wanted { + return true + } + } + return false +} + +func unsupportedModelSurfaceError(provider string, model string, apiType APIType, required string) error { + id := strings.Trim(strings.TrimSpace(provider)+"/"+strings.TrimSpace(model), "/") + return &ProviderError{ + Message: fmt.Sprintf("model %s does not support %s; %s requires %s", id, apiType, apiType, required), + } +} diff --git a/pricing/modelsdev/adapter.go b/pricing/modelsdev/adapter.go index 5600e53..9eb52f4 100644 --- a/pricing/modelsdev/adapter.go +++ b/pricing/modelsdev/adapter.go @@ -327,6 +327,7 @@ type Cost struct { CacheWrite float64 `json:"cache_write,omitempty"` InputAudio float64 `json:"input_audio,omitempty"` OutputAudio float64 `json:"output_audio,omitempty"` + Unit string `json:"unit,omitempty"` } // Limit represents token limits. @@ -348,6 +349,7 @@ func costToInfo(c Cost, markup float64) llmproxy.CostInfo { Output: c.Output, CacheRead: c.CacheRead, CacheWrite: c.CacheWrite, + Unit: c.Unit, } if markup > 0 && markup != 1.0 { info.Input *= markup diff --git a/providers/googleai/parser.go b/providers/googleai/parser.go index 691122a..d5c2932 100644 --- a/providers/googleai/parser.go +++ b/providers/googleai/parser.go @@ -41,11 +41,18 @@ func (p *Parser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error return llmproxy.BodyMetadata{}, nil, err } + meteredUsage := llmproxy.MeteredUsage{} + if outputVideoSeconds := req.VideoOutputSeconds(); outputVideoSeconds != 0 { + meteredUsage.OutputVideoSeconds = outputVideoSeconds + meteredUsage.HasOutputVideoSeconds = true + } + meta := llmproxy.BodyMetadata{ - Model: req.Model, - Messages: make([]llmproxy.Message, 0, len(req.Contents)), - MaxTokens: req.GenerationConfig.MaxOutputTokens, - Custom: make(map[string]any), + Model: req.Model, + Messages: make([]llmproxy.Message, 0, len(req.Contents)), + MaxTokens: req.GenerationConfig.MaxOutputTokens, + MeteredUsage: meteredUsage, + Custom: make(map[string]any), } for _, content := range req.Contents { @@ -90,6 +97,8 @@ func extractTextFromParts(parts []Part) string { type Request struct { Model string `json:"model,omitempty"` Contents []Content `json:"contents,omitempty"` + Instances []VideoInstance `json:"instances,omitempty"` + Parameters VideoParameters `json:"parameters,omitempty"` SystemInstruction *Content `json:"systemInstruction,omitempty"` GenerationConfig GenerationConfig `json:"generationConfig,omitempty"` SafetySettings []SafetySetting `json:"safetySettings,omitempty"` @@ -130,6 +139,29 @@ type SafetySetting struct { Threshold string `json:"threshold"` } +// VideoInstance represents a Google long-running video generation prompt. +type VideoInstance struct { + Prompt string `json:"prompt,omitempty"` +} + +// VideoParameters contains Google long-running video generation controls. +type VideoParameters struct { + SampleCount int `json:"sampleCount,omitempty"` + DurationSeconds float64 `json:"durationSeconds,omitempty"` + Resolution string `json:"resolution,omitempty"` +} + +func (r Request) VideoOutputSeconds() float64 { + if r.Parameters.DurationSeconds <= 0 { + return 0 + } + sampleCount := r.Parameters.SampleCount + if sampleCount <= 0 { + sampleCount = 1 + } + return r.Parameters.DurationSeconds * float64(sampleCount) +} + // UnmarshalJSON captures unknown fields into Custom. func (r *Request) UnmarshalJSON(data []byte) error { type Alias Request @@ -149,7 +181,7 @@ func (r *Request) UnmarshalJSON(data []byte) error { r.Custom = make(map[string]interface{}) known := map[string]bool{ - "model": true, "contents": true, "systemInstruction": true, + "model": true, "contents": true, "instances": true, "parameters": true, "systemInstruction": true, "generationConfig": true, "safetySettings": true, "tools": true, "toolConfig": true, } diff --git a/providers/googleai/parser_test.go b/providers/googleai/parser_test.go index e4f7d0a..7d10a32 100644 --- a/providers/googleai/parser_test.go +++ b/providers/googleai/parser_test.go @@ -94,6 +94,25 @@ func TestParser(t *testing.T) { t.Fatal("expected error") } }) + + t.Run("parses video metered usage", func(t *testing.T) { + body := `{"instances":[{"prompt":"make a short clip"}],"parameters":{"durationSeconds":8,"sampleCount":2,"resolution":"720p"}}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.MeteredUsage.OutputVideoSeconds != 16 { + t.Errorf("OutputVideoSeconds = %f, want 16", meta.MeteredUsage.OutputVideoSeconds) + } + if !meta.MeteredUsage.HasOutputVideoSeconds { + t.Error("HasOutputVideoSeconds = false, want true") + } + if _, ok := meta.Custom["parameters"]; ok { + t.Error("parameters should not be captured as a custom field") + } + }) } func TestEnricher(t *testing.T) { @@ -163,6 +182,46 @@ func TestResolver(t *testing.T) { t.Errorf("expected %s, got %s", expected, u.String()) } }) + + t.Run("resolves long-running endpoint from string api type", func(t *testing.T) { + resolver, err := NewResolver("https://generativelanguage.googleapis.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{ + Model: "veo-3.1-generate-preview", + Custom: map[string]any{"api_type": string(llmproxy.APITypePredictLongRunning)}, + } + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := "https://generativelanguage.googleapis.com/v1beta/models/veo-3.1-generate-preview:predictLongRunning" + if u.String() != expected { + t.Errorf("expected %s, got %s", expected, u.String()) + } + }) + + t.Run("resolves streaming endpoint from string api type", func(t *testing.T) { + resolver, err := NewResolver("https://generativelanguage.googleapis.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{ + Model: "gemini-pro", + Custom: map[string]any{"api_type": string(llmproxy.APITypeStreamGenerateContent)}, + } + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?alt=sse" + if u.String() != expected { + t.Errorf("expected %s, got %s", expected, u.String()) + } + }) } func TestExtractor(t *testing.T) { diff --git a/providers/googleai/resolver.go b/providers/googleai/resolver.go index 0a6b9fd..3793450 100644 --- a/providers/googleai/resolver.go +++ b/providers/googleai/resolver.go @@ -30,17 +30,32 @@ func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { } var endpoint *url.URL - if meta.Stream { + apiType := resolveAPIType(meta.Custom["api_type"]) + switch { + case apiType == llmproxy.APITypePredictLongRunning: + endpoint = r.BaseURL.JoinPath("v1beta", "models", fmt.Sprintf("%s:predictLongRunning", model)) + case meta.Stream || apiType == llmproxy.APITypeStreamGenerateContent: endpoint = r.BaseURL.JoinPath("v1beta", "models", fmt.Sprintf("%s:streamGenerateContent", model)) q := endpoint.Query() q.Set("alt", "sse") endpoint.RawQuery = q.Encode() - } else { + default: endpoint = r.BaseURL.JoinPath("v1beta", "models", fmt.Sprintf("%s:generateContent", model)) } return endpoint, nil } +func resolveAPIType(value any) llmproxy.APIType { + switch v := value.(type) { + case llmproxy.APIType: + return v + case string: + return llmproxy.APIType(v) + default: + return "" + } +} + // NewResolver creates a new resolver with the given base URL. func NewResolver(baseURL string) (*Resolver, error) { u, err := url.Parse(baseURL) diff --git a/providers/openai_compatible/enricher.go b/providers/openai_compatible/enricher.go index d4bf251..9d1dcf4 100644 --- a/providers/openai_compatible/enricher.go +++ b/providers/openai_compatible/enricher.go @@ -13,12 +13,13 @@ type Enricher struct { APIKey string } -// Enrich adds the Authorization and Content-Type headers to the request. -// It sets: -// - Authorization: Bearer -// - Content-Type: application/json +// Enrich adds the Authorization header and defaults Content-Type to JSON. +// Existing content types, such as multipart/form-data for audio uploads, are +// preserved. func (e *Enricher) Enrich(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) error { - req.Header.Set("Content-Type", "application/json") + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } if e.APIKey != "" { req.Header.Set("Authorization", "Bearer "+e.APIKey) } else { diff --git a/providers/openai_compatible/extractor.go b/providers/openai_compatible/extractor.go index 557cc4f..b2bd5a3 100644 --- a/providers/openai_compatible/extractor.go +++ b/providers/openai_compatible/extractor.go @@ -31,6 +31,13 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b var openaiResp OpenAIResponse if err := json.Unmarshal(body, &openaiResp); err != nil { + contentType := resp.Header.Get("Content-Type") + if contentType != "" && !strings.Contains(contentType, "json") { + return llmproxy.ResponseMetadata{ + MeteredUsage: llmproxy.MeteredUsage{}, + Custom: make(map[string]any), + }, body, nil + } return llmproxy.ResponseMetadata{}, nil, err } @@ -39,22 +46,39 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b Object: openaiResp.Object, Model: openaiResp.Model, Usage: llmproxy.Usage{ - PromptTokens: openaiResp.Usage.PromptTokens, - CompletionTokens: openaiResp.Usage.CompletionTokens, - TotalTokens: openaiResp.Usage.TotalTokens, + PromptTokens: openaiResp.Usage.PromptTokenCount(), + CompletionTokens: openaiResp.Usage.CompletionTokenCount(), + TotalTokens: openaiResp.Usage.TotalTokenCount(), }, Choices: make([]llmproxy.Choice, len(openaiResp.Choices)), Custom: make(map[string]any), } + meta.MeteredUsage = llmproxy.MeteredUsage{} + if openaiResp.Usage.Type == "duration" { + meta.MeteredUsage.InputAudioSeconds = openaiResp.Usage.Seconds + meta.MeteredUsage.HasInputAudioSeconds = true + } + if openaiResp.Data != nil { + meta.MeteredUsage.GeneratedImages = len(openaiResp.Data) + meta.MeteredUsage.HasGeneratedImages = true + } - if openaiResp.Usage.PromptTokensDetails != nil && openaiResp.Usage.PromptTokensDetails.CachedTokens > 0 { + promptDetails := openaiResp.Usage.PromptTokensDetails + if promptDetails == nil { + promptDetails = openaiResp.Usage.InputTokensDetails + } + if promptDetails != nil && promptDetails.CachedTokens > 0 { meta.Custom["cache_usage"] = llmproxy.CacheUsage{ - CachedTokens: openaiResp.Usage.PromptTokensDetails.CachedTokens, + CachedTokens: promptDetails.CachedTokens, } } - if openaiResp.Usage.CompletionTokensDetails != nil && openaiResp.Usage.CompletionTokensDetails.ReasoningTokens > 0 { - meta.Custom["reasoning_tokens"] = openaiResp.Usage.CompletionTokensDetails.ReasoningTokens + completionDetails := openaiResp.Usage.CompletionTokensDetails + if completionDetails == nil { + completionDetails = openaiResp.Usage.OutputTokensDetails + } + if completionDetails != nil && completionDetails.ReasoningTokens > 0 { + meta.Custom["reasoning_tokens"] = completionDetails.ReasoningTokens } for i, c := range openaiResp.Choices { @@ -93,15 +117,51 @@ type OpenAIResponse struct { Usage UsageInfo `json:"usage"` // Choices contains the completion choices. Choices []ResponseChoice `json:"choices"` + // Data contains non-chat output records such as generated images. + Data []json.RawMessage `json:"data,omitempty"` } // UsageInfo tracks token usage in an OpenAI-compatible response. type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` + InputTokens int `json:"input_tokens"` CompletionTokens int `json:"completion_tokens"` + OutputTokens int `json:"output_tokens"` TotalTokens int `json:"total_tokens"` PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` + InputTokensDetails *PromptTokensDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *CompletionTokensDetails `json:"output_tokens_details,omitempty"` + Type string `json:"type,omitempty"` + Seconds float64 `json:"seconds,omitempty"` +} + +func (u UsageInfo) PromptTokenCount() int { + if u.PromptTokens > 0 { + return u.PromptTokens + } + return u.InputTokens +} + +func (u UsageInfo) CompletionTokenCount() int { + if u.CompletionTokens > 0 { + return u.CompletionTokens + } + return u.OutputTokens +} + +func (u UsageInfo) TotalTokenCount() int { + if u.TotalTokens > 0 { + return u.TotalTokens + } + return u.PromptTokenCount() + u.CompletionTokenCount() +} + +func (u UsageInfo) InputAudioSeconds() float64 { + if u.Type == "duration" { + return u.Seconds + } + return 0 } // PromptTokensDetails contains detailed prompt token breakdown. diff --git a/providers/openai_compatible/extractor_test.go b/providers/openai_compatible/extractor_test.go index 818b2e5..0988e14 100644 --- a/providers/openai_compatible/extractor_test.go +++ b/providers/openai_compatible/extractor_test.go @@ -53,6 +53,99 @@ func TestExtractor_ReasoningTokens(t *testing.T) { } } +func TestExtractor_ImageGenerationUsage(t *testing.T) { + body := `{ + "created": 1778342333, + "data": [{"b64_json": "abc"}], + "usage": { + "input_tokens": 15, + "input_tokens_details": { + "image_tokens": 0, + "text_tokens": 15 + }, + "output_tokens": 272, + "output_tokens_details": { + "image_tokens": 272, + "text_tokens": 0 + }, + "total_tokens": 287 + } + }` + + extractor := NewExtractor() + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + if meta.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 272 { + t.Errorf("CompletionTokens = %d, want 272", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 287 { + t.Errorf("TotalTokens = %d, want 287", meta.Usage.TotalTokens) + } + if meta.MeteredUsage.GeneratedImages != 1 { + t.Errorf("GeneratedImages = %d, want 1", meta.MeteredUsage.GeneratedImages) + } + if !meta.MeteredUsage.HasGeneratedImages { + t.Error("HasGeneratedImages = false, want true") + } +} + +func TestExtractor_DurationUsage(t *testing.T) { + body := `{"text":"You","usage":{"type":"duration","seconds":1}}` + + extractor := NewExtractor() + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + if meta.MeteredUsage.InputAudioSeconds != 1 { + t.Errorf("InputAudioSeconds = %f, want 1", meta.MeteredUsage.InputAudioSeconds) + } + if !meta.MeteredUsage.HasInputAudioSeconds { + t.Error("HasInputAudioSeconds = false, want true") + } +} + +func TestExtractor_NonJSONResponsePassesThrough(t *testing.T) { + body := []byte{0xff, 0xfb, 0x90, 0x64} + + extractor := NewExtractor() + resp := &http.Response{ + StatusCode: 200, + Header: http.Header{ + "Content-Type": []string{"audio/mpeg"}, + }, + Body: io.NopCloser(bytes.NewReader(body)), + } + + meta, raw, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + if !bytes.Equal(raw, body) { + t.Fatalf("raw body = %v, want %v", raw, body) + } + if meta.Custom == nil { + t.Fatal("expected custom metadata map") + } +} + func TestExtractor_ReasoningTokensZero(t *testing.T) { body := `{ "id": "chatcmpl-abc", diff --git a/providers/openai_compatible/multiapi.go b/providers/openai_compatible/multiapi.go index 4ef6ea3..abca376 100644 --- a/providers/openai_compatible/multiapi.go +++ b/providers/openai_compatible/multiapi.go @@ -101,7 +101,8 @@ func (e *StreamingMultiAPIExtractor) ExtractStreamingWithController(resp *http.R if resp.Request != nil { metaCtx := llmproxy.GetMetaFromContext(resp.Request.Context()) - if apiType, ok := metaCtx.Meta.Custom["api_type"].(llmproxy.APIType); ok && apiType == llmproxy.APITypeResponses { + apiType := metaCtx.Meta.Custom["api_type"] + if apiType == llmproxy.APITypeResponses || apiType == string(llmproxy.APITypeResponses) { return e.responsesStreaming.ExtractStreamingWithController(resp, w, rc) } } diff --git a/providers/openai_compatible/parser.go b/providers/openai_compatible/parser.go index 1dd7972..43c21fe 100644 --- a/providers/openai_compatible/parser.go +++ b/providers/openai_compatible/parser.go @@ -14,6 +14,7 @@ import ( "bytes" "encoding/json" "io" + "unicode/utf8" "github.com/agentuity/llmproxy" ) @@ -39,12 +40,19 @@ func (p *Parser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error return llmproxy.BodyMetadata{}, nil, err } + meteredUsage := llmproxy.MeteredUsage{} + if req.Input != "" { + meteredUsage.InputCharacters = utf8.RuneCountInString(req.Input) + meteredUsage.HasInputCharacters = true + } + meta := llmproxy.BodyMetadata{ - Model: req.Model, - Messages: req.Messages, - MaxTokens: req.MaxTokens, - Stream: req.Stream, - Custom: make(map[string]any), + Model: req.Model, + Messages: req.Messages, + MaxTokens: req.MaxTokens, + Stream: req.Stream, + MeteredUsage: meteredUsage, + Custom: make(map[string]any), } for k, v := range req.Custom { @@ -65,6 +73,8 @@ type OpenAIRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // Stream enables streaming responses. Stream bool `json:"stream"` + // Input is used by non-chat APIs such as audio speech generation. + Input string `json:"input,omitempty"` // Custom holds provider-specific parameters not in the standard schema. Custom map[string]interface{} `json:"-"` } @@ -88,7 +98,7 @@ func (r *OpenAIRequest) UnmarshalJSON(data []byte) error { r.Custom = make(map[string]interface{}) known := map[string]bool{ - "model": true, "messages": true, "max_tokens": true, + "model": true, "messages": true, "max_tokens": true, "input": true, "stream": true, "temperature": true, "top_p": true, "n": true, "stop": true, "presence_penalty": true, "frequency_penalty": true, "logit_bias": true, "user": true, diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index 9b63032..458d4e4 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -151,6 +151,23 @@ func TestParser_UnicodeContent(t *testing.T) { } } +func TestParser_InputCharacterUsage(t *testing.T) { + body := `{"model":"tts-1","input":"Hello δΈ–η•Œ"}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.MeteredUsage.InputCharacters != 8 { + t.Errorf("InputCharacters = %d, want 8", meta.MeteredUsage.InputCharacters) + } + if !meta.MeteredUsage.HasInputCharacters { + t.Error("HasInputCharacters = false, want true") + } +} + func TestEnricher_SetsHeaders(t *testing.T) { enricher := NewEnricher("test-api-key") req := httptest.NewRequest("POST", "https://api.example.com/v1/chat/completions", nil) @@ -168,6 +185,21 @@ func TestEnricher_SetsHeaders(t *testing.T) { } } +func TestEnricher_PreservesExistingContentType(t *testing.T) { + enricher := NewEnricher("test-api-key") + req := httptest.NewRequest("POST", "https://api.example.com/v1/audio/transcriptions", nil) + req.Header.Set("Content-Type", "multipart/form-data; boundary=test") + + err := enricher.Enrich(req, llmproxy.BodyMetadata{}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ct := req.Header.Get("Content-Type"); ct != "multipart/form-data; boundary=test" { + t.Errorf("Content-Type = %q, want multipart form", ct) + } +} + func TestEnricher_EmptyKey(t *testing.T) { enricher := NewEnricher("") req := httptest.NewRequest("POST", "https://example.com", nil) diff --git a/providers/openai_compatible/resolver.go b/providers/openai_compatible/resolver.go index bb3d5b0..74399be 100644 --- a/providers/openai_compatible/resolver.go +++ b/providers/openai_compatible/resolver.go @@ -27,6 +27,14 @@ func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { return r.BaseURL.JoinPath("v1", "responses"), nil case llmproxy.APITypeCompletions: return r.BaseURL.JoinPath("v1", "completions"), nil + case llmproxy.APITypeEmbeddings: + return r.BaseURL.JoinPath("v1", "embeddings"), nil + case llmproxy.APITypeImagesGenerations: + return r.BaseURL.JoinPath("v1", "images", "generations"), nil + case llmproxy.APITypeAudioSpeech: + return r.BaseURL.JoinPath("v1", "audio", "speech"), nil + case llmproxy.APITypeAudioTranscriptions: + return r.BaseURL.JoinPath("v1", "audio", "transcriptions"), nil default: return r.BaseURL.JoinPath("v1", "chat", "completions"), nil } diff --git a/providers/openai_compatible/responses_parser.go b/providers/openai_compatible/responses_parser.go index 0b5fded..8ef7f20 100644 --- a/providers/openai_compatible/responses_parser.go +++ b/providers/openai_compatible/responses_parser.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "io" + "unicode/utf8" "github.com/agentuity/llmproxy" ) @@ -33,6 +34,8 @@ func (p *ResponsesParser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []by switch v := req.Input.(type) { case string: meta.Messages = []llmproxy.Message{{Role: "user", Content: v}} + meta.MeteredUsage.InputCharacters = utf8.RuneCountInString(v) + meta.MeteredUsage.HasInputCharacters = true case []interface{}: msgs := make([]llmproxy.Message, 0, len(v)) for _, item := range v { diff --git a/providers/openai_compatible/responses_streaming_extractor_test.go b/providers/openai_compatible/responses_streaming_extractor_test.go index 60fbb15..11e9d1e 100644 --- a/providers/openai_compatible/responses_streaming_extractor_test.go +++ b/providers/openai_compatible/responses_streaming_extractor_test.go @@ -85,6 +85,44 @@ func TestResponsesStreamingExtractor_UsageExtraction(t *testing.T) { } } +func TestResponsesStreamingExtractor_TopLevelUsageExtraction(t *testing.T) { + stream := "data: {\"type\":\"response.completed\",\"usage\":{\"input_tokens\":101,\"output_tokens\":44,\"total_tokens\":145}}\n\n" + + meta, _, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Usage.PromptTokens != 101 { + t.Errorf("PromptTokens = %d, want 101", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 44 { + t.Errorf("CompletionTokens = %d, want 44", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 145 { + t.Errorf("TotalTokens = %d, want 145", meta.Usage.TotalTokens) + } +} + +func TestResponsesStreamingExtractor_IncompleteUsageExtraction(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello\"}\n\n" + + "data: {\"type\":\"response.incomplete\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-4o\",\"status\":\"incomplete\",\"usage\":{\"input_tokens\":20,\"output_tokens\":30,\"total_tokens\":50}}}\n\n" + + meta, _, err := runResponsesStreamExtraction(t, "text/event-stream", stream) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Usage.PromptTokens != 20 { + t.Errorf("PromptTokens = %d, want 20", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 30 { + t.Errorf("CompletionTokens = %d, want 30", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 50 { + t.Errorf("TotalTokens = %d, want 50", meta.Usage.TotalTokens) + } +} + func TestResponsesStreamingExtractor_CacheUsageExtraction(t *testing.T) { stream := "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":100,\"output_tokens\":20,\"total_tokens\":120,\"input_tokens_details\":{\"cached_tokens\":80}}}}\n\n" + "data: [DONE]\n\n" diff --git a/providers/openai_compatible/responses_test.go b/providers/openai_compatible/responses_test.go index 5f995fc..21eb74a 100644 --- a/providers/openai_compatible/responses_test.go +++ b/providers/openai_compatible/responses_test.go @@ -56,6 +56,13 @@ func TestResponsesParser(t *testing.T) { t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) } + if meta.MeteredUsage.InputCharacters != 13 { + t.Errorf("InputCharacters = %d, want 13", meta.MeteredUsage.InputCharacters) + } + if !meta.MeteredUsage.HasInputCharacters { + t.Error("HasInputCharacters = false, want true") + } + if len(data) == 0 { t.Error("data is empty") } @@ -1758,6 +1765,47 @@ func TestStreamingMultiAPIExtractor_ResponsesAPIDispatch(t *testing.T) { } } +func TestStreamingMultiAPIExtractor_ResponsesAPIStringDispatch(t *testing.T) { + stream := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_dispatch\",\"model\":\"gpt-4o\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_dispatch\",\"model\":\"gpt-4o\",\"usage\":{\"input_tokens\":10,\"output_tokens\":5,\"total_tokens\":15}}}\n\n" + + "data: [DONE]\n\n" + + req, err := http.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil) + if err != nil { + t.Fatalf("NewRequest() error = %v", err) + } + ctxValue := llmproxy.MetaContextValue{ + Meta: llmproxy.BodyMetadata{Custom: map[string]any{"api_type": string(llmproxy.APITypeResponses)}}, + } + req = req.WithContext(context.WithValue(req.Context(), llmproxy.MetaContextKey{}, ctxValue)) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(stream)), + Request: req, + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingMultiAPIExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("ExtractStreamingWithController() error = %v", err) + } + + if meta.ID != "resp_dispatch" { + t.Errorf("ID = %q, want resp_dispatch", meta.ID) + } + if meta.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", meta.Usage.TotalTokens) + } + if meta.Custom["api_type"] != llmproxy.APITypeResponses { + t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) + } +} + func TestStreamingMultiAPIExtractor_ChatCompletionsDispatch(t *testing.T) { stream := "data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n" + "data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":3,\"total_tokens\":12}}\n\n" + diff --git a/streaming.go b/streaming.go index 45d7ffc..1eb16c8 100644 --- a/streaming.go +++ b/streaming.go @@ -202,7 +202,13 @@ type AnthropicStreamMessage struct { } type ResponsesStreamEvent struct { - Type string `json:"type"` + Type string `json:"type"` + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Usage *ResponsesStreamUsage `json:"usage,omitempty"` + Response json.RawMessage `json:"response,omitempty"` } @@ -332,33 +338,39 @@ func ExtractUsageFromAnthropicEvent(event *AnthropicStreamEvent) *StreamingUsage } func ExtractUsageFromResponsesEvent(event *ResponsesStreamEvent) *StreamingUsage { - if event == nil || event.Type != "response.completed" || len(event.Response) == 0 { + if event == nil || (event.Type != "response.completed" && event.Type != "response.incomplete") { return nil } - var response ResponsesStreamResponse - if err := json.Unmarshal(event.Response, &response); err != nil { - return nil + var responseUsage *ResponsesStreamUsage + if len(event.Response) > 0 { + var response ResponsesStreamResponse + if err := json.Unmarshal(event.Response, &response); err != nil { + return nil + } + responseUsage = response.Usage + } else { + responseUsage = event.Usage } - if response.Usage == nil { + if responseUsage == nil { return nil } usage := &StreamingUsage{ - PromptTokens: response.Usage.InputTokens, - CompletionTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.TotalTokens, + PromptTokens: responseUsage.InputTokens, + CompletionTokens: responseUsage.OutputTokens, + TotalTokens: responseUsage.TotalTokens, } - if response.Usage.InputTokensDetails != nil && response.Usage.InputTokensDetails.CachedTokens > 0 { + if responseUsage.InputTokensDetails != nil && responseUsage.InputTokensDetails.CachedTokens > 0 { usage.CacheUsage = &CacheUsage{ - CachedTokens: response.Usage.InputTokensDetails.CachedTokens, + CachedTokens: responseUsage.InputTokensDetails.CachedTokens, } } - if response.Usage.OutputTokensDetails != nil && response.Usage.OutputTokensDetails.ReasoningTokens > 0 { - usage.ReasoningTokens = response.Usage.OutputTokensDetails.ReasoningTokens + if responseUsage.OutputTokensDetails != nil && responseUsage.OutputTokensDetails.ReasoningTokens > 0 { + usage.ReasoningTokens = responseUsage.OutputTokensDetails.ReasoningTokens } return usage diff --git a/streaming_test.go b/streaming_test.go index ceeaf42..442050c 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -712,6 +712,30 @@ func TestExtractUsageFromResponsesEvent(t *testing.T) { expectedCompletion: 5, expectedTotal: 15, }, + { + name: "completed with top-level usage", + event: &ResponsesStreamEvent{ + Type: "response.completed", + Usage: &ResponsesStreamUsage{ + InputTokens: 11, + OutputTokens: 6, + TotalTokens: 17, + }, + }, + expectedPrompt: 11, + expectedCompletion: 6, + expectedTotal: 17, + }, + { + name: "incomplete with usage", + event: &ResponsesStreamEvent{ + Type: "response.incomplete", + Response: []byte(`{"usage":{"input_tokens":13,"output_tokens":8,"total_tokens":21}}`), + }, + expectedPrompt: 13, + expectedCompletion: 8, + expectedTotal: 21, + }, { name: "completed without usage", event: &ResponsesStreamEvent{