From f7c80d871e6aa94d42cdac9ecc1232771bc0a579 Mon Sep 17 00:00:00 2001 From: LeSingh1 Date: Sun, 31 May 2026 14:43:31 -0700 Subject: [PATCH 1/2] Constrain APEX_ASP_CACHE_DIR to a safe base directory The ASP permutation cache read APEX_ASP_CACHE_DIR directly and used it as the np.save() destination with no validation, so an externally controlled env var could redirect cache writes to an arbitrary writable location (CWE-22/CWE-73), enabling cache poisoning or file overwrite. Resolve the requested cache dir and require it to stay within the default cache base, otherwise warn and fall back to the safe default. Add regression tests covering a rejected traversal attempt and a normal in-base path. Signed-off-by: LeSingh1 --- .../exhaustive_search.py | 24 +++- .../sparsity/test/test_asp_cache_path.py | 108 ++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 apex/contrib/sparsity/test/test_asp_cache_path.py diff --git a/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py b/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py index 1fc168650..fc92788c9 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py +++ b/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py @@ -73,8 +73,30 @@ def generate_unique_combinations( unique_permutation_list = {} +def _resolve_cache_dir(): + # APEX_ASP_CACHE_DIR is externally controlled and is used as the write + # destination for the permutation cache. An attacker who can set the env var + # (e.g. via a CI/job wrapper) could otherwise redirect the np.save() below to + # an arbitrary writable location and poison or overwrite files there + # (CWE-22 / CWE-73). Constrain the cache dir to the default base directory and + # fall back to that safe default when the requested path escapes it. + allowed_base = path.realpath(ASP_CACHE_DIR_DEFAULT) + requested = os.getenv(ASP_CACHE_DIR_ENV_VAR) + if requested is None: + return allowed_base + + resolved = path.realpath(requested) + if resolved != allowed_base and path.commonpath([allowed_base, resolved]) != allowed_base: + print( + f"[ASP][Warning] {ASP_CACHE_DIR_ENV_VAR}={requested!r} resolves outside the " + f"allowed cache base {allowed_base!r}; falling back to the default cache dir." + ) + return allowed_base + return resolved + + def generate_all_unique_combinations(C, M, must_use_all_groups=False): - cache_dir_path = os.getenv(ASP_CACHE_DIR_ENV_VAR, ASP_CACHE_DIR_DEFAULT) + cache_dir_path = _resolve_cache_dir() cache_file_path = path.join(cache_dir_path, f"permutations_{C}_{M}.npy") global unique_permutation_list diff --git a/apex/contrib/sparsity/test/test_asp_cache_path.py b/apex/contrib/sparsity/test/test_asp_cache_path.py new file mode 100644 index 000000000..08474a462 --- /dev/null +++ b/apex/contrib/sparsity/test/test_asp_cache_path.py @@ -0,0 +1,108 @@ +"""Regression tests for the APEX_ASP_CACHE_DIR path-traversal fix (CWE-22/CWE-73). + +These exercise apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py +directly. They do not require CUDA: the permutation generation falls back to a pure +CPU/numpy path when the search kernels are not built. +""" + +import importlib +import os +import sys +import types +import unittest + +import numpy as np + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +def _load_exhaustive_search(): + # Importing apex.contrib.sparsity normally pulls in torchvision via asp.py, which + # is unrelated to the cache-path code under test. Register lightweight namespace + # packages so the kernels subpackage's relative imports resolve without it. + for name, rel in [ + ("apex", "apex"), + ("apex.contrib", "apex/contrib"), + ("apex.contrib.sparsity", "apex/contrib/sparsity"), + ]: + if name not in sys.modules: + mod = types.ModuleType(name) + mod.__path__ = [os.path.join(REPO_ROOT, rel)] + sys.modules[name] = mod + return importlib.import_module( + "apex.contrib.sparsity.permutation_search_kernels.exhaustive_search" + ) + + +class TestAspCachePath(unittest.TestCase): + def setUp(self): + self.mod = _load_exhaustive_search() + self._prev_env = os.environ.get(self.mod.ASP_CACHE_DIR_ENV_VAR) + self._prev_cwd = os.getcwd() + self._tmp = self._mkdtemp() + # Run from a temp dir so the default ".cache" base lives there. + os.chdir(self._tmp) + # Clear the module-level memoization so each test actually writes. + self.mod.unique_permutation_list = {} + + def tearDown(self): + os.chdir(self._prev_cwd) + if self._prev_env is None: + os.environ.pop(self.mod.ASP_CACHE_DIR_ENV_VAR, None) + else: + os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = self._prev_env + + def _mkdtemp(self): + import tempfile + + return os.path.realpath(tempfile.mkdtemp()) + + def test_normal_path_within_base_is_used(self): + # A cache dir nested under the allowed default base is honored. + allowed_base = os.path.realpath(self.mod.ASP_CACHE_DIR_DEFAULT) + nested = os.path.join(allowed_base, "sub") + os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = nested + + result = self.mod.generate_all_unique_combinations(4, 4) + + self.assertTrue(len(result) >= 1) + expected_file = os.path.join(nested, "permutations_4_4.npy") + self.assertTrue(os.path.exists(expected_file), f"missing {expected_file}") + + def test_traversal_attempt_is_rejected(self): + # An attacker-controlled value pointing outside the allowed base must not be + # used as the write destination; the code falls back to the safe default. + escape_dir = self._mkdtemp() # an arbitrary writable dir outside ".cache" + # ".." traversal that resolves to a fresh, unique location outside the base. + evil_target = os.path.realpath(os.path.join(escape_dir, "..", os.path.basename(escape_dir) + "_evil")) + os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = os.path.join(escape_dir, "..", os.path.basename(escape_dir) + "_evil") + + result = self.mod.generate_all_unique_combinations(4, 4) + + self.assertTrue(len(result) >= 1) + # Nothing was written to the attacker-controlled location. + leaked = os.path.join(evil_target, "permutations_4_4.npy") + self.assertFalse(os.path.exists(leaked), f"write escaped to {leaked}") + # The file landed in the safe default base instead. + safe_file = os.path.join( + os.path.realpath(self.mod.ASP_CACHE_DIR_DEFAULT), "permutations_4_4.npy" + ) + self.assertTrue(os.path.exists(safe_file), f"missing safe write {safe_file}") + + def test_resolve_cache_dir_helper(self): + # Direct unit check on the resolver. + allowed_base = os.path.realpath(self.mod.ASP_CACHE_DIR_DEFAULT) + + os.environ.pop(self.mod.ASP_CACHE_DIR_ENV_VAR, None) + self.assertEqual(self.mod._resolve_cache_dir(), allowed_base) + + os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = "/etc" + self.assertEqual(self.mod._resolve_cache_dir(), allowed_base) + + within = os.path.join(allowed_base, "ok") + os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = within + self.assertEqual(self.mod._resolve_cache_dir(), within) + + +if __name__ == "__main__": + unittest.main() From 29a4faea5649b1c0b3e64a77668d460e009e4772 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 31 May 2026 23:18:29 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apex/contrib/sparsity/test/test_asp_cache_path.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/apex/contrib/sparsity/test/test_asp_cache_path.py b/apex/contrib/sparsity/test/test_asp_cache_path.py index 08474a462..5bd2d6699 100644 --- a/apex/contrib/sparsity/test/test_asp_cache_path.py +++ b/apex/contrib/sparsity/test/test_asp_cache_path.py @@ -74,8 +74,12 @@ def test_traversal_attempt_is_rejected(self): # used as the write destination; the code falls back to the safe default. escape_dir = self._mkdtemp() # an arbitrary writable dir outside ".cache" # ".." traversal that resolves to a fresh, unique location outside the base. - evil_target = os.path.realpath(os.path.join(escape_dir, "..", os.path.basename(escape_dir) + "_evil")) - os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = os.path.join(escape_dir, "..", os.path.basename(escape_dir) + "_evil") + evil_target = os.path.realpath( + os.path.join(escape_dir, "..", os.path.basename(escape_dir) + "_evil") + ) + os.environ[self.mod.ASP_CACHE_DIR_ENV_VAR] = os.path.join( + escape_dir, "..", os.path.basename(escape_dir) + "_evil" + ) result = self.mod.generate_all_unique_combinations(4, 4)