Skip to content
15 changes: 7 additions & 8 deletions lite_bootstrap/bootstrappers/fastapi_bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class FastAPIConfig(
prometheus_expose_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)

def __post_init__(self) -> None:
# @dataclass(slots=True) replaces the class object, breaking bare super().
super(FastAPIConfig, self).__post_init__()
if not import_checker.is_fastapi_installed:
msg = "fastapi is not installed"
raise ConfigurationError(msg)
Expand All @@ -65,14 +67,11 @@ def __post_init__(self) -> None:
# FastAPIConfig stays frozen for user-facing immutability; __post_init__ needs
# to set application after construction, so we bypass the freeze here.
object.__setattr__(self, "application", application)
else:
application = self.application
if self.application_kwargs:
warnings.warn("application_kwargs must be used without application", stacklevel=2)

application.title = self.service_name
application.debug = self.service_debug
application.version = self.service_version
application.title = self.service_name
application.debug = self.service_debug
application.version = self.service_version
elif self.application_kwargs:
warnings.warn("application_kwargs must be used without application", stacklevel=2)


def _narrow_app(config: "FastAPIConfig") -> "fastapi.FastAPI":
Expand Down
10 changes: 7 additions & 3 deletions lite_bootstrap/bootstrappers/faststream_bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class FastStreamConfig(
):
application: "AsgiFastStream" = dataclasses.field(default_factory=_make_asgi_faststream)
opentelemetry_middleware_cls: type[FastStreamTelemetryMiddlewareProtocol] | None = None
opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list)
prometheus_middleware_cls: type[FastStreamPrometheusMiddlewareProtocol] | None = None
prometheus_collector_registry: "prometheus_client.CollectorRegistry | None" = None
faststream_log_level: int = logging.WARNING
faststream_health_check_broker_timeout: float = 5.0

Expand Down Expand Up @@ -156,12 +158,14 @@ def _make_collector_registry() -> "prometheus_client.CollectorRegistry":
@dataclasses.dataclass(kw_only=True)
class FastStreamPrometheusInstrument(PrometheusInstrument):
bootstrap_config: FastStreamConfig
collector_registry: "prometheus_client.CollectorRegistry" = dataclasses.field(
default_factory=_make_collector_registry, init=False
)
collector_registry: "prometheus_client.CollectorRegistry" = dataclasses.field(init=False)
not_ready_message = PrometheusInstrument.not_ready_message + " or prometheus_middleware_cls is missing"
missing_dependency_message = "prometheus_client is not installed"

def __post_init__(self) -> None:
injected = self.bootstrap_config.prometheus_collector_registry
self.collector_registry = injected if injected is not None else _make_collector_registry()

@classmethod
def is_configured(cls, bootstrap_config: "FastStreamConfig") -> bool: # ty: ignore[invalid-method-override]
return super().is_configured(bootstrap_config) and bool(bootstrap_config.prometheus_middleware_cls)
Expand Down
26 changes: 24 additions & 2 deletions lite_bootstrap/helpers/fastapi_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pathlib
import typing
import warnings

from lite_bootstrap import import_checker
from lite_bootstrap.exceptions import ConfigurationError
from lite_bootstrap.helpers.path import is_valid_path


if import_checker.is_fastapi_installed:
Expand All @@ -13,6 +15,26 @@
from starlette.routing import Route


def _safe_root_path(scope_root_path: str) -> str:
"""Strip trailing slash and validate against the project's path allowlist.

An empty ``root_path`` is the normal case (no proxy prefix) and is allowed without warning.
Any other path that fails the ``is_valid_path`` regex is rejected (falls back to empty) so
that proxy-header-derived root paths can't inject HTML into the offline-docs response.
"""
candidate = scope_root_path.rstrip("/")
if not candidate:
return ""
if not is_valid_path(candidate):
warnings.warn(
f"root_path {candidate!r} contains characters outside the valid-path allowlist; "
"falling back to empty root_path to prevent HTML injection in offline docs.",
stacklevel=3,
)
return ""
return candidate


