From 750ffc1de194898fde9c3b15913ed01a86d89afa Mon Sep 17 00:00:00 2001 From: rtmalikian Date: Tue, 26 May 2026 06:53:05 -0700 Subject: [PATCH] fix: restrict web server CORS defaults --- README.md | 1 + tests/test_web_server_cors.py | 19 +++++++++++++++++++ web_server/README.md | 1 + web_server/app.py | 18 ++++++++++++++---- 4 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 tests/test_web_server_cors.py diff --git a/README.md b/README.md index 50abd31..d2b149e 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,7 @@ print(response.text) | `N_PREDICT` | Max tokens to generate | `512` | | `TEMPERATURE` | Sampling temperature (0.0-2.0) | `0.8` | | `THREADS` | Number of CPU threads | `4` | +| `CORS_ALLOW_ORIGINS` | Comma-separated trusted browser origins allowed to call the API | `http://localhost:8080,http://127.0.0.1:8080` | ### Example with Custom Settings diff --git a/tests/test_web_server_cors.py b/tests/test_web_server_cors.py new file mode 100644 index 0000000..d8d30ba --- /dev/null +++ b/tests/test_web_server_cors.py @@ -0,0 +1,19 @@ +import importlib + + +def test_web_server_cors_defaults_to_local_origins_without_credentials(monkeypatch): + monkeypatch.delenv("CORS_ALLOW_ORIGINS", raising=False) + + app_module = importlib.import_module("web_server.app") + + cors_middleware = next( + middleware + for middleware in app_module.app.user_middleware + if middleware.cls.__name__ == "CORSMiddleware" + ) + + assert cors_middleware.kwargs["allow_origins"] == [ + "http://localhost:8080", + "http://127.0.0.1:8080", + ] + assert cors_middleware.kwargs["allow_credentials"] is False diff --git a/web_server/README.md b/web_server/README.md index 450fb16..7ade4c7 100644 --- a/web_server/README.md +++ b/web_server/README.md @@ -303,6 +303,7 @@ console.log(data.choices[0].message.content); | `N_PREDICT` | Max tokens to generate | `512` | | `TEMPERATURE` | Sampling temperature (0.0-2.0) | `0.8` | | `THREADS` | Number of CPU threads | `4` | +| `CORS_ALLOW_ORIGINS` | Comma-separated trusted browser origins allowed to call the API | `http://localhost:8080,http://127.0.0.1:8080` | ### Example with Custom Settings diff --git a/web_server/app.py b/web_server/app.py index a0dc5c7..b7051f8 100644 --- a/web_server/app.py +++ b/web_server/app.py @@ -19,16 +19,26 @@ app = FastAPI(title="BitNet API", version="1.0.0") -# Enable CORS +# Configuration +CORS_ALLOW_ORIGINS = [ + origin.strip() + for origin in os.environ.get( + "CORS_ALLOW_ORIGINS", + "http://localhost:8080,http://127.0.0.1:8080", + ).split(",") + if origin.strip() +] + +# Enable CORS for the local web UI by default. Operators can opt in to +# additional trusted origins with CORS_ALLOW_ORIGINS when exposing the API. app.add_middleware( CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, + allow_origins=CORS_ALLOW_ORIGINS, + allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) -# Configuration MODEL_PATH = os.environ.get("MODEL_PATH", "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf") CTX_SIZE = int(os.environ.get("CTX_SIZE", "2048")) N_PREDICT = int(os.environ.get("N_PREDICT", "512"))