Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions app/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,24 @@


def setup_middleware(app: FastAPI):
# Security: reject wildcard origin with credentials enabled
allowed_origins = cors_settings.allowed_origins
if "*" in allowed_origins:
import warnings

warnings.warn(
"CORS allow_origins contains '*' with allow_credentials=True is insecure. "
"Set ALLOWED_ORIGINS to explicit origins in production.",
stacklevel=2,
)
allow_credentials = False
else:
allow_credentials = True

app.add_middleware(
CORSMiddleware,
allow_origins=cors_settings.allowed_origins,
allow_credentials=True,
allow_origins=allowed_origins,
allow_credentials=allow_credentials,
allow_methods=["*"],
allow_headers=["*"],
)
Expand Down
5 changes: 3 additions & 2 deletions app/utils/jwt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hmac
import time
import jwt
from base64 import b64decode, b64encode
Expand Down Expand Up @@ -38,7 +39,7 @@ async def get_admin_payload(token: str) -> dict | None:
if admin_id is not None:
try:
admin_id = int(admin_id)
except TypeError, ValueError:
except (TypeError, ValueError):
return
if not username or access not in ("admin", "sudo"):
return
Expand Down Expand Up @@ -97,7 +98,7 @@ async def get_subscription_payload(token: str) -> dict | None:
sha256((u_token + await get_secret_key()).encode("utf-8")).digest(), altchars=b"-_"
).decode("utf-8")[:10]
u_token_hex_resign = sha256((u_token + await get_secret_key()).encode("utf-8")).hexdigest()[:10]
if u_signature in (u_token_resign, u_token_hex_resign):
if hmac.compare_digest(u_signature, u_token_resign) | hmac.compare_digest(u_signature, u_token_hex_resign):
parts = u_token_dec_str.split(",")
if len(parts) == 3 and parts[0] in ("v2", "v3"):
_, u_user_id_str, u_created_at_str = parts
Expand Down