Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ safety-check pickle files:
fickling --check-safety -p pickled.data
```

### Scanning a directory or glob

When `PICKLE_FILE` is a directory or a glob pattern, fickling safety-scans every pickle
file it finds (`.pkl`, `.pickle`, `.bin`, and the `.pt`/`.pth` zip archives) and aggregates
the per-file verdicts into a single ClamAV-compatible exit code. Use `-R`/`--recursive`
to walk a directory tree — useful for triaging a downloaded model cache in one command:

```console
fickling ./model_cache --recursive
fickling "./downloads/**/*.pkl" --recursive
```

The exit code is `1` if any file is unsafe, `2` if scans completed but errored on some file
(and none were unsafe), or `0` if everything is clean. Add `-p`/`--print-results` for a
per-file breakdown.

## Advanced usage

### Trace pickle execution
Expand Down
156 changes: 155 additions & 1 deletion fickling/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from argparse import ArgumentParser
from ast import unparse
from pathlib import PurePosixPath
from pathlib import Path, PurePosixPath

from . import __version__, fickle, tracing
from .analysis import Severity, check_safety
Expand Down Expand Up @@ -149,6 +149,130 @@ def _scan_huggingface(
return EXIT_CLEAN if overall_safe else EXIT_UNSAFE


def _is_glob(target: str) -> bool:
"""Return True if `target` looks like a shell glob pattern."""
return any(ch in target for ch in "*?[")


def _collect_scan_targets(target: str, recursive: bool) -> tuple[list[Path], bool]:
"""Expand a path, directory, or glob into a sorted list of candidate files.

Args:
target: A file path, directory path, or glob pattern. Patterns express
recursion explicitly with ``**`` (e.g. ``models/**/*.pkl``); the
`recursive` flag governs bare-directory inputs only.
recursive: If True, a directory input is walked depth-first.

Returns:
A tuple ``(files, resolved)``. ``files`` is the sorted list of existing
regular files. ``resolved`` is False only when `target` matched nothing
on disk (no such file, directory, or glob match) so the caller can emit
EXIT_ERROR rather than a misleading "clean" result.
"""
path = Path(target)
if path.is_dir():
pattern = "**/*" if recursive else "*"
return sorted(p for p in path.glob(pattern) if p.is_file()), True
# Check for an existing regular file before treating `target` as a glob, so a
# literal filename containing glob metacharacters (e.g. ``data[v1].pkl``) is
# scanned as itself rather than mis-parsed as a pattern that matches nothing.
if path.is_file():
return [path], True
if _is_glob(target):
# Path.glob requires a relative pattern, so split off any anchor (e.g.
# the leading "/") and glob from there. ``**`` recurses on its own.
anchor = Path(path.anchor) if path.anchor else Path()
pattern = str(path.relative_to(path.anchor)) if path.anchor else target
matches = list(anchor.glob(pattern))
files = sorted(m for m in matches if m.is_file())
return files, bool(matches)
return [], False


def _scan_paths(
target: str,
recursive: bool = False,
json_output_path: str | None = None,
print_results: bool = False,
) -> int:
"""Safety-scan every pickle file under a directory or glob pattern.

Reuses :func:`scan_file` and :func:`scan_zip_archive` and the
``HF_PICKLE_EXTENSIONS`` classification to triage a tree of files in one
invocation. Per-file verdicts are aggregated into a single ClamAV-compatible
exit code with the following precedence: any unsafe file -> EXIT_UNSAFE,
otherwise any scan error -> EXIT_ERROR, otherwise EXIT_CLEAN.

Args:
target: Directory path or glob pattern to scan.
recursive: Walk directories recursively when True.
json_output_path: Optional path to write JSON analysis results.
print_results: Whether to print results to console.

Returns:
EXIT_CLEAN (0), EXIT_UNSAFE (1), or EXIT_ERROR (2)
"""
files, resolved = _collect_scan_targets(target, recursive)
if not resolved:
sys.stderr.write(f"Error: no such file, directory, or glob match: '{target}'\n")
return EXIT_ERROR

candidates = [f for f in files if f.suffix.lower() in HF_PICKLE_EXTENSIONS]
if not candidates:
if print_results:
print(f"No scannable pickle files found in '{target}'")
return EXIT_CLEAN

json_output = json_output_path or DEFAULT_JSON_OUTPUT_FILE
if print_results:
print(f"Scanning {len(candidates)} file(s) under '{target}'...")

any_unsafe = False
any_error = False

for filepath in candidates:
path_str = str(filepath)
if filepath.suffix.lower() in HF_ZIP_PICKLE_EXTENSIONS:
member_results = scan_zip_archive(path_str, graceful=True, json_output_path=json_output)
file_results = list(member_results.values())
else:
file_results = [scan_file(path_str, graceful=True, json_output_path=json_output)]

if print_results:
print(f"\n Scanning: {path_str}")

for result in file_results:
if print_results:
for ar in result.results:
result_str = ar.to_string()
if result_str:
print(f" {result_str}")

if not result.is_safe:
any_unsafe = True
if print_results:
sys.stderr.write(f" WARNING: {path_str} may contain unsafe content!\n")

if result.errors:
any_error = True
for err in result.errors:
sys.stderr.write(f" {err}\n")

if print_results:
if any_unsafe:
print(f"\n{target}: Potentially unsafe content detected!")
elif any_error:
print(f"\n{target}: Completed with scan errors; review warnings above")
else:
print(f"\n{target}: No obvious safety issues detected")

if any_unsafe:
return EXIT_UNSAFE
if any_error:
return EXIT_ERROR
return EXIT_CLEAN


def main(argv: list[str] | None = None) -> int:
if argv is None:
argv = sys.argv
Expand Down Expand Up @@ -234,6 +358,15 @@ def main(argv: list[str] | None = None) -> int:
action="store_true",
help="print a runtime trace while interpreting the input pickle file",
)
parser.add_argument(
"--recursive",
"-R",
action="store_true",
help="when PICKLE_FILE is a directory, scan it recursively. Directory and "
"glob inputs are always safety-scanned (reusing the --check-safety analysis), "
"aggregating per-file verdicts into a single ClamAV-compatible exit code "
"(unsafe takes precedence over scan errors).",
)
parser.add_argument("--version", "-v", action="store_true", help="print the version and exit")
options.add_argument(
"--huggingface",
Expand Down Expand Up @@ -278,6 +411,27 @@ def main(argv: list[str] | None = None) -> int:
print_results=args.print_results,
)

# Directory / glob scanning mode (bulk triage). Triggered when PICKLE_FILE is
# a directory, a glob pattern, or --recursive is set, as long as no single-file
# operation (inject/create/trace) was requested.
if (
args.create is None
and args.inject is None
and not args.trace
and args.PICKLE_FILE != "-"
and (
args.recursive
or (_is_glob(args.PICKLE_FILE) and not Path(args.PICKLE_FILE).is_file())
or Path(args.PICKLE_FILE).is_dir()
)
):
return _scan_paths(
args.PICKLE_FILE,
recursive=args.recursive,
json_output_path=args.json_output,
print_results=args.print_results,
)

if args.create is None:
if args.PICKLE_FILE == "-":
if hasattr(sys.stdin, "buffer") and sys.stdin.buffer is not None:
Expand Down
159 changes: 159 additions & 0 deletions test/test_cli_recursive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import os
import pickle
import tempfile
import unittest
import zipfile
from contextlib import redirect_stdout
from io import StringIO
from pathlib import Path

from fickling.cli import main
from fickling.constants import EXIT_CLEAN, EXIT_ERROR, EXIT_UNSAFE


class Payload:
"""Malicious payload for testing (executes os.system on unpickle)."""

def __reduce__(self):
return (os.system, ("echo pwned",))


def _write_pickle(path: Path, obj) -> None:
with open(path, "wb") as f:
pickle.dump(obj, f)


def _run(*argv) -> int:
"""Invoke the CLI and swallow stdout to keep test output clean."""
with redirect_stdout(StringIO()):
return main(["fickling", *argv])


class TestRecursiveDirectoryScan(unittest.TestCase):
def test_directory_with_unsafe_file_returns_exit_unsafe(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "safe.pkl", [1, 2, 3])
_write_pickle(root / "evil.pkl", Payload())
self.assertEqual(_run(str(root)), EXIT_UNSAFE)

def test_all_safe_directory_returns_exit_clean(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "a.pkl", [1, 2, 3])
_write_pickle(root / "b.pickle", {"k": "v"})
_write_pickle(root / "c.bin", 42)
self.assertEqual(_run(str(root)), EXIT_CLEAN)

def test_directory_without_pickle_files_returns_exit_clean(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
(root / "README.md").write_text("not a pickle")
(root / "config.json").write_text("{}")
self.assertEqual(_run(str(root)), EXIT_CLEAN)

def test_nonrecursive_does_not_descend_into_subdirs(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "safe.pkl", [1, 2, 3])
nested = root / "nested"
nested.mkdir()
_write_pickle(nested / "evil.pkl", Payload())
# Default (non-recursive) must ignore the nested unsafe file.
self.assertEqual(_run(str(root)), EXIT_CLEAN)

def test_recursive_descends_into_subdirs(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "safe.pkl", [1, 2, 3])
nested = root / "nested" / "deeper"
nested.mkdir(parents=True)
_write_pickle(nested / "evil.pkl", Payload())
self.assertEqual(_run(str(root), "--recursive"), EXIT_UNSAFE)

def test_recursive_short_flag(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
nested = root / "sub"
nested.mkdir()
_write_pickle(nested / "evil.pkl", Payload())
self.assertEqual(_run(str(root), "-R"), EXIT_UNSAFE)

def test_nonexistent_target_returns_exit_error(self):
# A directory/glob scan of a path that resolves to nothing reports
# EXIT_ERROR rather than a misleading "clean" result.
self.assertEqual(_run("/nonexistent/directory/xyz", "--recursive"), EXIT_ERROR)


class TestGlobScan(unittest.TestCase):
def test_glob_matches_unsafe_file(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "safe.pkl", [1, 2, 3])
_write_pickle(root / "evil.pkl", Payload())
self.assertEqual(_run(str(root / "*.pkl")), EXIT_UNSAFE)

def test_glob_only_safe_matches(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "a.pkl", [1, 2, 3])
_write_pickle(root / "evil.bin", Payload())
# Glob restricted to *.pkl must not pick up the unsafe .bin file.
self.assertEqual(_run(str(root / "*.pkl")), EXIT_CLEAN)

def test_glob_no_matches_returns_exit_error(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
self.assertEqual(_run(str(root / "*.pkl")), EXIT_ERROR)

def test_recursive_glob(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
nested = root / "a" / "b"
nested.mkdir(parents=True)
_write_pickle(nested / "evil.pkl", Payload())
pattern = str(root / "**" / "*.pkl")
self.assertEqual(_run(pattern, "--recursive"), EXIT_UNSAFE)


class TestZipMemberScan(unittest.TestCase):
def test_pt_archive_with_unsafe_member(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
archive = root / "model.pt"
with zipfile.ZipFile(archive, "w") as zf:
zf.writestr("data.pkl", pickle.dumps(Payload()))
self.assertEqual(_run(str(root)), EXIT_UNSAFE)

def test_pth_archive_all_safe(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
archive = root / "model.pth"
with zipfile.ZipFile(archive, "w") as zf:
zf.writestr("weights.pkl", pickle.dumps([1, 2, 3]))
self.assertEqual(_run(str(root)), EXIT_CLEAN)


class TestSingleFileUnaffected(unittest.TestCase):
def test_recursive_flag_on_single_file(self):
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f:
pickle.dump(Payload(), f)
path = f.name
try:
self.assertEqual(_run(path, "--recursive"), EXIT_UNSAFE)
finally:
Path(path).unlink()

def test_print_results_emits_summary(self):
with tempfile.TemporaryDirectory() as d:
root = Path(d)
_write_pickle(root / "evil.pkl", Payload())
buf = StringIO()
with redirect_stdout(buf):
code = main(["fickling", str(root), "--print-results"])
self.assertEqual(code, EXIT_UNSAFE)
self.assertIn("Potentially unsafe content detected", buf.getvalue())


if __name__ == "__main__":
unittest.main()