Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ print(response.text)
| `N_PREDICT` | Max tokens to generate | `512` |
| `TEMPERATURE` | Sampling temperature (0.0-2.0) | `0.8` |
| `THREADS` | Number of CPU threads | `4` |
| `CORS_ALLOW_ORIGINS` | Comma-separated trusted browser origins allowed to call the API | `http://localhost:8080,http://127.0.0.1:8080` |

### Example with Custom Settings

Expand Down
19 changes: 19 additions & 0 deletions tests/test_web_server_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import importlib


def test_web_server_cors_defaults_to_local_origins_without_credentials(monkeypatch):
monkeypatch.delenv("CORS_ALLOW_ORIGINS", raising=False)

app_module = importlib.import_module("web_server.app")

cors_middleware = next(
middleware
for middleware in app_module.app.user_middleware
if middleware.cls.__name__ == "CORSMiddleware"
)

assert cors_middleware.kwargs["allow_origins"] == [
"http://localhost:8080",
"http://127.0.0.1:8080",
]
assert cors_middleware.kwargs["allow_credentials"] is False
1 change: 1 addition & 0 deletions web_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ console.log(data.choices[0].message.content);
| `N_PREDICT` | Max tokens to generate | `512` |
| `TEMPERATURE` | Sampling temperature (0.0-2.0) | `0.8` |
| `THREADS` | Number of CPU threads | `4` |
| `CORS_ALLOW_ORIGINS` | Comma-separated trusted browser origins allowed to call the API | `http://localhost:8080,http://127.0.0.1:8080` |

### Example with Custom Settings

Expand Down
18 changes: 14 additions & 4 deletions web_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@

app = FastAPI(title="BitNet API", version="1.0.0")

# Enable CORS
# Configuration
CORS_ALLOW_ORIGINS = [
origin.strip()
for origin in os.environ.get(
"CORS_ALLOW_ORIGINS",
"http://localhost:8080,http://127.0.0.1:8080",
).split(",")
if origin.strip()
]

# Enable CORS for the local web UI by default. Operators can opt in to
# additional trusted origins with CORS_ALLOW_ORIGINS when exposing the API.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_origins=CORS_ALLOW_ORIGINS,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)

# Configuration
MODEL_PATH = os.environ.get("MODEL_PATH", "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf")
CTX_SIZE = int(os.environ.get("CTX_SIZE", "2048"))
N_PREDICT = int(os.environ.get("N_PREDICT", "512"))
Expand Down