diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index f39a7a165e..b1ebd40cd4 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -211,8 +211,11 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model arn_parts = model_pkg_arn.split(':') if len(arn_parts) >= 4: region = arn_parts[3] - # Construct hub content ARN for SageMaker public hub - base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}" + # Use SAGEMAKER_HUB_NAME if set (private hub), otherwise fall back to public hub + hub_name = os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub") + # Private hubs are account-scoped; public hub uses 'aws' as account + hub_account = "aws" if hub_name == "SageMakerPublicHub" else arn_parts[4] + base_model_arn = f"arn:aws:sagemaker:{region}:{hub_account}:hub-content/{hub_name}/Model/{hub_content_name}/{hub_content_version}" # If we couldn't extract or construct base model ARN, this is not a supported model package if not base_model_arn: