diff --git a/README.md b/README.md index d092cf3..f07c83b 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,22 @@ safety-check pickle files: fickling --check-safety -p pickled.data ``` +### Scanning a directory or glob + +When `PICKLE_FILE` is a directory or a glob pattern, fickling safety-scans every pickle +file it finds (`.pkl`, `.pickle`, `.bin`, and the `.pt`/`.pth` zip archives) and aggregates +the per-file verdicts into a single ClamAV-compatible exit code. Use `-R`/`--recursive` +to walk a directory tree — useful for triaging a downloaded model cache in one command: + +```console +fickling ./model_cache --recursive +fickling "./downloads/**/*.pkl" --recursive +``` + +The exit code is `1` if any file is unsafe, `2` if scans completed but errored on some file +(and none were unsafe), or `0` if everything is clean. Add `-p`/`--print-results` for a +per-file breakdown. + ## Advanced usage ### Trace pickle execution diff --git a/fickling/cli.py b/fickling/cli.py index b4e50ab..5f0dccf 100644 --- a/fickling/cli.py +++ b/fickling/cli.py @@ -3,7 +3,7 @@ import sys from argparse import ArgumentParser from ast import unparse -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath from . import __version__, fickle, tracing from .analysis import Severity, check_safety @@ -149,6 +149,130 @@ def _scan_huggingface( return EXIT_CLEAN if overall_safe else EXIT_UNSAFE +def _is_glob(target: str) -> bool: + """Return True if `target` looks like a shell glob pattern.""" + return any(ch in target for ch in "*?[") + + +def _collect_scan_targets(target: str, recursive: bool) -> tuple[list[Path], bool]: + """Expand a path, directory, or glob into a sorted list of candidate files. + + Args: + target: A file path, directory path, or glob pattern. Patterns express + recursion explicitly with ``**`` (e.g. ``models/**/*.pkl``); the + `recursive` flag governs bare-directory inputs only. + recursive: If True, a directory input is walked depth-first. + + Returns: + A tuple ``(files, resolved)``. ``files`` is the sorted list of existing + regular files. ``resolved`` is False only when `target` matched nothing + on disk (no such file, directory, or glob match) so the caller can emit + EXIT_ERROR rather than a misleading "clean" result. + """ + path = Path(target) + if path.is_dir(): + pattern = "**/*" if recursive else "*" + return sorted(p for p in path.glob(pattern) if p.is_file()), True + # Check for an existing regular file before treating `target` as a glob, so a + # literal filename containing glob metacharacters (e.g. ``data[v1].pkl``) is + # scanned as itself rather than mis-parsed as a pattern that matches nothing. + if path.is_file(): + return [path], True + if _is_glob(target): + # Path.glob requires a relative pattern, so split off any anchor (e.g. + # the leading "/") and glob from there. ``**`` recurses on its own. + anchor = Path(path.anchor) if path.anchor else Path() + pattern = str(path.relative_to(path.anchor)) if path.anchor else target + matches = list(anchor.glob(pattern)) + files = sorted(m for m in matches if m.is_file()) + return files, bool(matches) + return [], False + + +def _scan_paths( + target: str, + recursive: bool = False, + json_output_path: str | None = None, + print_results: bool = False, +) -> int: + """Safety-scan every pickle file under a directory or glob pattern. + + Reuses :func:`scan_file` and :func:`scan_zip_archive` and the + ``HF_PICKLE_EXTENSIONS`` classification to triage a tree of files in one + invocation. Per-file verdicts are aggregated into a single ClamAV-compatible + exit code with the following precedence: any unsafe file -> EXIT_UNSAFE, + otherwise any scan error -> EXIT_ERROR, otherwise EXIT_CLEAN. + + Args: + target: Directory path or glob pattern to scan. + recursive: Walk directories recursively when True. + json_output_path: Optional path to write JSON analysis results. + print_results: Whether to print results to console. + + Returns: + EXIT_CLEAN (0), EXIT_UNSAFE (1), or EXIT_ERROR (2) + """ + files, resolved = _collect_scan_targets(target, recursive) + if not resolved: + sys.stderr.write(f"Error: no such file, directory, or glob match: '{target}'\n") + return EXIT_ERROR + + candidates = [f for f in files if f.suffix.lower() in HF_PICKLE_EXTENSIONS] + if not candidates: + if print_results: + print(f"No scannable pickle files found in '{target}'") + return EXIT_CLEAN + + json_output = json_output_path or DEFAULT_JSON_OUTPUT_FILE + if print_results: + print(f"Scanning {len(candidates)} file(s) under '{target}'...") + + any_unsafe = False + any_error = False + + for filepath in candidates: + path_str = str(filepath) + if filepath.suffix.lower() in HF_ZIP_PICKLE_EXTENSIONS: + member_results = scan_zip_archive(path_str, graceful=True, json_output_path=json_output) + file_results = list(member_results.values()) + else: + file_results = [scan_file(path_str, graceful=True, json_output_path=json_output)] + + if print_results: + print(f"\n Scanning: {path_str}") + + for result in file_results: + if print_results: + for ar in result.results: + result_str = ar.to_string() + if result_str: + print(f" {result_str}") + + if not result.is_safe: + any_unsafe = True + if print_results: + sys.stderr.write(f" WARNING: {path_str} may contain unsafe content!\n") + + if result.errors: + any_error = True + for err in result.errors: + sys.stderr.write(f" {err}\n") + + if print_results: + if any_unsafe: + print(f"\n{target}: Potentially unsafe content detected!") + elif any_error: + print(f"\n{target}: Completed with scan errors; review warnings above") + else: + print(f"\n{target}: No obvious safety issues detected") + + if any_unsafe: + return EXIT_UNSAFE + if any_error: + return EXIT_ERROR + return EXIT_CLEAN + + def main(argv: list[str] | None = None) -> int: if argv is None: argv = sys.argv @@ -234,6 +358,15 @@ def main(argv: list[str] | None = None) -> int: action="store_true", help="print a runtime trace while interpreting the input pickle file", ) + parser.add_argument( + "--recursive", + "-R", + action="store_true", + help="when PICKLE_FILE is a directory, scan it recursively. Directory and " + "glob inputs are always safety-scanned (reusing the --check-safety analysis), " + "aggregating per-file verdicts into a single ClamAV-compatible exit code " + "(unsafe takes precedence over scan errors).", + ) parser.add_argument("--version", "-v", action="store_true", help="print the version and exit") options.add_argument( "--huggingface", @@ -278,6 +411,27 @@ def main(argv: list[str] | None = None) -> int: print_results=args.print_results, ) + # Directory / glob scanning mode (bulk triage). Triggered when PICKLE_FILE is + # a directory, a glob pattern, or --recursive is set, as long as no single-file + # operation (inject/create/trace) was requested. + if ( + args.create is None + and args.inject is None + and not args.trace + and args.PICKLE_FILE != "-" + and ( + args.recursive + or (_is_glob(args.PICKLE_FILE) and not Path(args.PICKLE_FILE).is_file()) + or Path(args.PICKLE_FILE).is_dir() + ) + ): + return _scan_paths( + args.PICKLE_FILE, + recursive=args.recursive, + json_output_path=args.json_output, + print_results=args.print_results, + ) + if args.create is None: if args.PICKLE_FILE == "-": if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None: diff --git a/test/test_cli_recursive.py b/test/test_cli_recursive.py new file mode 100644 index 0000000..0c9d4d9 --- /dev/null +++ b/test/test_cli_recursive.py @@ -0,0 +1,159 @@ +import os +import pickle +import tempfile +import unittest +import zipfile +from contextlib import redirect_stdout +from io import StringIO +from pathlib import Path + +from fickling.cli import main +from fickling.constants import EXIT_CLEAN, EXIT_ERROR, EXIT_UNSAFE + + +class Payload: + """Malicious payload for testing (executes os.system on unpickle).""" + + def __reduce__(self): + return (os.system, ("echo pwned",)) + + +def _write_pickle(path: Path, obj) -> None: + with open(path, "wb") as f: + pickle.dump(obj, f) + + +def _run(*argv) -> int: + """Invoke the CLI and swallow stdout to keep test output clean.""" + with redirect_stdout(StringIO()): + return main(["fickling", *argv]) + + +class TestRecursiveDirectoryScan(unittest.TestCase): + def test_directory_with_unsafe_file_returns_exit_unsafe(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "safe.pkl", [1, 2, 3]) + _write_pickle(root / "evil.pkl", Payload()) + self.assertEqual(_run(str(root)), EXIT_UNSAFE) + + def test_all_safe_directory_returns_exit_clean(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "a.pkl", [1, 2, 3]) + _write_pickle(root / "b.pickle", {"k": "v"}) + _write_pickle(root / "c.bin", 42) + self.assertEqual(_run(str(root)), EXIT_CLEAN) + + def test_directory_without_pickle_files_returns_exit_clean(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + (root / "README.md").write_text("not a pickle") + (root / "config.json").write_text("{}") + self.assertEqual(_run(str(root)), EXIT_CLEAN) + + def test_nonrecursive_does_not_descend_into_subdirs(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "safe.pkl", [1, 2, 3]) + nested = root / "nested" + nested.mkdir() + _write_pickle(nested / "evil.pkl", Payload()) + # Default (non-recursive) must ignore the nested unsafe file. + self.assertEqual(_run(str(root)), EXIT_CLEAN) + + def test_recursive_descends_into_subdirs(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "safe.pkl", [1, 2, 3]) + nested = root / "nested" / "deeper" + nested.mkdir(parents=True) + _write_pickle(nested / "evil.pkl", Payload()) + self.assertEqual(_run(str(root), "--recursive"), EXIT_UNSAFE) + + def test_recursive_short_flag(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + nested = root / "sub" + nested.mkdir() + _write_pickle(nested / "evil.pkl", Payload()) + self.assertEqual(_run(str(root), "-R"), EXIT_UNSAFE) + + def test_nonexistent_target_returns_exit_error(self): + # A directory/glob scan of a path that resolves to nothing reports + # EXIT_ERROR rather than a misleading "clean" result. + self.assertEqual(_run("/nonexistent/directory/xyz", "--recursive"), EXIT_ERROR) + + +class TestGlobScan(unittest.TestCase): + def test_glob_matches_unsafe_file(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "safe.pkl", [1, 2, 3]) + _write_pickle(root / "evil.pkl", Payload()) + self.assertEqual(_run(str(root / "*.pkl")), EXIT_UNSAFE) + + def test_glob_only_safe_matches(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "a.pkl", [1, 2, 3]) + _write_pickle(root / "evil.bin", Payload()) + # Glob restricted to *.pkl must not pick up the unsafe .bin file. + self.assertEqual(_run(str(root / "*.pkl")), EXIT_CLEAN) + + def test_glob_no_matches_returns_exit_error(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + self.assertEqual(_run(str(root / "*.pkl")), EXIT_ERROR) + + def test_recursive_glob(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + nested = root / "a" / "b" + nested.mkdir(parents=True) + _write_pickle(nested / "evil.pkl", Payload()) + pattern = str(root / "**" / "*.pkl") + self.assertEqual(_run(pattern, "--recursive"), EXIT_UNSAFE) + + +class TestZipMemberScan(unittest.TestCase): + def test_pt_archive_with_unsafe_member(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + archive = root / "model.pt" + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr("data.pkl", pickle.dumps(Payload())) + self.assertEqual(_run(str(root)), EXIT_UNSAFE) + + def test_pth_archive_all_safe(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + archive = root / "model.pth" + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr("weights.pkl", pickle.dumps([1, 2, 3])) + self.assertEqual(_run(str(root)), EXIT_CLEAN) + + +class TestSingleFileUnaffected(unittest.TestCase): + def test_recursive_flag_on_single_file(self): + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f: + pickle.dump(Payload(), f) + path = f.name + try: + self.assertEqual(_run(path, "--recursive"), EXIT_UNSAFE) + finally: + Path(path).unlink() + + def test_print_results_emits_summary(self): + with tempfile.TemporaryDirectory() as d: + root = Path(d) + _write_pickle(root / "evil.pkl", Payload()) + buf = StringIO() + with redirect_stdout(buf): + code = main(["fickling", str(root), "--print-results"]) + self.assertEqual(code, EXIT_UNSAFE) + self.assertIn("Potentially unsafe content detected", buf.getvalue()) + + +if __name__ == "__main__": + unittest.main()