diff --git a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_model_info.py b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_model_info.py index 81dc4f79638e..03d0a49d8b61 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_model_info.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_model_info.py @@ -1,3 +1,4 @@ +import re from typing import Dict from autogen_core.models import ModelFamily, ModelInfo @@ -138,15 +139,23 @@ } +def _normalize_model_id(model: str) -> str: + """Normalize provider-specific Anthropic model IDs to table keys.""" + model = re.sub(r"^(?:[a-z]+\.)?anthropic\.", "", model) + return re.sub(r"-v\d+:\d+$", "", model) + + def get_info(model: str) -> ModelInfo: """Get the model information for a specific model.""" + normalized_model = _normalize_model_id(model) + # Check for exact match first - if model in _MODEL_INFO: - return _MODEL_INFO[model] + if normalized_model in _MODEL_INFO: + return _MODEL_INFO[normalized_model] # Check for partial match (for handling model variants) for model_id in _MODEL_INFO: - if model.startswith(model_id.split("-2")[0]): # Match base name + if normalized_model.startswith(model_id.split("-2")[0]): # Match base name return _MODEL_INFO[model_id] raise KeyError(f"Model '{model}' not found in model info") @@ -154,13 +163,15 @@ def get_info(model: str) -> ModelInfo: def get_token_limit(model: str) -> int: """Get the token limit for a specific model.""" + normalized_model = _normalize_model_id(model) + # Check for exact match first - if model in _MODEL_TOKEN_LIMITS: - return _MODEL_TOKEN_LIMITS[model] + if normalized_model in _MODEL_TOKEN_LIMITS: + return _MODEL_TOKEN_LIMITS[normalized_model] # Check for partial match (for handling model variants) for model_id in _MODEL_TOKEN_LIMITS: - if model.startswith(model_id.split("-2")[0]): # Match base name + if normalized_model.startswith(model_id.split("-2")[0]): # Match base name return _MODEL_TOKEN_LIMITS[model_id] # Default to a reasonable limit if model not found diff --git a/python/packages/autogen-ext/tests/models/test_anthropic_model_info.py b/python/packages/autogen-ext/tests/models/test_anthropic_model_info.py new file mode 100644 index 000000000000..2351f384c2ad --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_anthropic_model_info.py @@ -0,0 +1,32 @@ +import pytest +from autogen_core.models import ModelFamily +from autogen_ext.models.anthropic import _model_info + + +@pytest.mark.parametrize( + "model", + [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "eu.anthropic.claude-3-5-sonnet-20240620-v1:0", + "apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + "global.anthropic.claude-3-5-sonnet-20240620-v1:0", + ], +) +def test_bedrock_model_ids_resolve_model_info(model: str) -> None: + info = _model_info.get_info(model) + + assert info["family"] == ModelFamily.CLAUDE_3_5_SONNET + assert info == _model_info.get_info("claude-3-5-sonnet-20240620") + + +@pytest.mark.parametrize( + "model", + [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + ], +) +def test_bedrock_model_ids_resolve_token_limit(model: str) -> None: + assert _model_info.get_token_limit(model) == 200000 +