def enable_offline_docs(
app: "FastAPI",
static_path: str,
Expand All @@ -36,7 +58,7 @@ def enable_offline_docs(

@app.get(docs_url, include_in_schema=False)
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
root_path = request.scope.get("root_path", "").rstrip("/")
root_path = _safe_root_path(request.scope.get("root_path", ""))
return get_swagger_ui_html(
openapi_url=f"{root_path}{app_openapi_url}",
title=f"{app.title} - Swagger UI",
Expand All @@ -51,7 +73,7 @@ async def swagger_ui_redirect() -> HTMLResponse:

@app.get(redoc_url, include_in_schema=False)
async def redoc_html(request: Request) -> HTMLResponse:
root_path = request.scope.get("root_path", "").rstrip("/")
root_path = _safe_root_path(request.scope.get("root_path", ""))
return get_redoc_html(
openapi_url=f"{root_path}{app_openapi_url}",
title=f"{app.title} - ReDoc",
Expand Down
8 changes: 8 additions & 0 deletions lite_bootstrap/instruments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ class BaseConfig:
service_environment: str | None = None
service_debug: bool = True

def __post_init__(self) -> None:
"""Terminate the MRO __post_init__ cascade safely.

Subclasses call super().__post_init__() to propagate through multiple-inheritance
chains (e.g. FastAPIConfig → CorsConfig → OpenTelemetryConfig → BaseConfig).
Without this no-op, the chain would raise AttributeError on object.
"""

@classmethod
def from_dict(cls, data: dict[str, typing.Any]) -> typing_extensions.Self:
"""Build a config from a dict; unknown keys are silently dropped, explicit None overrides defaults."""
Expand Down
18 changes: 18 additions & 0 deletions lite_bootstrap/instruments/cors_instrument.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import dataclasses
import typing

from lite_bootstrap.exceptions import ConfigurationError
from lite_bootstrap.instruments.base import BaseConfig, BaseInstrument


_PERMISSIVE_ORIGIN_REGEX: typing.Final[frozenset[str]] = frozenset({".*", r".+"})


@dataclasses.dataclass(kw_only=True, frozen=True)
class CorsConfig(BaseConfig):
cors_allowed_origins: list[str] = dataclasses.field(default_factory=list)
Expand All @@ -13,6 +18,19 @@ class CorsConfig(BaseConfig):
cors_allowed_origin_regex: str | None = None
cors_max_age: int = 600

def __post_init__(self) -> None:
if self.cors_allowed_credentials:
wildcard_in_origins = "*" in self.cors_allowed_origins
permissive_regex = self.cors_allowed_origin_regex in _PERMISSIVE_ORIGIN_REGEX
if wildcard_in_origins or permissive_regex:
msg = (
"Unsafe CORS configuration: cors_allowed_credentials=True combined with a "
"wildcard origin is rejected by browsers and is a security misconfiguration. "
"Use an explicit list of allowed origins (or a narrow regex)."
)
raise ConfigurationError(msg)
super().__post_init__()


@dataclasses.dataclass(kw_only=True, slots=True)
class CorsInstrument(BaseInstrument[CorsConfig]):
Expand Down
32 changes: 32 additions & 0 deletions lite_bootstrap/instruments/opentelemetry_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os
import typing
import urllib.parse
import warnings

from lite_bootstrap import import_checker
from lite_bootstrap.instruments.base import BaseConfig, BaseInstrument
Expand Down Expand Up @@ -38,6 +40,9 @@ class OpenTelemetryServiceFieldsConfig(BaseConfig):
opentelemetry_namespace: str | None = None


_LOCAL_HOSTS: typing.Final[frozenset[str]] = frozenset({"localhost", "127.0.0.1", "::1", ""})


@dataclasses.dataclass(kw_only=True, frozen=True)
class OpenTelemetryConfig(OpenTelemetryServiceFieldsConfig):
opentelemetry_container_name: str | None = dataclasses.field(
Expand All @@ -51,6 +56,33 @@ class OpenTelemetryConfig(OpenTelemetryServiceFieldsConfig):
opentelemetry_log_traces: bool = False
opentelemetry_generate_health_check_spans: bool = True

def __post_init__(self) -> None:
host = self._parse_remote_insecure_host()
if host is not None:
warnings.warn(
f"OTLP exporter sending traces unencrypted to non-local host {host!r}; "
"set opentelemetry_insecure=False or use a localhost/unix endpoint.",
stacklevel=2,
)
super().__post_init__()

def _parse_remote_insecure_host(self) -> str | None:
"""Return the host name if the endpoint is insecure AND non-local; else None."""
if not self.opentelemetry_endpoint or not self.opentelemetry_insecure:
return None
if self.opentelemetry_endpoint.startswith("unix://"):
return None
# urlparse treats schemeless input as `path`, misparsing `host:port` forms.
# Prepend `//` so urlparse always sees a network-location-style input.
raw = self.opentelemetry_endpoint
if "://" not in raw:
raw = f"//{raw}"
parsed = urllib.parse.urlparse(raw)
host = (parsed.hostname or "").lower()
if host in _LOCAL_HOSTS:
return None
return host


if import_checker.is_opentelemetry_installed and import_checker.is_pyroscope_installed:
_OTEL_PROFILE_ID_KEY: typing.Final = "pyroscope.profile.id"
Expand Down
Loading
Loading