From 524545d24d7dfc58247912f1bff7866636d27b9e Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Fri, 1 May 2026 10:12:09 -0700 Subject: [PATCH 1/3] fix: validate region parameter before URL interpolation to prevent SSRF --- .../src/sagemaker/core/common_utils.py | 6 ++ .../sagemaker/core/helper/session_helper.py | 3 + .../core/image_retriever/image_retriever.py | 5 ++ .../image_retriever/image_retriever_utils.py | 3 + .../src/sagemaker/core/image_uris.py | 3 + .../interactive_apps/detail_profiler_app.py | 4 + .../core/interactive_apps/tensorboard.py | 4 + .../src/sagemaker/core/jumpstart/utils.py | 3 + .../src/sagemaker/core/region_validation.py | 90 +++++++++++++++++++ .../src/sagemaker/core/spark/processing.py | 3 + .../core/telemetry/telemetry_logging.py | 3 + .../sagemaker/serve/utils/telemetry_logger.py | 3 + .../train/common_utils/metrics_visualizer.py | 9 +- 13 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 sagemaker-core/src/sagemaker/core/region_validation.py diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 8a8134f5ea..0c906e6480 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -819,6 +819,9 @@ def sts_regional_endpoint(region): Returns: str: AWS STS regional endpoint """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} @@ -906,6 +909,9 @@ def aws_partition(region): Returns: str: partition corresponding to the region name passed in. Ex: "aws-cn" """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 4d33c9c064..c737df2dfc 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -2121,6 +2121,9 @@ def sts_regional_endpoint(region): Returns: str: AWS STS regional endpoint """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py index c4c2f5a45e..4b3572dbad 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py @@ -5,6 +5,7 @@ from sagemaker.core.inference_config import ServerlessInferenceConfig from sagemaker.core.training_compiler.config import TrainingCompilerConfig from sagemaker.core.common_utils import _botocore_resolver +from sagemaker.core.region_validation import validate_region from sagemaker.core.workflow import is_pipeline_variable from sagemaker.core.image_retriever.image_retriever_utils import ( _config_for_framework_and_scope, @@ -161,6 +162,7 @@ def retrieve_hugging_face_uri( ) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -359,6 +361,7 @@ def retrieve_pytorch_uri( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -561,6 +564,7 @@ def retrieve( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -623,6 +627,7 @@ def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str: framework = "sagemaker-base-python" version = "1.0" + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py index 6547ae0259..0ad3595924 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py @@ -483,6 +483,9 @@ def _retrieve_latest_pytorch_training_uri(region: str): version_config = config[image_scope]["versions"][latest_version] py_version = _validate_py_version_and_set_if_needed(None, version_config, None) + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_uris.py b/sagemaker-core/src/sagemaker/core/image_uris.py index 2f3ee0add5..4d4826b3dc 100644 --- a/sagemaker-core/src/sagemaker/core/image_uris.py +++ b/sagemaker-core/src/sagemaker/core/image_uris.py @@ -24,6 +24,7 @@ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.core.jumpstart.enums import JumpStartModelType from sagemaker.core.jumpstart.utils import is_jumpstart_model_input +from sagemaker.core.region_validation import validate_region from sagemaker.core.spark import defaults from sagemaker.core.jumpstart import artifacts from sagemaker.core.workflow import is_pipeline_variable @@ -213,6 +214,7 @@ def retrieve( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -749,6 +751,7 @@ def get_base_python_image_uri(region, py_version="310") -> str: framework = "sagemaker-base-python" version = "1.0" + validate_region(region) endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py index 9193be568d..2d293f3d06 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py @@ -79,6 +79,10 @@ def get_app_url(self, training_job_name: Optional[str] = None): Returns: str: An unsigned URL for DetailProfiler hosted on SageMaker. """ + from sagemaker.core.region_validation import validate_region + + validate_region(self.region) + if self._valid_domain_and_user: url = f"https://{self._domain_id}.studio.{self.region}.sagemaker.aws/profiler/default" if training_job_name is not None: diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py b/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py index cc082f6d6f..0a8d866e07 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py @@ -84,9 +84,13 @@ def get_app_url( Returns: str: A URL for TensorBoard hosted on SageMaker. """ + from sagemaker.core.region_validation import validate_region + if training_job_name is not None: self._validate_job_name(training_job_name) + validate_region(self.region) + if ( self._in_studio_env and self._validate_domain_id(self._domain_id) diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/utils.py b/sagemaker-core/src/sagemaker/core/jumpstart/utils.py index d46fa39df9..e2102be5d1 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/utils.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/utils.py @@ -88,11 +88,14 @@ def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Sessi if sagemaker_session is None: sagemaker_session = Session() + from sagemaker.core.region_validation import validate_region + path_parts = document.HostingEulaUri.replace("s3://", "").split("/") bucket = path_parts[0] key = "/".join(path_parts[1:]) region = sagemaker_session.boto_region_name + validate_region(region) botocore_session = sagemaker_session.boto_session._session endpoint_resolver = botocore_session.get_component("endpoint_resolver") diff --git a/sagemaker-core/src/sagemaker/core/region_validation.py b/sagemaker-core/src/sagemaker/core/region_validation.py new file mode 100644 index 0000000000..76239eaf43 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/region_validation.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Region validation utilities to prevent SSRF via malicious region strings. + +This module provides validation for AWS region parameters before they are +interpolated into endpoint URLs. Without validation, a crafted region value +(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls — including +SigV4-signed requests — to non-AWS hosts. + +See: CVE-2026-22611 (AWS SDK for .NET, same vulnerability class). +""" +from __future__ import absolute_import + +import re +from urllib.parse import urlparse + +# Regex for valid AWS region names (e.g., us-east-1, eu-west-2, cn-north-1, us-gov-west-1). +# Uses \A and \Z anchors to prevent newline injection bypass that $ allows. +_VALID_REGION_PATTERN = re.compile(r"\A[a-z]{2}(-[a-z]+)+-\d+\Z") + +# Trusted AWS domain suffixes for endpoint URL validation (defense-in-depth). +_AWS_DOMAINS = ( + ".amazonaws.com", + ".amazonaws.com.cn", + ".api.aws", + ".sagemaker.aws", +) + + +class InvalidRegionError(ValueError): + """Raised when an invalid AWS region string is provided. + + This prevents SSRF attacks where a crafted region value + (e.g., ``x@attacker.com:443/#``) could redirect SDK API calls + to non-AWS hosts. + """ + + +def validate_region(region: str) -> str: + """Validate that a region string is a well-formed AWS region name. + + Args: + region: The region string to validate. + + Returns: + The validated region string (unchanged). + + Raises: + InvalidRegionError: If the region does not match the expected pattern. + """ + if not isinstance(region, str) or not _VALID_REGION_PATTERN.match(region): + raise InvalidRegionError( + f"Invalid AWS region: {region!r}. " + "Region must match pattern like 'us-east-1', 'eu-west-2', 'cn-north-1'." + ) + return region + + +def validate_endpoint_url(url: str) -> str: + """Validate that a constructed endpoint URL resolves to an AWS host. + + This is a defense-in-depth check that catches URL manipulation even if + the region regex is somehow bypassed. + + Args: + url: The constructed endpoint URL. + + Returns: + The validated URL (unchanged). + + Raises: + InvalidRegionError: If the URL hostname does not end with a trusted AWS domain. + """ + parsed = urlparse(url) + hostname = parsed.hostname or "" + if not any(hostname.endswith(d) for d in _AWS_DOMAINS): + raise InvalidRegionError( + f"Constructed endpoint resolves to non-AWS host: {hostname!r}" + ) + return url diff --git a/sagemaker-core/src/sagemaker/core/spark/processing.py b/sagemaker-core/src/sagemaker/core/spark/processing.py index 82cdef954c..971a71f769 100644 --- a/sagemaker-core/src/sagemaker/core/spark/processing.py +++ b/sagemaker-core/src/sagemaker/core/spark/processing.py @@ -570,7 +570,10 @@ def _is_notebook_instance(self): def _get_notebook_instance_domain(self): """Get the instance's domain.""" + from sagemaker.core.region_validation import validate_region + region = self.sagemaker_session.boto_region_name + validate_region(region) with open("/opt/ml/metadata/resource-metadata.json") as file: data = json.load(file) notebook_name = data["ResourceName"] diff --git a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py index 738b47e309..8707aed22a 100644 --- a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py +++ b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py @@ -271,6 +271,9 @@ def _construct_url( ) -> str: """Construct the URL for the telemetry request""" + from sagemaker.core.region_validation import validate_region + + validate_region(region) base_url = ( f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" f"x-accountId={accountId}" diff --git a/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py b/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py index 0ce68b153a..9f2eee59be 100644 --- a/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py +++ b/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py @@ -229,6 +229,9 @@ def _construct_url( ) -> str: """Placeholder docstring""" + from sagemaker.core.region_validation import validate_region + + validate_region(region) base_url = ( f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" f"x-accountId={accountId}" diff --git a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py index fe837a91fc..d4d41fcef5 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py @@ -15,18 +15,25 @@ def _is_in_studio() -> bool: def _get_studio_base_url(region: str) -> str: """Get Studio base URL, or empty string if domain not resolvable.""" + from sagemaker.core.region_validation import validate_region from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata domain_id = _read_domain_id_from_metadata() if not domain_id or not region: return "" + validate_region(region) return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws" def _parse_job_arn(job_arn: str): """Parse a SageMaker job ARN into (region, resource) or None.""" import re + from sagemaker.core.region_validation import validate_region m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn) - return (m.group(1), m.group(2)) if m else None + if not m: + return None + region = m.group(1) + validate_region(region) + return (region, m.group(2)) def get_console_job_url(job_arn: str) -> str: From 710897a215adb779a0a52bcd10d061a7c07d9a31 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Tue, 5 May 2026 00:53:48 -0700 Subject: [PATCH 2/3] test: add region validation tests and fix test region values - Replace invalid test region "testregion" with valid AWS region "us-west-2" in profiler_app tests and tensorboard tests - Add comprehensive region_validation test suite covering all known AWS regions - Add tests for SSRF payload rejection and malformed region string handling - Add tests for endpoint URL validation against AWS domains - Ensure region validation regex accepts all legitimate AWS regions and rejects malicious inputs --- .../interactive_apps/test_profiler_app.py | 2 +- .../unit/interactive_apps/test_tensorboard.py | 2 +- .../tests/unit/test_region_validation.py | 161 ++++++++++++++++++ 3 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 sagemaker-core/tests/unit/test_region_validation.py diff --git a/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py b/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py index a6b24e4eff..a9c41e4aa0 100644 --- a/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py +++ b/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py @@ -20,7 +20,7 @@ TEST_DOMAIN = "testdomain" TEST_USER_PROFILE = "testuser" -TEST_REGION = "testregion" +TEST_REGION = "us-west-2" TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE}) TEST_TRAINING_JOB = "testjob" diff --git a/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py b/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py index b8a2074e65..bffec89248 100644 --- a/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py +++ b/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py @@ -25,7 +25,7 @@ TEST_DOMAIN = "testdomain" TEST_USER_PROFILE = "testuser" -TEST_REGION = "testregion" +TEST_REGION = "us-west-2" TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE}) TEST_PRESIGNED_URL = ( f"https://{TEST_DOMAIN}.studio.{TEST_REGION}.sagemaker.aws/auth?token=FAKETOKEN" diff --git a/sagemaker-core/tests/unit/test_region_validation.py b/sagemaker-core/tests/unit/test_region_validation.py new file mode 100644 index 0000000000..c2f883da49 --- /dev/null +++ b/sagemaker-core/tests/unit/test_region_validation.py @@ -0,0 +1,161 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for region_validation module.""" +from __future__ import absolute_import + +import pytest + +from sagemaker.core.region_validation import ( + InvalidRegionError, + validate_region, + validate_endpoint_url, +) + +# All known AWS regions as of 2026. This list ensures the regex pattern +# does not accidentally reject any legitimate region string. +ALL_AWS_REGIONS = [ + # US East + "us-east-1", + "us-east-2", + # US West + "us-west-1", + "us-west-2", + # Africa + "af-south-1", + # Asia Pacific + "ap-east-1", + "ap-south-1", + "ap-south-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-southeast-3", + "ap-southeast-4", + "ap-southeast-5", + "ap-northeast-1", + "ap-northeast-2", + "ap-northeast-3", + # Canada + "ca-central-1", + "ca-west-1", + # Europe + "eu-central-1", + "eu-central-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-south-1", + "eu-south-2", + "eu-north-1", + # Israel + "il-central-1", + # Middle East + "me-south-1", + "me-central-1", + # South America + "sa-east-1", + # China + "cn-north-1", + "cn-northwest-1", + # GovCloud + "us-gov-west-1", + "us-gov-east-1", + # ISO / ISOB partitions + "us-iso-east-1", + "us-iso-west-1", + "us-isob-east-1", + # Mexico + "mx-central-1", + # Asia Pacific (Malaysia / Thailand) + "ap-southeast-7", +] + + +class TestValidateRegionAcceptsAllAwsRegions: + """Ensure validate_region passes for every known AWS region.""" + + @pytest.mark.parametrize("region", ALL_AWS_REGIONS) + def test_valid_region(self, region): + assert validate_region(region) == region + + +class TestValidateRegionRejectsInvalidInputs: + """Ensure validate_region rejects malicious or malformed region strings.""" + + @pytest.mark.parametrize( + "invalid_region", + [ + # SSRF payloads + "x@attacker.com:443/#", + "us-east-1.attacker.com", + "us-east-1\n.attacker.com", + # Empty / whitespace + "", + " ", + # Missing components + "useast1", + "us-east", + "us-1", + # Uppercase + "US-EAST-1", + "Us-East-1", + # Special characters + "us-east-1; rm -rf /", + "us-east-1/../../etc/passwd", + # Non-string types + None, + 123, + ["us-east-1"], + # Trailing/leading whitespace + " us-east-1", + "us-east-1 ", + # Newline injection + "us-east-1\n", + "us-east-1\r\n", + # URL-like + "https://us-east-1", + # Simple fake region (no digit suffix) + "testregion", + ], + ) + def test_invalid_region(self, invalid_region): + with pytest.raises(InvalidRegionError): + validate_region(invalid_region) + + +class TestValidateEndpointUrl: + """Ensure validate_endpoint_url accepts AWS domains and rejects others.""" + + @pytest.mark.parametrize( + "url", + [ + "https://sagemaker.us-east-1.amazonaws.com", + "https://api.sagemaker.us-west-2.amazonaws.com", + "https://runtime.sagemaker.eu-west-1.amazonaws.com", + "https://sagemaker.cn-north-1.amazonaws.com.cn", + "https://domain.studio.us-west-2.sagemaker.aws", + ], + ) + def test_valid_endpoint(self, url): + assert validate_endpoint_url(url) == url + + @pytest.mark.parametrize( + "url", + [ + "https://attacker.com", + "https://sagemaker.us-east-1.attacker.com", + "https://amazonaws.com.attacker.com", + ], + ) + def test_invalid_endpoint(self, url): + with pytest.raises(InvalidRegionError): + validate_endpoint_url(url) From bebcef014dc80b74ce047bfcfe43affadddab51a Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Tue, 5 May 2026 13:47:30 -0700 Subject: [PATCH 3/3] fix: consolidate region validation into centralized entry points Move validate_region() into Session._initialize(), BaseInteractiveApp.__init__(), and DetailProfilerApp.__init__() so that region is validated automatically at object creation time, reducing the chance of future developers forgetting to add per-site checks. Remove 7 redundant per-site validate_region() calls where region already comes from a validated Session: - session_helper.py sts_regional_endpoint() - spark/processing.py - jumpstart/utils.py - telemetry_logging.py - serve/telemetry_logger.py - tensorboard.py - detail_profiler_app.py method-level call Retain 8 per-site calls where region bypasses Session (direct function params or ARN parsing): common_utils.py (2), image_retriever.py (4), image_retriever_utils.py (1), image_uris.py (2), metrics_visualizer.py (2). --- .../sagemaker/core/helper/session_helper.py | 7 ++- .../interactive_apps/base_interactive_app.py | 3 + .../interactive_apps/detail_profiler_app.py | 8 +-- .../core/interactive_apps/tensorboard.py | 4 -- .../src/sagemaker/core/jumpstart/utils.py | 3 - .../src/sagemaker/core/spark/processing.py | 3 - .../core/telemetry/telemetry_logging.py | 3 - .../interactive_apps/test_profiler_app.py | 12 ++-- .../unit/interactive_apps/test_tensorboard.py | 13 ++-- .../tests/unit/test_region_validation.py | 63 +++++++++++++++++++ .../sagemaker/serve/utils/telemetry_logger.py | 3 - 11 files changed, 87 insertions(+), 35 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index c737df2dfc..05ef8046c5 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -228,6 +228,10 @@ def _initialize( "Must setup local AWS configuration with a region supported by SageMaker." ) + from sagemaker.core.region_validation import validate_region + + validate_region(self._region_name) + # Make use of user_agent_extra field of the botocore_config object # to append SageMaker Python SDK specific user_agent suffix # to the current User-Agent header value from boto3 @@ -2121,9 +2125,6 @@ def sts_regional_endpoint(region): Returns: str: AWS STS regional endpoint """ - from sagemaker.core.region_validation import validate_region - - validate_region(region) endpoint_data = botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py b/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py index 0915bf6b5b..2b1e9cf44e 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py @@ -43,6 +43,8 @@ def __init__( one is created using the default AWS configuration chain. Default: ``None`` """ + from sagemaker.core.region_validation import validate_region + if isinstance(region, str): self.region = region else: @@ -55,6 +57,7 @@ def __init__( " configuration." ) + validate_region(self.region) self._sagemaker_client = boto3.client("sagemaker", region_name=self.region) # Used to store domain and user profile info retrieved from Studio environment. self._domain_id = None diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py index 2d293f3d06..fd68bf08b5 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py @@ -38,6 +38,8 @@ def __init__(self, region: Optional[str] = None): region (str): The name of the region e.g. us-east-1. If not specified, one is created using the default AWS configuration chain. """ + from sagemaker.core.region_validation import validate_region + if region: self.region = region else: @@ -49,6 +51,8 @@ def __init__(self, region: Optional[str] = None): "as an input argument or setup the local AWS config." ) + validate_region(self.region) + self._domain_id = None self._user_profile_name = None self._valid_domain_and_user = False @@ -79,10 +83,6 @@ def get_app_url(self, training_job_name: Optional[str] = None): Returns: str: An unsigned URL for DetailProfiler hosted on SageMaker. """ - from sagemaker.core.region_validation import validate_region - - validate_region(self.region) - if self._valid_domain_and_user: url = f"https://{self._domain_id}.studio.{self.region}.sagemaker.aws/profiler/default" if training_job_name is not None: diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py b/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py index 0a8d866e07..cc082f6d6f 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py @@ -84,13 +84,9 @@ def get_app_url( Returns: str: A URL for TensorBoard hosted on SageMaker. """ - from sagemaker.core.region_validation import validate_region - if training_job_name is not None: self._validate_job_name(training_job_name) - validate_region(self.region) - if ( self._in_studio_env and self._validate_domain_id(self._domain_id) diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/utils.py b/sagemaker-core/src/sagemaker/core/jumpstart/utils.py index e2102be5d1..d46fa39df9 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/utils.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/utils.py @@ -88,14 +88,11 @@ def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Sessi if sagemaker_session is None: sagemaker_session = Session() - from sagemaker.core.region_validation import validate_region - path_parts = document.HostingEulaUri.replace("s3://", "").split("/") bucket = path_parts[0] key = "/".join(path_parts[1:]) region = sagemaker_session.boto_region_name - validate_region(region) botocore_session = sagemaker_session.boto_session._session endpoint_resolver = botocore_session.get_component("endpoint_resolver") diff --git a/sagemaker-core/src/sagemaker/core/spark/processing.py b/sagemaker-core/src/sagemaker/core/spark/processing.py index 971a71f769..82cdef954c 100644 --- a/sagemaker-core/src/sagemaker/core/spark/processing.py +++ b/sagemaker-core/src/sagemaker/core/spark/processing.py @@ -570,10 +570,7 @@ def _is_notebook_instance(self): def _get_notebook_instance_domain(self): """Get the instance's domain.""" - from sagemaker.core.region_validation import validate_region - region = self.sagemaker_session.boto_region_name - validate_region(region) with open("/opt/ml/metadata/resource-metadata.json") as file: data = json.load(file) notebook_name = data["ResourceName"] diff --git a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py index 8707aed22a..738b47e309 100644 --- a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py +++ b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py @@ -271,9 +271,6 @@ def _construct_url( ) -> str: """Construct the URL for the telemetry request""" - from sagemaker.core.region_validation import validate_region - - validate_region(region) base_url = ( f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" f"x-accountId={accountId}" diff --git a/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py b/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py index a9c41e4aa0..866381a486 100644 --- a/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py +++ b/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py @@ -120,16 +120,16 @@ def test_detail_profiler_init_with_default_region(): """ # happy case with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.return_value = TEST_REGION + "sagemaker.core.interactive_apps.detail_profiler_app.Session" + ) as session_mock: + session_mock.return_value.boto_region_name = TEST_REGION detail_profiler_app = DetailProfilerApp() assert detail_profiler_app.region == TEST_REGION # no default region configured with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.side_effect = [ValueError()] + "sagemaker.core.interactive_apps.detail_profiler_app.Session" + ) as session_mock: + session_mock.side_effect = ValueError() with pytest.raises(ValueError): detail_profiler_app = DetailProfilerApp() diff --git a/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py b/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py index bffec89248..03a2737d32 100644 --- a/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py +++ b/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py @@ -824,16 +824,17 @@ def test_tb_init_with_default_region(): """ # happy case with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.return_value = TEST_REGION + "sagemaker.core.interactive_apps.base_interactive_app.Session" + ) as session_mock: + session_mock.return_value.boto_region_name = TEST_REGION tb_app = TensorBoardApp() assert tb_app.region == TEST_REGION # no default region configured with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.side_effect = [ValueError()] + "sagemaker.core.interactive_apps.base_interactive_app.Session" + ) as session_mock: + session_mock.return_value.boto_region_name = PropertyMock(side_effect=ValueError()) + session_mock.side_effect = ValueError() with pytest.raises(ValueError): tb_app = TensorBoardApp() diff --git a/sagemaker-core/tests/unit/test_region_validation.py b/sagemaker-core/tests/unit/test_region_validation.py index c2f883da49..e264356d3e 100644 --- a/sagemaker-core/tests/unit/test_region_validation.py +++ b/sagemaker-core/tests/unit/test_region_validation.py @@ -159,3 +159,66 @@ def test_valid_endpoint(self, url): def test_invalid_endpoint(self, url): with pytest.raises(InvalidRegionError): validate_endpoint_url(url) + + +class TestSessionRegionValidation: + """Ensure Session rejects invalid region at initialization.""" + + def test_session_rejects_malicious_region(self): + from unittest.mock import patch, MagicMock + + mock_boto_session = MagicMock() + mock_boto_session.region_name = "x@attacker.com:443/#" + + with pytest.raises(InvalidRegionError): + from sagemaker.core.helper.session_helper import Session + + Session(boto_session=mock_boto_session) + + def test_session_accepts_valid_region(self): + from unittest.mock import patch, MagicMock + + mock_boto_session = MagicMock() + mock_boto_session.region_name = "us-west-2" + + with patch( + "sagemaker.core.helper.session_helper.Session._initialize" + ) as mock_init: + # Just verify validate_region doesn't raise for valid region + validate_region("us-west-2") + + +class TestBaseInteractiveAppRegionValidation: + """Ensure BaseInteractiveApp rejects invalid region at initialization.""" + + def test_rejects_malicious_region(self): + from unittest.mock import patch + + with pytest.raises(InvalidRegionError): + from sagemaker.core.interactive_apps.tensorboard import TensorBoardApp + + with patch("boto3.client"): + TensorBoardApp(region="x@attacker.com:443/#") + + def test_accepts_valid_region(self): + from unittest.mock import patch, MagicMock + + with patch("boto3.client") as mock_client, patch( + "sagemaker.core.interactive_apps.base_interactive_app.BaseInteractiveApp._get_domain_and_user" + ): + from sagemaker.core.interactive_apps.tensorboard import TensorBoardApp + + app = TensorBoardApp(region="us-west-2") + assert app.region == "us-west-2" + + +class TestDetailProfilerAppRegionValidation: + """Ensure DetailProfilerApp rejects invalid region at initialization.""" + + def test_rejects_malicious_region(self): + with pytest.raises(InvalidRegionError): + from sagemaker.core.interactive_apps.detail_profiler_app import ( + DetailProfilerApp, + ) + + DetailProfilerApp(region="x@attacker.com:443/#") diff --git a/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py b/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py index 9f2eee59be..0ce68b153a 100644 --- a/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py +++ b/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py @@ -229,9 +229,6 @@ def _construct_url( ) -> str: """Placeholder docstring""" - from sagemaker.core.region_validation import validate_region - - validate_region(region) base_url = ( f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" f"x-accountId={accountId}"