Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions autorouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,88 @@ func normalizeProviderRequest(raw map[string]any, providerName string) {
const defaultAnthropicMaxTokens = 1024

func normalizeAnthropicRequest(raw map[string]any) {
normalizeAnthropicSystemMessages(raw)

if hasPositiveNumber(raw["max_tokens"]) {
return
}
raw["max_tokens"] = defaultAnthropicMaxTokens
}

func normalizeAnthropicSystemMessages(raw map[string]any) {
messages, ok := raw["messages"].([]any)
if !ok || len(messages) == 0 {
return
}

filtered := make([]any, 0, len(messages))
systemParts := make([]any, 0, 1)
for _, item := range messages {
message, ok := item.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
role, _ := message["role"].(string)
if role != "system" {
filtered = append(filtered, item)
continue
}
if content, exists := message["content"]; exists {
systemParts = append(systemParts, content)
}
}

raw["messages"] = filtered
if len(systemParts) > 0 {
raw["system"] = mergeAnthropicSystem(raw["system"], systemParts)
}
}

func mergeAnthropicSystem(existing any, systemParts []any) any {
systemText := joinTextParts(systemParts)
if existing == nil {
if systemText != "" {
return systemText
}
return systemParts[0]
}

existingText := joinTextParts([]any{existing})
if existingText != "" && systemText != "" {
return existingText + "\n\n" + systemText
}
if systemText != "" {
return systemText
}
return existing
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

func joinTextParts(parts []any) string {
values := make([]string, 0, len(parts))
for _, part := range parts {
switch v := part.(type) {
case string:
if strings.TrimSpace(v) != "" {
values = append(values, v)
}
case []any:
for _, item := range v {
block, ok := item.(map[string]any)
if !ok {
continue
}
blockType, _ := block["type"].(string)
text, _ := block["text"].(string)
if blockType == "text" && strings.TrimSpace(text) != "" {
values = append(values, text)
}
}
}
}
return strings.Join(values, "\n\n")
}

func hasPositiveNumber(value any) bool {
switch v := value.(type) {
case int:
Expand Down
230 changes: 230 additions & 0 deletions autorouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,236 @@ func TestAutoRouter_AnthropicPreservesMaxTokens(t *testing.T) {
}
}

func TestAutoRouter_AnthropicMovesSystemMessageToTopLevel(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "anthropic",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "claude-3-opus"}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "msg_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","messages":[{"role":"system","content":"You are terse."},{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["system"]; got != "You are terse." {
t.Fatalf("system = %v, want %q", got, "You are terse.")
}
messages, ok := receivedBody["messages"].([]any)
if !ok {
t.Fatalf("messages = %T, want []any", receivedBody["messages"])
}
if len(messages) != 1 {
t.Fatalf("len(messages) = %d, want 1", len(messages))
}
message, ok := messages[0].(map[string]any)
if !ok {
t.Fatalf("messages[0] = %T, want map[string]any", messages[0])
}
if got := message["role"]; got != "user" {
t.Fatalf("messages[0].role = %v, want user", got)
}
if got := receivedBody["max_tokens"]; got != float64(defaultAnthropicMaxTokens) {
t.Fatalf("max_tokens = %v, want %d", got, defaultAnthropicMaxTokens)
}
}

func TestAutoRouter_AnthropicMergesSystemMessageWithExistingSystem(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "anthropic",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "claude-3-opus"}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "msg_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","system":"Existing system.","messages":[{"role":"system","content":"Additional system."},{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["system"]; got != "Existing system.\n\nAdditional system." {
t.Fatalf("system = %v, want merged system", got)
}
messages, ok := receivedBody["messages"].([]any)
if !ok || len(messages) != 1 {
t.Fatalf("messages = %#v, want one non-system message", receivedBody["messages"])
}
}

func TestAutoRouter_AnthropicUsesSystemMessageWhenExistingSystemEmpty(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "anthropic",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "claude-3-opus"}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "msg_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","system":"","messages":[{"role":"system","content":"Use terse answers."},{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["system"]; got != "Use terse answers." {
t.Fatalf("system = %v, want %q", got, "Use terse answers.")
}
messages, ok := receivedBody["messages"].([]any)
if !ok || len(messages) != 1 {
t.Fatalf("messages = %#v, want one non-system message", receivedBody["messages"])
}
message, ok := messages[0].(map[string]any)
if !ok {
t.Fatalf("messages[0] = %T, want map[string]any", messages[0])
}
if got := message["role"]; got != "user" {
t.Fatalf("messages[0].role = %v, want user", got)
}
if got := receivedBody["max_tokens"]; got != float64(defaultAnthropicMaxTokens) {
t.Fatalf("max_tokens = %v, want %d", got, defaultAnthropicMaxTokens)
}
}

func TestAutoRouter_AnthropicRemovesSystemMessageWithMissingContent(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "anthropic",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "claude-3-opus"}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "msg_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","system":"Existing system.","messages":[{"role":"system"},{"role":"system","content":null},{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["system"]; got != "Existing system." {
t.Fatalf("system = %v, want existing system unchanged", got)
}
messages, ok := receivedBody["messages"].([]any)
if !ok || len(messages) != 1 {
t.Fatalf("messages = %#v, want one non-system message", receivedBody["messages"])
}
message, ok := messages[0].(map[string]any)
if !ok {
t.Fatalf("messages[0] = %T, want map[string]any", messages[0])
}
if got := message["role"]; got != "user" {
t.Fatalf("messages[0].role = %v, want user", got)
}
}

func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand Down
Loading