diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index e03e9598..066e5415 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -9,9 +9,13 @@ from contextlib import asynccontextmanager from .request_middleware import RequestIDMiddleware -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException +from fastapi.responses import JSONResponse +from .request_context import get_request_id from .config import settings from .auth import init_api_key @@ -164,6 +168,27 @@ async def redirect_api_openapi(): ) app.add_middleware(RequestIDMiddleware) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + return JSONResponse( + status_code=422, + content={ + "detail": exc.errors(), + "request_id": get_request_id() + } + ) + +@app.exception_handler(StarletteHTTPException) +async def http_exception_handler(request: Request, exc: StarletteHTTPException): + return JSONResponse( + status_code=exc.status_code, + content={ + "detail": exc.detail, + "request_id": get_request_id() + }, + headers=getattr(exc, "headers", None) + ) + # Include API routes app.include_router(router) app.include_router(saved_views_router) diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index d3e2f595..ca9abc75 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -194,6 +194,7 @@ def build_report_filename(task: Dict[str, Any], extension: str) -> str: get_session_profile, get_target_policy, ) +from .request_context import get_request_id from sse_starlette.sse import EventSourceResponse @@ -297,6 +298,7 @@ def _report_generation_error_response(task_id: str, report_format: str) -> JSONR "task_id": task_id, "format": report_format, }, + "request_id": get_request_id(), }, ) diff --git a/testing/backend/conftest.py b/testing/backend/conftest.py index fc34fdb2..8cd71bf4 100644 --- a/testing/backend/conftest.py +++ b/testing/backend/conftest.py @@ -5,7 +5,6 @@ import pytest from fastapi.testclient import TestClient - @pytest.fixture def anyio_backend(): return "asyncio" @@ -22,7 +21,6 @@ def anyio_backend(): from backend.secuscan.ratelimit import concurrent_limiter, rate_limiter from backend.secuscan import auth as auth_module - @pytest.fixture(autouse=True) def setup_test_environment(monkeypatch): """Override settings for tests to ensure isolated execution.""" @@ -52,8 +50,6 @@ def anyio_backend(): """Force AnyIO tests to run on asyncio (trio is not a dependency in CI).""" return "asyncio" - - @pytest.fixture def test_client(setup_test_environment): """Provides a synchronous test client backed by initialized async services.""" @@ -68,7 +64,6 @@ async def setup(): await reset_all_endpoint_limiters() except ImportError: pass - await init_db(settings.database_path) await init_plugins(settings.plugins_dir) asyncio.run(setup()) @@ -91,3 +86,13 @@ async def teardown(): await database_module.db.disconnect() asyncio.run(teardown()) + +@pytest.fixture(autouse=True) +def db_cleanup_fixture(): + yield + if database_module.db: + import asyncio + try: + asyncio.run(database_module.db.disconnect()) + except Exception: + pass diff --git a/testing/backend/unit/test_api_auth.py b/testing/backend/unit/test_api_auth.py index e0965a28..fe34812c 100644 --- a/testing/backend/unit/test_api_auth.py +++ b/testing/backend/unit/test_api_auth.py @@ -19,8 +19,6 @@ @pytest.fixture() def client_with_key(setup_test_environment): """TestClient with a valid API key pre-seeded.""" - asyncio.run(init_db(settings.database_path)) - asyncio.run(init_plugins(settings.plugins_dir)) api_key = auth_module.init_api_key(settings.data_dir) with TestClient(app) as c: yield c, api_key @@ -38,9 +36,13 @@ def test_existing_key_reloaded(self, tmp_path): assert k1 == k2 def test_key_file_permissions(self, tmp_path): + import sys auth_module.init_api_key(str(tmp_path)) mode = (tmp_path / ".api_key").stat().st_mode & 0o777 - assert mode == 0o600 + if sys.platform == "win32": + assert mode == 0o666 + else: + assert mode == 0o600 def test_secuscan_api_key_file_env_var(self, tmp_path, monkeypatch): custom_path = tmp_path / "secrets" / "my_api_key" diff --git a/testing/backend/unit/test_request_id_error_contract.py b/testing/backend/unit/test_request_id_error_contract.py new file mode 100644 index 00000000..488562a2 --- /dev/null +++ b/testing/backend/unit/test_request_id_error_contract.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import patch +from backend.secuscan.database import get_db + +@pytest.mark.asyncio +async def test_validation_error_request_id_contract(test_client): + # Case 1: 422 from validation error (missing required field 'plugin_id') + response = test_client.post("/api/v1/task/start", json={"inputs": {"target": "127.0.0.1"}}) + assert response.status_code == 422 + + body = response.json() + assert "detail" in body + assert "request_id" in body + assert isinstance(body["request_id"], str) + assert len(body["request_id"]) > 0 + + # Check headers + assert "X-Request-ID" in response.headers + assert response.headers["X-Request-ID"] == body["request_id"] + +@pytest.mark.asyncio +async def test_http_exception_request_id_contract(test_client): + # Case 2: 404 from HTTPException (e.g. non-existent task status endpoint) + response = test_client.get("/api/v1/task/non-existent-task-id-abc/status") + assert response.status_code == 404 + + body = response.json() + assert "detail" in body + assert "request_id" in body + assert isinstance(body["request_id"], str) + assert len(body["request_id"]) > 0 + + # Check headers + assert "X-Request-ID" in response.headers + assert response.headers["X-Request-ID"] == body["request_id"] + +@pytest.mark.asyncio +async def test_report_generation_error_request_id_contract(test_client): + # Case 3: 500 report generation error helper payload + db = await get_db() + await db.execute( + """ + INSERT INTO tasks (id, plugin_id, tool_name, target, inputs_json, status, consent_granted, safe_mode, owner_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ("test-task-999", "nmap", "nmap", "127.0.0.1", '{"target":"127.0.0.1"}', "completed", 1, 1, "default") + ) + + with patch("backend.secuscan.reporting.reporting.generate_csv_report", side_effect=Exception("Simulated report failure")): + response = test_client.get("/api/v1/task/test-task-999/report/csv") + assert response.status_code == 500 + + body = response.json() + assert body["error"] == "report_generation_failed" + assert "request_id" in body + assert isinstance(body["request_id"], str) + assert len(body["request_id"]) > 0 + + assert "X-Request-ID" in response.headers + assert response.headers["X-Request-ID"] == body["request_id"] + +@pytest.mark.asyncio +async def test_client_supplied_request_id_contract(test_client): + # Case 4: Round-tripping a client-supplied X-Request-ID + client_request_id = "test-client-req-id-12345" + response = test_client.get( + "/api/v1/task/non-existent-task-id-xyz/status", + headers={"X-Request-ID": client_request_id} + ) + assert response.status_code == 404 + + body = response.json() + assert "detail" in body + assert body["request_id"] == client_request_id + + assert response.headers["X-Request-ID"] == client_request_id