diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..a48edea31a 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2411,6 +2411,12 @@ def _build_single_modelbuilder( f"{self.model._latest_training_job.model_artifacts.s3_model_artifacts}" "/checkpoints/hf/" ) + elif isinstance(self.model, ModelPackage): + self._adapter_s3_uri = ( + model_package.inference_specification.containers[ + 0 + ].model_data_source.s3_data_source.s3_uri + ) else: # Non-LORA: Model points at training output self.s3_upload_path = model_package.inference_specification.containers[ @@ -2418,6 +2424,7 @@ def _build_single_modelbuilder( ].model_data_source.s3_data_source.s3_uri container_def = ContainerDefinition( image=self.image_uri, + environment=self.env_vars, model_data_source={ "s3_data_source": { "s3_uri": self.s3_upload_path.rstrip("/") + "/", @@ -4554,6 +4561,16 @@ def _fetch_peft(self) -> Optional[str]: training_job = self.model elif isinstance(self.model, ModelTrainer): training_job = self.model._latest_training_job + elif isinstance(self.model, ModelPackage): + try: + recipe_name = ( + self.model.inference_specification.containers[0].base_model.recipe_name + ) + if recipe_name and "lora" in recipe_name.lower(): + return "LORA" + except (AttributeError, IndexError): + pass + return None else: return None diff --git a/sagemaker-serve/tests/integ/test_model_package_lora_detection.py b/sagemaker-serve/tests/integ/test_model_package_lora_detection.py new file mode 100644 index 0000000000..b762836262 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_model_package_lora_detection.py @@ -0,0 +1,179 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for ModelPackage LoRA detection in ModelBuilder. + +Build-only tests that verify _fetch_peft() correctly detects LoRA from +ModelPackage recipe names, that the build path sets the right attributes, +and that the created SageMaker Model has the correct container configuration. +No deployment or GPU instances required. +""" +from __future__ import absolute_import + +import os +import time +import random +import pytest +import boto3 + + +# LoRA model package (recipe name contains "lora") +LORA_MODEL_PACKAGE_ARN = ( + "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1" +) + +# Non-LoRA model package (DPO recipe, no "lora" in name) +NON_LORA_MODEL_PACKAGE_ARN = ( + "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/264" +) + +REGION = "us-west-2" + + +@pytest.fixture(scope="session", autouse=True) +def set_region(): + """Ensure us-west-2 region for all tests.""" + original = os.environ.get("SAGEMAKER_REGION") + os.environ["SAGEMAKER_REGION"] = REGION + yield + if original: + os.environ["SAGEMAKER_REGION"] = original + elif "SAGEMAKER_REGION" in os.environ: + del os.environ["SAGEMAKER_REGION"] + + +@pytest.fixture(scope="session") +def sm_client(): + """Boto3 SageMaker client for validating created models.""" + return boto3.client("sagemaker", region_name=REGION) + + +class TestModelPackageLoraBuild: + """Test build() from LoRA and non-LoRA ModelPackages.""" + + def test_build_lora_model_package(self, sm_client): + """Build from a LoRA ModelPackage: verify peft detection, adapter URI, and model config.""" + from sagemaker.core.resources import ModelPackage + from sagemaker.serve import ModelBuilder + + model_package = ModelPackage.get(model_package_name=LORA_MODEL_PACKAGE_ARN) + model_builder = ModelBuilder(model=model_package) + model_builder.accept_eula = True + + # Verify _fetch_peft detects LoRA + peft_type = model_builder._fetch_peft() + assert peft_type == "LORA", ( + f"Expected 'LORA' but got '{peft_type}' for {LORA_MODEL_PACKAGE_ARN}" + ) + + model_name = f"integ-lora-test-{int(time.time())}-{random.randint(100, 10000)}" + model = model_builder.build(model_name=model_name) + + try: + assert model is not None + assert model.model_arn is not None + + # Verify _adapter_s3_uri is set to a valid S3 path + assert hasattr(model_builder, "_adapter_s3_uri"), ( + "_adapter_s3_uri should be set after build() for LoRA ModelPackage" + ) + assert model_builder._adapter_s3_uri is not None + assert model_builder._adapter_s3_uri.startswith("s3://"), ( + f"_adapter_s3_uri should be an S3 URI, got: {model_builder._adapter_s3_uri}" + ) + + # Use boto3 to validate the actual model configuration + describe_resp = sm_client.describe_model(ModelName=model_name) + containers = describe_resp.get("Containers", []) + assert len(containers) == 1, f"Expected 1 container, got {len(containers)}" + + container = containers[0] + + # LoRA path: container should point at JumpStart base model, NOT the adapter + model_data = container.get("ModelDataSource", {}) + s3_source = model_data.get("S3DataSource", {}) + s3_uri = s3_source.get("S3Uri", "") + assert s3_uri, "Container should have an S3 model data source" + # The S3 URI should NOT be the adapter URI — it should be the base model + assert s3_uri != model_builder._adapter_s3_uri, ( + f"LoRA container S3 URI should point to base model, not adapter. " + f"Got: {s3_uri}, adapter: {model_builder._adapter_s3_uri}" + ) + + # LoRA path: container should have accept_eula in model access config + access_config = s3_source.get("ModelAccessConfig", {}) + assert access_config.get("AcceptEula") is True, ( + "LoRA container should have AcceptEula=True in ModelAccessConfig" + ) + + # LoRA path: container should have environment variables set + env_vars = container.get("Environment", {}) + assert len(env_vars) > 0, ( + "LoRA container should have environment variables set" + ) + finally: + try: + model.delete() + except Exception: + pass + + def test_build_non_lora_model_package(self, sm_client): + """Build from a non-LoRA ModelPackage: verify no adapter URI and env vars in container.""" + from sagemaker.core.resources import ModelPackage + from sagemaker.serve import ModelBuilder + + model_package = ModelPackage.get(model_package_name=NON_LORA_MODEL_PACKAGE_ARN) + model_builder = ModelBuilder(model=model_package) + model_builder.accept_eula = True + + # Verify _fetch_peft does NOT detect LoRA + peft_type = model_builder._fetch_peft() + assert peft_type is None, ( + f"Expected None but got '{peft_type}' for {NON_LORA_MODEL_PACKAGE_ARN}" + ) + + model_name = f"integ-nonlora-test-{int(time.time())}-{random.randint(100, 10000)}" + model = model_builder.build(model_name=model_name) + + try: + assert model is not None + assert model.model_arn is not None + + # Verify no adapter URI is set + adapter_uri = getattr(model_builder, "_adapter_s3_uri", None) + assert adapter_uri is None, ( + f"_adapter_s3_uri should not be set for non-LoRA, got: {adapter_uri}" + ) + + # Use boto3 to validate the actual model configuration + describe_resp = sm_client.describe_model(ModelName=model_name) + containers = describe_resp.get("Containers", []) + assert len(containers) == 1, f"Expected 1 container, got {len(containers)}" + + container = containers[0] + + # Non-LoRA path: container should have environment variables (bug fix validation) + env_vars = container.get("Environment", {}) + assert len(env_vars) > 0, ( + "Non-LoRA container should have environment variables set (env_vars bug fix)" + ) + + # Non-LoRA path: container should point at training output S3 URI + model_data = container.get("ModelDataSource", {}) + s3_source = model_data.get("S3DataSource", {}) + s3_uri = s3_source.get("S3Uri", "") + assert s3_uri, "Non-LoRA container should have an S3 model data source" + finally: + try: + model.delete() + except Exception: + pass diff --git a/sagemaker-serve/tests/unit/test_model_package_peft_detection.py b/sagemaker-serve/tests/unit/test_model_package_peft_detection.py new file mode 100644 index 0000000000..3fc18d07b6 --- /dev/null +++ b/sagemaker-serve/tests/unit/test_model_package_peft_detection.py @@ -0,0 +1,222 @@ +"""Unit tests for ModelPackage LoRA detection in _fetch_peft() and related paths. + +Tests verify that: +1. _fetch_peft() returns "LORA" for ModelPackage with lora recipe name +2. _fetch_peft() returns None for ModelPackage with non-lora recipe name +3. _fetch_peft() returns None for ModelPackage with no recipe name +4. _adapter_s3_uri is correctly set from ModelPackage container S3 URI +5. env vars are applied in the non-LoRA ContainerDefinition path +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.resources import ModelPackage, TrainingJob + + +class TestModelPackagePeftDetection: + """Test _fetch_peft() behavior with ModelPackage input.""" + + def _create_model_package_mock(self, recipe_name=None): + """Helper to create a mock ModelPackage with a given recipe name.""" + mock_package = Mock(spec=ModelPackage) + mock_container = Mock() + mock_container.base_model = Mock() + mock_container.base_model.recipe_name = recipe_name + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + return mock_package + + def test_fetch_peft_returns_lora_for_lora_recipe(self): + """_fetch_peft() returns 'LORA' when recipe name contains 'lora'.""" + mock_package = self._create_model_package_mock( + recipe_name="verl-grpo-rlvr-qwen-3-32b-lora" + ) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() == "LORA" + + def test_fetch_peft_returns_lora_case_insensitive(self): + """_fetch_peft() matches 'lora' case-insensitively.""" + mock_package = self._create_model_package_mock( + recipe_name="some-model-LoRA-adapter" + ) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() == "LORA" + + def test_fetch_peft_returns_none_for_fft_recipe(self): + """_fetch_peft() returns None when recipe name does not contain 'lora'.""" + mock_package = self._create_model_package_mock( + recipe_name="verl-grpo-rlvr-qwen-3-32b-fft" + ) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + def test_fetch_peft_returns_none_for_no_recipe_name(self): + """_fetch_peft() returns None when recipe name is None.""" + mock_package = self._create_model_package_mock(recipe_name=None) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + def test_fetch_peft_returns_none_when_base_model_missing(self): + """_fetch_peft() returns None when base_model attribute is missing.""" + mock_package = Mock(spec=ModelPackage) + mock_container = Mock() + mock_container.base_model = None + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + def test_fetch_peft_returns_none_when_containers_empty(self): + """_fetch_peft() returns None when containers list is empty.""" + mock_package = Mock(spec=ModelPackage) + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [] + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + +class TestModelPackageAdapterS3Uri: + """Test _adapter_s3_uri is correctly set from ModelPackage.""" + + @patch.object(ModelBuilder, "_fetch_model_package_arn") + @patch.object(ModelBuilder, "_fetch_model_package") + @patch.object(ModelBuilder, "_fetch_peft") + @patch.object(ModelBuilder, "_fetch_hub_document_for_custom_model") + @patch.object(ModelBuilder, "_fetch_and_cache_recipe_config") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) + @patch.object(ModelBuilder, "_is_model_customization") + @patch("sagemaker.core.resources.Model.create") + def test_adapter_s3_uri_set_from_model_package( + self, + mock_model_create, + mock_is_customization, + mock_is_nova_model, + mock_fetch_and_cache_recipe, + mock_fetch_hub, + mock_fetch_peft, + mock_fetch_package, + mock_fetch_package_arn, + ): + """_adapter_s3_uri is set from ModelPackage container S3 URI for LORA.""" + mock_is_customization.return_value = True + mock_fetch_peft.return_value = "LORA" + + expected_adapter_uri = "s3://bucket/adapter-weights/" + + mock_package = Mock(spec=ModelPackage) + mock_package.model_package_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/test-package" + ) + mock_container = Mock() + mock_container.base_model = Mock() + mock_container.base_model.recipe_name = "verl-grpo-rlvr-qwen-3-32b-lora" + mock_container.model_data_source = Mock() + mock_container.model_data_source.s3_data_source = Mock() + mock_container.model_data_source.s3_data_source.s3_uri = expected_adapter_uri + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + mock_fetch_package.return_value = mock_package + mock_fetch_package_arn.return_value = mock_package.model_package_arn + + mock_fetch_hub.return_value = { + "HostingArtifactUri": "s3://jumpstart-bucket/base-model-artifacts/" + } + + mock_model_create.return_value = Mock() + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + builder.accept_eula = True + builder._build_single_modelbuilder() + + assert builder._adapter_s3_uri == expected_adapter_uri + + +class TestNonLoraEnvVars: + """Test env vars are applied in the non-LoRA ContainerDefinition path.""" + + @patch.object(ModelBuilder, "_fetch_model_package_arn") + @patch.object(ModelBuilder, "_fetch_model_package") + @patch.object(ModelBuilder, "_fetch_peft") + @patch.object(ModelBuilder, "_fetch_and_cache_recipe_config") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) + @patch.object(ModelBuilder, "_is_model_customization") + @patch("sagemaker.core.resources.Model.create") + def test_env_vars_passed_to_non_lora_container_def( + self, + mock_model_create, + mock_is_customization, + mock_is_nova_model, + mock_fetch_and_cache_recipe, + mock_fetch_peft, + mock_fetch_package, + mock_fetch_package_arn, + ): + """Non-LoRA ContainerDefinition includes environment vars.""" + mock_is_customization.return_value = True + mock_fetch_peft.return_value = None # Not LORA + + mock_package = Mock(spec=ModelPackage) + mock_package.model_package_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/test-package" + ) + mock_container = Mock() + mock_container.base_model = Mock() + mock_container.base_model.recipe_name = "verl-grpo-rlvr-qwen-3-32b-fft" + mock_container.model_data_source = Mock() + mock_container.model_data_source.s3_data_source = Mock() + mock_container.model_data_source.s3_data_source.s3_uri = "s3://bucket/model/" + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + mock_fetch_package.return_value = mock_package + mock_fetch_package_arn.return_value = mock_package.model_package_arn + + mock_model_create.return_value = Mock() + + expected_env = {"SM_MODEL_ID": "test-model", "CUSTOM_VAR": "value"} + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + env_vars=expected_env, + ) + builder._build_single_modelbuilder() + + # Verify Model.create was called and the container has environment set + assert mock_model_create.called + create_call = mock_model_create.call_args + containers = create_call[1].get("containers", []) + assert len(containers) == 1 + container_def = containers[0] + assert container_def.environment == expected_env