Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ UVICORN_PORT = 8000
## Database pool settings are per worker, not global. Four workers with pool size 10 can open up to 40 base connections.
# SQLALCHEMY_POOL_SIZE = 25
# SQLALCHEMY_MAX_OVERFLOW = 60
# SQLALCHEMY_POOL_RECYCLE = 300
# SQLALCHEMY_POOL_RECYCLE = 1800
# SQLALCHEMY_POOL_TIMEOUT = 15
# ECHO_SQL_QUERIES = False

## NATS connection and subjects for multi-process/node coordination.
Expand Down
8 changes: 4 additions & 4 deletions app/core/xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def _handle_tls_settings(self, tls_settings: dict, settings: dict, inbound_tag:
if sni := tls_settings.get("serverName"):
settings["sni"].append(sni)
for certificate in tls_settings.get("certificates", []):
serve_on_node = certificate.pop("serveOnNode", False)
if serve_on_node:
if certificate.get("serveOnNode", False):
# prevent error on parse by xray core
del certificate["serveOnNode"]
continue
if certificate.get("certificateFile", None):
with open(certificate["certificateFile"], "rb") as file:
Expand Down Expand Up @@ -372,7 +372,7 @@ def _hysteria_finalmask_from_stream(stream: dict, net_settings: dict) -> dict |
"""Normalize Hysteria Salamander masks into finalmask for client generation."""
finalmask = stream.get("finalmask") or stream.get("finalMask")
if isinstance(finalmask, dict):
finalmask = deepcopy(finalmask)
finalmask = {k: v for k, v in finalmask.items()}
else:
finalmask = {}

Expand All @@ -383,7 +383,7 @@ def _hysteria_finalmask_from_stream(stream: dict, net_settings: dict) -> dict |
udpmasks = stream.get("udpmasks")

if isinstance(udpmasks, list) and udpmasks and not finalmask.get("udp"):
finalmask["udp"] = deepcopy(udpmasks)
finalmask["udp"] = list(udpmasks)

return finalmask or None

Expand Down
1 change: 0 additions & 1 deletion app/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .base import Base, GetDB, get_db # noqa


from .models import JWT, System, User # noqa

__all__ = [
Expand Down
14 changes: 5 additions & 9 deletions app/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@

IS_SQLITE = database_settings.is_sqlite

connect_args = {}
if IS_SQLITE:
connect_args["check_same_thread"] = False
elif database_settings.is_mysql:
connect_args["connect_timeout"] = database_settings.connect_timeout

if IS_SQLITE:
engine = create_async_engine(database_settings.url, connect_args=connect_args, echo=database_settings.echo_queries)
engine = create_async_engine(
database_settings.url, connect_args={"check_same_thread": False}, echo=database_settings.echo_queries
)
else:
engine = create_async_engine(
database_settings.url,
connect_args=connect_args,
pool_size=database_settings.pool_size,
max_overflow=database_settings.max_overflow,
pool_recycle=database_settings.pool_recycle,
pool_timeout=5,
pool_timeout=database_settings.pool_timeout,
pool_pre_ping=True,
pool_use_lifo=True,
echo=database_settings.echo_queries,
)

Expand Down
54 changes: 54 additions & 0 deletions app/db/migrations/versions/118c2a5eaa7d_add_performance_indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""add performance indexes

Revision ID: 118c2a5eaa7d
Revises: f9c69a49f544
Create Date: 2026-06-15 20:15:44.076888

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '118c2a5eaa7d'
down_revision = 'f9c69a49f544'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('hosts', schema=None) as batch_op:
batch_op.create_index('idx_hosts_inbound_tag', ['inbound_tag'], unique=False)
batch_op.create_index('idx_hosts_is_disabled', ['is_disabled'], unique=False)

with op.batch_alter_table('nodes', schema=None) as batch_op:
batch_op.create_index('idx_nodes_status', ['status'], unique=False)

with op.batch_alter_table('notification_reminders', schema=None) as batch_op:
batch_op.create_index('idx_notification_reminders_user_id', ['user_id'], unique=False)

with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.create_index('idx_users_status_expire', ['status', 'expire'], unique=False)
batch_op.create_index('idx_users_status_used_traffic', ['status', 'used_traffic'], unique=False)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('users', schema=None) as batch_op:
batch_op.drop_index('idx_users_status_used_traffic')
batch_op.drop_index('idx_users_status_expire')

with op.batch_alter_table('notification_reminders', schema=None) as batch_op:
batch_op.drop_index('idx_notification_reminders_user_id')

with op.batch_alter_table('nodes', schema=None) as batch_op:
batch_op.drop_index('idx_nodes_status')

with op.batch_alter_table('hosts', schema=None) as batch_op:
batch_op.drop_index('idx_hosts_is_disabled')
batch_op.drop_index('idx_hosts_inbound_tag')

# ### end Alembic commands ###
8 changes: 8 additions & 0 deletions app/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class User(Base, CreatedAtUTCMixin):
Index("idx_users_admin_online", "admin_id", "online_at"),
Index("idx_users_admin_status", "admin_id", "status"),
Index("idx_users_admin_created", "admin_id", "created_at"),
Index("idx_users_status_expire", "status", "expire"),
Index("idx_users_status_used_traffic", "status", "used_traffic"),
)
username: Mapped[str] = mapped_column(CaseSensitiveString(128), unique=True, index=True)
node_usages: Mapped[List["NodeUserUsage"]] = relationship(
Expand Down Expand Up @@ -519,6 +521,10 @@ class ProxyHostALPN(str, Enum):

class ProxyHost(Base, IdMixin):
__tablename__ = "hosts"
__table_args__ = (
Index("idx_hosts_inbound_tag", "inbound_tag"),
Index("idx_hosts_is_disabled", "is_disabled"),
)
remark: Mapped[str] = mapped_column(String(256), unique=False, nullable=False)
port: Mapped[Optional[int]] = mapped_column(nullable=True)
path: Mapped[Optional[str]] = mapped_column(String(256), unique=False, nullable=True)
Expand Down Expand Up @@ -595,6 +601,7 @@ class NodeStatus(str, Enum):

class Node(Base, CreatedAtUTCMixin):
__tablename__ = "nodes"
__table_args__ = (Index("idx_nodes_status", "status"),)
name: Mapped[str] = mapped_column(CaseSensitiveString(256), unique=True)
address: Mapped[str] = mapped_column(String(256), unique=False, nullable=False)
port: Mapped[int] = mapped_column(unique=False, nullable=False)
Expand Down Expand Up @@ -739,6 +746,7 @@ class NodeUsageResetLogs(Base, CreatedAtUTCMixin):

class NotificationReminder(Base, CreatedAtUTCMixin):
__tablename__ = "notification_reminders"
__table_args__ = (Index("idx_notification_reminders_user_id", "user_id"),)
user_id: Mapped[int] = fk_id_column("users.id", ondelete="CASCADE")
user: Mapped["User"] = relationship(back_populates="notification_reminders", init=False)
type: Mapped[ReminderType] = mapped_column(SQLEnum(ReminderType))
Expand Down
64 changes: 30 additions & 34 deletions app/jobs/record_usages.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,21 @@ def _chunked(items: list, size: int):
yield items[index : index + size]


_dialect_cache: str | None = None
_dialect_lock = asyncio.Lock()


async def get_dialect() -> str:
"""Get the database dialect name without holding the session open."""
async with GetDB() as db:
return db.bind.dialect.name
"""Get the database dialect name, cached after first lookup."""
global _dialect_cache
if _dialect_cache is not None:
return _dialect_cache
async with _dialect_lock:
if _dialect_cache is not None:
return _dialect_cache
async with GetDB() as db:
_dialect_cache = db.bind.dialect.name
return _dialect_cache


def build_node_user_usage_upsert(dialect: str, upsert_params: list[dict]):
Expand Down Expand Up @@ -436,48 +447,33 @@ async def record_user_stats_batched(all_node_params: dict, usage_coefficients: d

async def record_node_stats_batched(all_node_params: dict):
"""
Record node-level statistics for ALL nodes in batched operations.
This reduces write amplification and lock contention.

Args:
all_node_params: Dict mapping node_id -> list of node stat params
Record node-level statistics for ALL nodes using safe_execute per node.
Each node's upsert is executed independently to avoid batch transaction issues.
"""
if not all_node_params:
return

created_at = _get_time_bucket()
dialect = await get_dialect()

# Process each node's stats with concurrency control
async def _record_single_node(node_id: int, params: list[dict]):
if not params:
return

# Aggregate uplink and downlink from params
total_up = sum(p.get("up", 0) for p in params)
total_down = sum(p.get("down", 0) for p in params)

if not (total_up or total_down):
return

upsert_param = {
"node_id": node_id,
"created_at": created_at,
"up": total_up,
"down": total_down,
}

# Execute with concurrency control
async with JOB_SEM:
async with JOB_SEM:
for node_id, params in all_node_params.items():
if not params:
continue
total_up = sum(p.get("up", 0) for p in params)
total_down = sum(p.get("down", 0) for p in params)
if not (total_up or total_down):
continue
upsert_param = {
"node_id": node_id,
"created_at": created_at,
"up": total_up,
"down": total_down,
}
queries = build_node_usage_upsert(dialect, upsert_param)
for stmt, stmt_params in queries:
await safe_execute(stmt, stmt_params)

# Execute all node stats with limited concurrency
tasks = [_record_single_node(node_id, params) for node_id, params in all_node_params.items()]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)


def _process_users_stats_response(stats_response):
"""
Expand Down
37 changes: 22 additions & 15 deletions app/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,33 @@ async def get_nodes(self) -> dict[int, PasarGuardNode]:
async with self._lock.reader_lock:
return self._nodes

async def get_healthy_nodes(self) -> list[tuple[int, PasarGuardNode]]:
async def _get_nodes_by_health(self, expected: Health) -> list[tuple[int, PasarGuardNode]]:
async with self._lock.reader_lock:
nodes: list[tuple[int, PasarGuardNode]] = [
(id, node) for id, node in self._nodes.items() if (await node.get_health() == Health.HEALTHY)
]
return nodes
items = list(self._nodes.items())

health_results = await asyncio.gather(
*(node.get_health() for _, node in items),
return_exceptions=True,
)

async def get_broken_nodes(self) -> list[tuple[int, PasarGuardNode]]:
async with self._lock.reader_lock:
nodes: list[tuple[int, PasarGuardNode]] = [
(id, node) for id, node in self._nodes.items() if (await node.get_health() == Health.BROKEN)
]
return nodes
matched = []
for (node_id, node), health in zip(items, health_results):
if isinstance(health, Exception):
self.logger.warning("Failed to get health for node %s: %s", node_id, health)
continue
if health == expected and self._nodes.get(node_id) is node:
matched.append((node_id, node))
return matched

async def get_healthy_nodes(self) -> list[tuple[int, PasarGuardNode]]:
return await self._get_nodes_by_health(Health.HEALTHY)

async def get_broken_nodes(self) -> list[tuple[int, PasarGuardNode]]:
return await self._get_nodes_by_health(Health.BROKEN)

async def get_not_connected_nodes(self) -> list[tuple[int, PasarGuardNode]]:
async with self._lock.reader_lock:
nodes: list[tuple[int, PasarGuardNode]] = [
(id, node) for id, node in self._nodes.items() if (await node.get_health() == Health.NOT_CONNECTED)
]
return nodes
return await self._get_nodes_by_health(Health.NOT_CONNECTED)

async def _snapshot_nodes(self) -> list[PasarGuardNode]:
async with self._lock.reader_lock:
Expand Down
8 changes: 2 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,14 @@ class DatabaseSettings(EnvSettings):
url: str = Field(default="sqlite+aiosqlite:///db.sqlite3", validation_alias="SQLALCHEMY_DATABASE_URL")
pool_size: int = Field(default=25, validation_alias="SQLALCHEMY_POOL_SIZE")
max_overflow: int = Field(default=60, validation_alias="SQLALCHEMY_MAX_OVERFLOW")
pool_recycle: int = Field(default=300, validation_alias="SQLALCHEMY_POOL_RECYCLE")
connect_timeout: int = Field(default=5, gt=0, validation_alias="SQLALCHEMY_CONNECT_TIMEOUT")
pool_recycle: int = Field(default=1800, ge=1, validation_alias="SQLALCHEMY_POOL_RECYCLE")
pool_timeout: int = Field(default=15, ge=1, validation_alias="SQLALCHEMY_POOL_TIMEOUT")
echo_queries: bool = Field(default=False, validation_alias="ECHO_SQL_QUERIES")

@cached_property
def is_postgresql(self) -> bool:
return self.url.startswith("postgresql")

@cached_property
def is_mysql(self) -> bool:
return self.url.startswith(("mysql", "mariadb"))

@cached_property
def is_sqlite(self) -> bool:
return self.url.startswith("sqlite")
Expand Down