From 9945bd84af805fba44ceeec27e9e8ed512d6fd72 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 6 May 2026 15:30:51 -0700 Subject: [PATCH] Passing kms key in mlops interfaces --- .../mlops/workflow/clarify_check_step.py | 30 ++- .../mlops/workflow/quality_check_step.py | 30 ++- .../integ/test_check_step_kms_propagation.py | 248 ++++++++++++++++++ .../workflow/test_clarify_check_step_kms.py | 205 +++++++++++++++ .../workflow/test_quality_check_step_kms.py | 201 ++++++++++++++ 5 files changed, 692 insertions(+), 22 deletions(-) create mode 100644 sagemaker-mlops/tests/integ/test_check_step_kms_propagation.py create mode 100644 sagemaker-mlops/tests/unit/workflow/test_clarify_check_step_kms.py create mode 100644 sagemaker-mlops/tests/unit/workflow/test_quality_check_step_kms.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/clarify_check_step.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/clarify_check_step.py index 9bd38df275..8f102a13e6 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/clarify_check_step.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/clarify_check_step.py @@ -281,25 +281,33 @@ def arguments(self) -> RequestType: } processing_inputs.append(input_dict) + s3_output_dict = { + "S3Uri": self._processing_params["result_output"].s3_output.s3_uri, + "LocalPath": self._processing_params["result_output"].s3_output.local_path, + "S3UploadMode": self._processing_params["result_output"].s3_output.s3_upload_mode, + } + if self.check_job_config.output_kms_key: + s3_output_dict["KmsKeyId"] = self.check_job_config.output_kms_key + processing_outputs = [{ "OutputName": self._processing_params["result_output"].output_name, - "S3Output": { - "S3Uri": self._processing_params["result_output"].s3_output.s3_uri, - "LocalPath": self._processing_params["result_output"].s3_output.local_path, - "S3UploadMode": self._processing_params["result_output"].s3_output.s3_upload_mode, - } + "S3Output": s3_output_dict, }] - + + cluster_config = { + "InstanceCount": self._baselining_processor.instance_count, + "InstanceType": self._baselining_processor.instance_type, + "VolumeSizeInGB": getattr(self._baselining_processor, 'volume_size_in_gb', 30), + } + if self.check_job_config.volume_kms_key: + cluster_config["VolumeKmsKeyId"] = self.check_job_config.volume_kms_key + request_dict = { "ProcessingInputs": processing_inputs, "ProcessingOutputConfig": {"Outputs": processing_outputs}, "ProcessingJobName": self._baselining_processor._current_job_name or "clarify-job", "ProcessingResources": { - "ClusterConfig": { - "InstanceCount": self._baselining_processor.instance_count, - "InstanceType": self._baselining_processor.instance_type, - "VolumeSizeInGB": getattr(self._baselining_processor, 'volume_size_in_gb', 30), - } + "ClusterConfig": cluster_config, }, "AppSpecification": { "ImageUri": self._baselining_processor.image_uri, diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/quality_check_step.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/quality_check_step.py index e37189f2de..280278be91 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/quality_check_step.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/quality_check_step.py @@ -259,25 +259,33 @@ def arguments(self) -> RequestType: } processing_inputs.append(input_dict) + s3_output_dict = { + "S3Uri": self._baseline_output.s3_output.s3_uri, + "LocalPath": self._baseline_output.s3_output.local_path, + "S3UploadMode": self._baseline_output.s3_output.s3_upload_mode, + } + if self.check_job_config.output_kms_key: + s3_output_dict["KmsKeyId"] = self.check_job_config.output_kms_key + processing_outputs = [{ "OutputName": self._baseline_output.output_name, - "S3Output": { - "S3Uri": self._baseline_output.s3_output.s3_uri, - "LocalPath": self._baseline_output.s3_output.local_path, - "S3UploadMode": self._baseline_output.s3_output.s3_upload_mode, - } + "S3Output": s3_output_dict, }] - + + cluster_config = { + "InstanceCount": self._baselining_processor.instance_count, + "InstanceType": self._baselining_processor.instance_type, + "VolumeSizeInGB": getattr(self._baselining_processor, 'volume_size_in_gb', 30), + } + if self.check_job_config.volume_kms_key: + cluster_config["VolumeKmsKeyId"] = self.check_job_config.volume_kms_key + request_dict = { "ProcessingInputs": processing_inputs, "ProcessingOutputConfig": {"Outputs": processing_outputs}, "ProcessingJobName": self._baselining_processor._current_job_name or "baseline-job", "ProcessingResources": { - "ClusterConfig": { - "InstanceCount": self._baselining_processor.instance_count, - "InstanceType": self._baselining_processor.instance_type, - "VolumeSizeInGB": getattr(self._baselining_processor, 'volume_size_in_gb', 30), - } + "ClusterConfig": cluster_config, }, "AppSpecification": { "ImageUri": self._baselining_processor.image_uri, diff --git a/sagemaker-mlops/tests/integ/test_check_step_kms_propagation.py b/sagemaker-mlops/tests/integ/test_check_step_kms_propagation.py new file mode 100644 index 0000000000..4ec767a80f --- /dev/null +++ b/sagemaker-mlops/tests/integ/test_check_step_kms_propagation.py @@ -0,0 +1,248 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration test for KMS key propagation in check steps. + +This test constructs real QualityCheckStep and ClarifyCheckStep objects using +the actual SDK classes with a real SageMaker Session, then inspects the compiled +step arguments to verify KmsKeyId and VolumeKmsKeyId are present. + +No SageMaker compute resources are launched. The only AWS interaction is a small +S3 put_object for the Clarify analysis config (cleaned up in teardown). + +Prerequisites: + - AWS credentials with S3 read/write access to the default SageMaker bucket. + +Related ticket: V2184920638 +""" +import json +import pytest +import boto3 + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.mlops.workflow.quality_check_step import ( + QualityCheckStep, + DataQualityCheckConfig, +) +from sagemaker.mlops.workflow.clarify_check_step import ( + ClarifyCheckStep, + DataBiasCheckConfig, +) +from sagemaker.mlops.workflow.check_job_config import CheckJobConfig + + +# Use a fake KMS key ARN — we never actually encrypt anything, we just verify +# the key appears in the compiled request dict. +_TEST_OUTPUT_KMS_KEY = "arn:aws:kms:us-west-2:123456789012:key/test-output-key-id" +_TEST_VOLUME_KMS_KEY = "arn:aws:kms:us-west-2:123456789012:key/test-volume-key-id" + +_S3_PREFIX = "integ-test-kms-check-step" + + +@pytest.fixture(scope="module") +def sagemaker_session(): + """Real SageMaker session with AWS credentials.""" + return Session() + + +@pytest.fixture(scope="module") +def role(): + return get_execution_role() + + +@pytest.fixture(scope="module") +def bucket(sagemaker_session): + return sagemaker_session.default_bucket() + + +@pytest.fixture(scope="module") +def s3_client(sagemaker_session): + return boto3.client("s3", region_name=sagemaker_session.boto_region_name) + + +@pytest.fixture +def check_job_config_with_kms(role, sagemaker_session): + """CheckJobConfig with both output and volume KMS keys.""" + return CheckJobConfig( + role=role, + instance_count=1, + instance_type="ml.m5.xlarge", + volume_size_in_gb=30, + volume_kms_key=_TEST_VOLUME_KMS_KEY, + output_kms_key=_TEST_OUTPUT_KMS_KEY, + max_runtime_in_seconds=3600, + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture +def check_job_config_no_kms(role, sagemaker_session): + """CheckJobConfig without KMS keys.""" + return CheckJobConfig( + role=role, + instance_count=1, + instance_type="ml.m5.xlarge", + volume_size_in_gb=30, + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture(autouse=True, scope="module") +def cleanup_s3(bucket, s3_client): + """Clean up any S3 objects created during the test.""" + yield + # Teardown: delete all objects under our test prefix + try: + response = s3_client.list_objects_v2(Bucket=bucket, Prefix=_S3_PREFIX) + if "Contents" in response: + objects = [{"Key": obj["Key"]} for obj in response["Contents"]] + s3_client.delete_objects(Bucket=bucket, Delete={"Objects": objects}) + except Exception: + pass # Best-effort cleanup + + +class TestDataQualityCheckStepKms: + """Verify KMS key propagation in DataQualityCheckStep using real SDK objects.""" + + def _build_step(self, check_job_config, bucket): + """Construct a real DataQualityCheckStep.""" + quality_check_config = DataQualityCheckConfig( + baseline_dataset=f"s3://{bucket}/{_S3_PREFIX}/input/data.csv", + dataset_format={"csv": {"header": True}}, + output_s3_uri=f"s3://{bucket}/{_S3_PREFIX}/output/quality-results", + ) + return QualityCheckStep( + name="TestDataQualityCheck", + quality_check_config=quality_check_config, + check_job_config=check_job_config, + skip_check=False, + fail_on_violation=True, + register_new_baseline=False, + ) + + def test_output_kms_key_in_arguments(self, check_job_config_with_kms, bucket): + """output_kms_key from CheckJobConfig appears as KmsKeyId in S3Output.""" + step = self._build_step(check_job_config_with_kms, bucket) + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" in s3_output, ( + f"Expected KmsKeyId in S3Output but got: {s3_output}" + ) + assert s3_output["KmsKeyId"] == _TEST_OUTPUT_KMS_KEY + + def test_volume_kms_key_in_arguments(self, check_job_config_with_kms, bucket): + """volume_kms_key from CheckJobConfig appears as VolumeKmsKeyId in ClusterConfig.""" + step = self._build_step(check_job_config_with_kms, bucket) + args = step.arguments + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" in cluster_config, ( + f"Expected VolumeKmsKeyId in ClusterConfig but got: {cluster_config}" + ) + assert cluster_config["VolumeKmsKeyId"] == _TEST_VOLUME_KMS_KEY + + def test_no_kms_keys_when_not_configured(self, check_job_config_no_kms, bucket): + """KMS keys are absent from arguments when not set in CheckJobConfig.""" + step = self._build_step(check_job_config_no_kms, bucket) + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" not in s3_output + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" not in cluster_config + + def test_arguments_are_json_serializable(self, check_job_config_with_kms, bucket): + """The compiled arguments dict is valid JSON (required for pipeline definitions).""" + step = self._build_step(check_job_config_with_kms, bucket) + args = step.arguments + + json_str = json.dumps(args, default=str) + parsed = json.loads(json_str) + assert parsed["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["KmsKeyId"] == _TEST_OUTPUT_KMS_KEY + assert parsed["ProcessingResources"]["ClusterConfig"]["VolumeKmsKeyId"] == _TEST_VOLUME_KMS_KEY + + +class TestDataBiasCheckStepKms: + """Verify KMS key propagation in DataBiasCheckStep (ClarifyCheckStep) using real SDK objects.""" + + def _build_step(self, check_job_config, bucket): + """Construct a real DataBiasCheckStep.""" + from sagemaker.core.clarify import DataConfig, BiasConfig + + data_config = DataConfig( + s3_data_input_path=f"s3://{bucket}/{_S3_PREFIX}/input/bias-data.csv", + s3_output_path=f"s3://{bucket}/{_S3_PREFIX}/output/bias-results", + label="target", + dataset_type="text/csv", + ) + bias_config = BiasConfig( + label_values_or_threshold=[1], + facet_name="gender", + facet_values_or_threshold=[0], + ) + clarify_check_config = DataBiasCheckConfig( + data_config=data_config, + data_bias_config=bias_config, + ) + return ClarifyCheckStep( + name="TestDataBiasCheck", + clarify_check_config=clarify_check_config, + check_job_config=check_job_config, + skip_check=False, + fail_on_violation=True, + register_new_baseline=False, + ) + + def test_output_kms_key_in_arguments(self, check_job_config_with_kms, bucket): + """output_kms_key from CheckJobConfig appears as KmsKeyId in S3Output.""" + step = self._build_step(check_job_config_with_kms, bucket) + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" in s3_output, ( + f"Expected KmsKeyId in S3Output but got: {s3_output}" + ) + assert s3_output["KmsKeyId"] == _TEST_OUTPUT_KMS_KEY + + def test_volume_kms_key_in_arguments(self, check_job_config_with_kms, bucket): + """volume_kms_key from CheckJobConfig appears as VolumeKmsKeyId in ClusterConfig.""" + step = self._build_step(check_job_config_with_kms, bucket) + args = step.arguments + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" in cluster_config, ( + f"Expected VolumeKmsKeyId in ClusterConfig but got: {cluster_config}" + ) + assert cluster_config["VolumeKmsKeyId"] == _TEST_VOLUME_KMS_KEY + + def test_no_kms_keys_when_not_configured(self, check_job_config_no_kms, bucket): + """KMS keys are absent from arguments when not set in CheckJobConfig.""" + step = self._build_step(check_job_config_no_kms, bucket) + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" not in s3_output + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" not in cluster_config + + def test_arguments_are_json_serializable(self, check_job_config_with_kms, bucket): + """The compiled arguments dict is valid JSON (required for pipeline definitions).""" + step = self._build_step(check_job_config_with_kms, bucket) + args = step.arguments + + json_str = json.dumps(args, default=str) + parsed = json.loads(json_str) + assert parsed["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["KmsKeyId"] == _TEST_OUTPUT_KMS_KEY + assert parsed["ProcessingResources"]["ClusterConfig"]["VolumeKmsKeyId"] == _TEST_VOLUME_KMS_KEY diff --git a/sagemaker-mlops/tests/unit/workflow/test_clarify_check_step_kms.py b/sagemaker-mlops/tests/unit/workflow/test_clarify_check_step_kms.py new file mode 100644 index 0000000000..eb92734d90 --- /dev/null +++ b/sagemaker-mlops/tests/unit/workflow/test_clarify_check_step_kms.py @@ -0,0 +1,205 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for KMS key propagation in ClarifyCheckStep.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.mlops.workflow.clarify_check_step import ( + ClarifyCheckStep, + DataBiasCheckConfig, + ModelBiasCheckConfig, + ModelExplainabilityCheckConfig, +) +from sagemaker.mlops.workflow.check_job_config import CheckJobConfig + + +_OUTPUT_KMS_KEY = "arn:aws:kms:us-east-1:123456789012:key/output-key-id" +_VOLUME_KMS_KEY = "arn:aws:kms:us-east-1:123456789012:key/volume-key-id" + +# Patch trim_request_dict to be a no-op since it only trims job names +# and is not relevant to KMS key propagation testing. +_TRIM_PATCH = "sagemaker.mlops.workflow.clarify_check_step.trim_request_dict" + + +def _noop_trim(request_dict, *args, **kwargs): + """No-op replacement for trim_request_dict in tests.""" + return request_dict + + +def _create_mock_clarify_check_step(output_kms_key=None, volume_kms_key=None): + """Create a ClarifyCheckStep with mocked internals for testing arguments.""" + step = object.__new__(ClarifyCheckStep) + + # Mock check_job_config + step.check_job_config = Mock() + step.check_job_config.output_kms_key = output_kms_key + step.check_job_config.volume_kms_key = volume_kms_key + + # Mock processing params (config_input, data_input, result_output) + config_input = Mock() + config_input.input_name = "analysis_config" + config_input.s3_input = Mock() + config_input.s3_input.s3_uri = "s3://bucket/config/analysis_config.json" + config_input.s3_input.local_path = "/opt/ml/processing/input/config" + config_input.s3_input.s3_data_type = "S3Prefix" + config_input.s3_input.s3_input_mode = "File" + + data_input = Mock() + data_input.input_name = "dataset" + data_input.s3_input = Mock() + data_input.s3_input.s3_uri = "s3://bucket/input/data.csv" + data_input.s3_input.local_path = "/opt/ml/processing/input/data" + data_input.s3_input.s3_data_type = "S3Prefix" + data_input.s3_input.s3_input_mode = "File" + + result_output = Mock() + result_output.output_name = "analysis_result" + result_output.s3_output = Mock() + result_output.s3_output.s3_uri = "s3://bucket/output/results" + result_output.s3_output.local_path = "/opt/ml/processing/output" + result_output.s3_output.s3_upload_mode = "EndOfJob" + + step._processing_params = { + "config_input": config_input, + "data_input": data_input, + "result_output": result_output, + } + + # Mock baselining processor + step._baselining_processor = Mock() + step._baselining_processor._current_job_name = "clarify-check-job" + step._baselining_processor.instance_count = 1 + step._baselining_processor.instance_type = "ml.m5.xlarge" + step._baselining_processor.volume_size_in_gb = 30 + step._baselining_processor.image_uri = "123456789012.dkr.ecr.us-east-1.amazonaws.com/clarify:latest" + step._baselining_processor.role = "arn:aws:iam::123456789012:role/SageMakerRole" + step._baselining_processor.max_runtime_in_seconds = 3600 + step._baselining_processor.env = None + step._baselining_processor.network_config = None + step._baselining_processor.entrypoint = None + step._baselining_processor.arguments = None + + return step + + +class TestClarifyCheckStepKmsKeyPropagation: + """Tests for KMS key propagation in ClarifyCheckStep.arguments.""" + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_output_kms_key_propagated_in_s3_output(self, mock_trim): + """Test that output_kms_key from CheckJobConfig is included in S3Output.""" + step = _create_mock_clarify_check_step(output_kms_key=_OUTPUT_KMS_KEY) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" in s3_output + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_volume_kms_key_propagated_in_cluster_config(self, mock_trim): + """Test that volume_kms_key from CheckJobConfig is included in ClusterConfig.""" + step = _create_mock_clarify_check_step(volume_kms_key=_VOLUME_KMS_KEY) + + args = step.arguments + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" in cluster_config + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_both_kms_keys_propagated(self, mock_trim): + """Test that both output_kms_key and volume_kms_key are propagated together.""" + step = _create_mock_clarify_check_step( + output_kms_key=_OUTPUT_KMS_KEY, + volume_kms_key=_VOLUME_KMS_KEY, + ) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_no_kms_keys_when_not_set(self, mock_trim): + """Test that KmsKeyId and VolumeKmsKeyId are absent when not configured.""" + step = _create_mock_clarify_check_step(output_kms_key=None, volume_kms_key=None) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" not in s3_output + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" not in cluster_config + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_s3_output_retains_other_fields_with_kms(self, mock_trim): + """Test that S3Output still contains S3Uri, LocalPath, S3UploadMode alongside KmsKeyId.""" + step = _create_mock_clarify_check_step(output_kms_key=_OUTPUT_KMS_KEY) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert s3_output["S3Uri"] == "s3://bucket/output/results" + assert s3_output["LocalPath"] == "/opt/ml/processing/output" + assert s3_output["S3UploadMode"] == "EndOfJob" + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_cluster_config_retains_other_fields_with_kms(self, mock_trim): + """Test that ClusterConfig still contains instance fields alongside VolumeKmsKeyId.""" + step = _create_mock_clarify_check_step(volume_kms_key=_VOLUME_KMS_KEY) + + args = step.arguments + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert cluster_config["InstanceCount"] == 1 + assert cluster_config["InstanceType"] == "ml.m5.xlarge" + assert cluster_config["VolumeSizeInGB"] == 30 + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_output_kms_key_only_without_volume_kms(self, mock_trim): + """Test output_kms_key set but volume_kms_key not set.""" + step = _create_mock_clarify_check_step( + output_kms_key=_OUTPUT_KMS_KEY, volume_kms_key=None + ) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" not in cluster_config + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_volume_kms_key_only_without_output_kms(self, mock_trim): + """Test volume_kms_key set but output_kms_key not set.""" + step = _create_mock_clarify_check_step( + output_kms_key=None, volume_kms_key=_VOLUME_KMS_KEY + ) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" not in s3_output + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY diff --git a/sagemaker-mlops/tests/unit/workflow/test_quality_check_step_kms.py b/sagemaker-mlops/tests/unit/workflow/test_quality_check_step_kms.py new file mode 100644 index 0000000000..1f00611f00 --- /dev/null +++ b/sagemaker-mlops/tests/unit/workflow/test_quality_check_step_kms.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for KMS key propagation in QualityCheckStep.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.mlops.workflow.quality_check_step import ( + QualityCheckStep, + DataQualityCheckConfig, + ModelQualityCheckConfig, +) +from sagemaker.mlops.workflow.check_job_config import CheckJobConfig + + +_OUTPUT_KMS_KEY = "arn:aws:kms:us-east-1:123456789012:key/output-key-id" +_VOLUME_KMS_KEY = "arn:aws:kms:us-east-1:123456789012:key/volume-key-id" + +# Patch trim_request_dict to be a no-op since it only trims job names +# and is not relevant to KMS key propagation testing. +_TRIM_PATCH = "sagemaker.mlops.workflow.quality_check_step.trim_request_dict" + + +def _noop_trim(request_dict, *args, **kwargs): + """No-op replacement for trim_request_dict in tests.""" + return request_dict + + +def _create_mock_quality_check_step(output_kms_key=None, volume_kms_key=None): + """Create a QualityCheckStep with mocked internals for testing arguments.""" + step = object.__new__(QualityCheckStep) + + # Mock check_job_config + step.check_job_config = Mock() + step.check_job_config.output_kms_key = output_kms_key + step.check_job_config.volume_kms_key = volume_kms_key + + # Mock baseline output + step._baseline_output = Mock() + step._baseline_output.output_name = "quality-check-output" + step._baseline_output.s3_output = Mock() + step._baseline_output.s3_output.s3_uri = "s3://bucket/output" + step._baseline_output.s3_output.local_path = "/opt/ml/processing/output" + step._baseline_output.s3_output.s3_upload_mode = "EndOfJob" + + # Mock baseline job inputs + mock_input = Mock() + mock_input.input_name = "baseline-dataset" + mock_input.s3_input = Mock() + mock_input.s3_input.s3_uri = "s3://bucket/input/data.csv" + mock_input.s3_input.local_path = "/opt/ml/processing/input" + mock_input.s3_input.s3_data_type = "S3Prefix" + mock_input.s3_input.s3_input_mode = "File" + step._baseline_job_inputs = [mock_input] + + # Mock baselining processor + step._baselining_processor = Mock() + step._baselining_processor._current_job_name = "quality-check-job" + step._baselining_processor.instance_count = 1 + step._baselining_processor.instance_type = "ml.m5.xlarge" + step._baselining_processor.volume_size_in_gb = 30 + step._baselining_processor.image_uri = "123456789012.dkr.ecr.us-east-1.amazonaws.com/monitor:latest" + step._baselining_processor.role = "arn:aws:iam::123456789012:role/SageMakerRole" + step._baselining_processor.max_runtime_in_seconds = 3600 + step._baselining_processor.env = None + step._baselining_processor.network_config = None + step._baselining_processor.entrypoint = None + step._baselining_processor.arguments = None + + # Mock step properties needed by arguments + step.skip_check = False + step.fail_on_violation = True + step.register_new_baseline = False + step.supplied_baseline_statistics = None + step.supplied_baseline_constraints = None + + return step + + +class TestQualityCheckStepKmsKeyPropagation: + """Tests for KMS key propagation in QualityCheckStep.arguments.""" + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_output_kms_key_propagated_in_s3_output(self, mock_trim): + """Test that output_kms_key from CheckJobConfig is included in S3Output.""" + step = _create_mock_quality_check_step(output_kms_key=_OUTPUT_KMS_KEY) + + args = step.arguments + + # Verify KmsKeyId is present in S3Output + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" in s3_output + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_volume_kms_key_propagated_in_cluster_config(self, mock_trim): + """Test that volume_kms_key from CheckJobConfig is included in ClusterConfig.""" + step = _create_mock_quality_check_step(volume_kms_key=_VOLUME_KMS_KEY) + + args = step.arguments + + # Verify VolumeKmsKeyId is present in ClusterConfig + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" in cluster_config + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_both_kms_keys_propagated(self, mock_trim): + """Test that both output_kms_key and volume_kms_key are propagated together.""" + step = _create_mock_quality_check_step( + output_kms_key=_OUTPUT_KMS_KEY, + volume_kms_key=_VOLUME_KMS_KEY, + ) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_no_kms_keys_when_not_set(self, mock_trim): + """Test that KmsKeyId and VolumeKmsKeyId are absent when not configured.""" + step = _create_mock_quality_check_step(output_kms_key=None, volume_kms_key=None) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" not in s3_output + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" not in cluster_config + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_s3_output_retains_other_fields_with_kms(self, mock_trim): + """Test that S3Output still contains S3Uri, LocalPath, S3UploadMode alongside KmsKeyId.""" + step = _create_mock_quality_check_step(output_kms_key=_OUTPUT_KMS_KEY) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert s3_output["S3Uri"] == "s3://bucket/output" + assert s3_output["LocalPath"] == "/opt/ml/processing/output" + assert s3_output["S3UploadMode"] == "EndOfJob" + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_cluster_config_retains_other_fields_with_kms(self, mock_trim): + """Test that ClusterConfig still contains instance fields alongside VolumeKmsKeyId.""" + step = _create_mock_quality_check_step(volume_kms_key=_VOLUME_KMS_KEY) + + args = step.arguments + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert cluster_config["InstanceCount"] == 1 + assert cluster_config["InstanceType"] == "ml.m5.xlarge" + assert cluster_config["VolumeSizeInGB"] == 30 + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_output_kms_key_only_without_volume_kms(self, mock_trim): + """Test output_kms_key set but volume_kms_key not set.""" + step = _create_mock_quality_check_step( + output_kms_key=_OUTPUT_KMS_KEY, volume_kms_key=None + ) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert s3_output["KmsKeyId"] == _OUTPUT_KMS_KEY + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert "VolumeKmsKeyId" not in cluster_config + + @patch(_TRIM_PATCH, side_effect=_noop_trim) + def test_volume_kms_key_only_without_output_kms(self, mock_trim): + """Test volume_kms_key set but output_kms_key not set.""" + step = _create_mock_quality_check_step( + output_kms_key=None, volume_kms_key=_VOLUME_KMS_KEY + ) + + args = step.arguments + + s3_output = args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"] + assert "KmsKeyId" not in s3_output + + cluster_config = args["ProcessingResources"]["ClusterConfig"] + assert cluster_config["VolumeKmsKeyId"] == _VOLUME_KMS_KEY