diff --git a/run_benchmark.py b/run_benchmark.py index 13c5e645e..1ab0f7108 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -38,10 +38,12 @@ def run(): benchmark = importlib.import_module( f"userbenchmark.{available_benchmarks[args.bm_name]}.run" ) - benchmark.run(bm_args) except ImportError as e: print(f"Failed to import user benchmark module {args.bm_name}, error: {str(e)}") traceback.print_exc() + raise SystemExit(1) + + benchmark.run(bm_args) if __name__ == "__main__": diff --git a/test_run_benchmark.py b/test_run_benchmark.py new file mode 100644 index 000000000..eb24d1d09 --- /dev/null +++ b/test_run_benchmark.py @@ -0,0 +1,32 @@ +import importlib +import sys +import types + +import pytest + +import run_benchmark + + +def test_import_error_in_run_benchmark_exits_nonzero(monkeypatch): + monkeypatch.setattr(run_benchmark, "list_benchmarks", lambda: {"fake": "fake"}) + monkeypatch.setattr(run_benchmark.importlib, "import_module", lambda _: (_ for _ in ()).throw(ImportError("boom"))) + + monkeypatch.setattr(sys, "argv", ["run_benchmark.py", "fake"]) + with pytest.raises(SystemExit) as exc: + run_benchmark.run() + + assert exc.value.code == 1 + + +def test_run_exception_propagates_outside_import_handling(monkeypatch): + def fake_run(_): + raise RuntimeError("runtime failure") + + fake_module = types.SimpleNamespace(run=fake_run) + + monkeypatch.setattr(run_benchmark, "list_benchmarks", lambda: {"fake": "fake"}) + monkeypatch.setattr(run_benchmark.importlib, "import_module", lambda _: fake_module) + + monkeypatch.setattr(sys, "argv", ["run_benchmark.py", "fake"]) + with pytest.raises(RuntimeError): + run_benchmark.run()