Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Dict

from autogen_core.models import ModelFamily, ModelInfo
Expand Down Expand Up @@ -138,29 +139,39 @@
}


def _normalize_model_id(model: str) -> str:
"""Normalize provider-specific Anthropic model IDs to table keys."""
model = re.sub(r"^(?:[a-z]+\.)?anthropic\.", "", model)
return re.sub(r"-v\d+:\d+$", "", model)


def get_info(model: str) -> ModelInfo:
"""Get the model information for a specific model."""
normalized_model = _normalize_model_id(model)

# Check for exact match first
if model in _MODEL_INFO:
return _MODEL_INFO[model]
if normalized_model in _MODEL_INFO:
return _MODEL_INFO[normalized_model]

# Check for partial match (for handling model variants)
for model_id in _MODEL_INFO:
if model.startswith(model_id.split("-2")[0]): # Match base name
if normalized_model.startswith(model_id.split("-2")[0]): # Match base name
return _MODEL_INFO[model_id]

raise KeyError(f"Model '{model}' not found in model info")


def get_token_limit(model: str) -> int:
"""Get the token limit for a specific model."""
normalized_model = _normalize_model_id(model)

# Check for exact match first
if model in _MODEL_TOKEN_LIMITS:
return _MODEL_TOKEN_LIMITS[model]
if normalized_model in _MODEL_TOKEN_LIMITS:
return _MODEL_TOKEN_LIMITS[normalized_model]

# Check for partial match (for handling model variants)
for model_id in _MODEL_TOKEN_LIMITS:
if model.startswith(model_id.split("-2")[0]): # Match base name
if normalized_model.startswith(model_id.split("-2")[0]): # Match base name
return _MODEL_TOKEN_LIMITS[model_id]

# Default to a reasonable limit if model not found
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from autogen_core.models import ModelFamily
from autogen_ext.models.anthropic import _model_info


@pytest.mark.parametrize(
"model",
[
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"global.anthropic.claude-3-5-sonnet-20240620-v1:0",
],
)
def test_bedrock_model_ids_resolve_model_info(model: str) -> None:
info = _model_info.get_info(model)

assert info["family"] == ModelFamily.CLAUDE_3_5_SONNET
assert info == _model_info.get_info("claude-3-5-sonnet-20240620")


@pytest.mark.parametrize(
"model",
[
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
],
)
def test_bedrock_model_ids_resolve_token_limit(model: str) -> None:
assert _model_info.get_token_limit(model) == 200000