Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions app/core/hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from app.core.manager import core_manager
from app.db import GetDB
from app.db.crud.host import get_host_by_id, get_hosts, upsert_inbounds
from app.db.crud.node import get_xray_version_by_core_id
from app.db.models import ProxyHostSecurity
from app.models.host import BaseHost, TransportSettings, WireGuardHostOverrides
from app.models.subscription import (
Expand Down Expand Up @@ -49,6 +50,7 @@ def _string_list(value) -> list[str]:
async def _prepare_subscription_inbound_data(
host: BaseHost,
down_settings: SubscriptionInboundData | None = None,
db: AsyncSession | None = None,
) -> SubscriptionInboundData:
"""
Prepare host data - creates small config instances ONCE.
Expand All @@ -57,6 +59,10 @@ async def _prepare_subscription_inbound_data(
"""
# Get inbound configuration
inbound_config = await core_manager.get_inbound_by_tag(host.inbound_tag)
core_id = await core_manager.get_core_id_by_tag(host.inbound_tag)
xray_version = None
if core_id and db:
xray_version = await get_xray_version_by_core_id(db, core_id)
protocol = inbound_config["protocol"]

ts = host.transport_settings
Expand Down Expand Up @@ -260,6 +266,16 @@ async def _prepare_subscription_inbound_data(
else inbound_config.get("session_placement")
),
session_key=xs.session_key if xs and xs.session_key is not None else inbound_config.get("session_key"),
session_id_table=(
xs.session_id_table
if xs and xs.session_id_table is not None
else inbound_config.get("session_id_table")
),
session_id_length=(
xs.session_id_length
if xs and xs.session_id_length is not None
else inbound_config.get("session_id_length")
),
seq_placement=(
xs.seq_placement if xs and xs.seq_placement is not None else inbound_config.get("seq_placement")
),
Expand All @@ -281,6 +297,7 @@ async def _prepare_subscription_inbound_data(
download_settings=down_settings if xs and down_settings else inbound_config.get("download_settings"),
http_headers=host.http_headers if host.http_headers is not None else inbound_config.get("http_headers"),
random_user_agent=host.random_user_agent,
core_version=xray_version,
)
elif network in ("grpc", "gun"):
gs = ts.grpc_settings if ts else None
Expand Down Expand Up @@ -540,8 +557,8 @@ async def _prepare_host_entry(
downstream = await get_host_by_id(db, ds_host)
if downstream:
downstream_base = BaseHost.model_validate(downstream)
downstream_data: SubscriptionInboundData = await _prepare_subscription_inbound_data(downstream_base)
subscription_data = await _prepare_subscription_inbound_data(host, downstream_data)
downstream_data = await _prepare_subscription_inbound_data(downstream_base, db=db)
subscription_data = await _prepare_subscription_inbound_data(host, downstream_data, db=db)

# Return subscription data directly
return host.id, subscription_data
Expand Down
10 changes: 9 additions & 1 deletion app/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self):
self._lock = Lock()
self._inbounds: list[str] = []
self._inbounds_by_tag = {}
self._tag_to_core_id: dict[str, int] = {}
self._nats_enabled = is_nats_enabled()
self._multi_worker = runtime_settings.role.requires_nats
self._nc: nats.NATS | None = None
Expand Down Expand Up @@ -222,9 +223,13 @@ async def initialize(self, db):
async def update_inbounds(self):
async with self._lock:
new_inbounds = {}
for core in self._cores.values():
new_tag_to_core_id = {}
for core_id, core in self._cores.items():
new_inbounds.update(core.inbounds_by_tag)
for tag in core.inbounds_by_tag:
new_tag_to_core_id[tag] = core_id

self._tag_to_core_id = new_tag_to_core_id
self._inbounds_by_tag = new_inbounds
self._inbounds = list(self._inbounds_by_tag.keys())

Expand Down Expand Up @@ -313,6 +318,9 @@ async def get_inbound_by_tag(self, tag) -> dict:
return None
return deepcopy(inbound)

async def get_core_id_by_tag(self, tag: str) -> int | None:
async with self._lock:
return self._tag_to_core_id.get(tag)

core_manager = CoreManager()

Expand Down
17 changes: 15 additions & 2 deletions app/core/xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def _protocols_from_inbounds_by_tag(inbounds_by_tag: dict[str, dict]) -> frozens
if (protocol := ProxyProtocol.from_value(inbound["protocol"])) is not None
)

def rename_xhttp_session_keys(obj):
if isinstance(obj, dict):
if "sessionPlacement" in obj:
obj["sessionIDPlacement"] = obj.pop("sessionPlacement")
if "sessionKey" in obj:
obj["sessionIDKey"] = obj.pop("sessionKey")
for value in obj.values():
rename_xhttp_session_keys(value)
elif isinstance(obj, list):
for item in obj:
rename_xhttp_session_keys(item)

class XRayConfig(dict):
def __init__(
Expand Down Expand Up @@ -329,8 +340,10 @@ def get_xhttp_value(key: str):
settings["x_padding_placement"] = get_xhttp_value("xPaddingPlacement")
settings["x_padding_method"] = get_xhttp_value("xPaddingMethod")
settings["uplink_http_method"] = get_xhttp_value("uplinkHTTPMethod")
settings["session_placement"] = get_xhttp_value("sessionPlacement")
settings["session_key"] = get_xhttp_value("sessionKey")
settings["session_placement"] = get_xhttp_value("sessionIDPlacement") or get_xhttp_value("sessionPlacement")
settings["session_key"] = get_xhttp_value("sessionIDKey") or get_xhttp_value("sessionKey")
settings["session_id_table"] = get_xhttp_value("sessionIDTable")
settings["session_id_length"] = get_xhttp_value("sessionIDLength")
settings["seq_placement"] = get_xhttp_value("seqPlacement")
settings["seq_key"] = get_xhttp_value("seqKey")
settings["uplink_data_placement"] = get_xhttp_value("uplinkDataPlacement")
Expand Down
64 changes: 49 additions & 15 deletions app/db/crud/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timezone
from typing import Optional

from packaging.version import InvalidVersion, Version
from sqlalchemy import and_, bindparam, case, delete, func, literal_column, or_, select, update
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -86,6 +87,25 @@ async def get_node_by_id(db: AsyncSession, node_id: int) -> Optional[Node]:
await load_node_attrs(node)
return node

async def get_xray_version_by_core_id(db: AsyncSession, core_config_id: int) -> str | None:
"""Returns the highest reported xray_version among nodes on this core config."""
versions = (
await db.execute(
select(Node.xray_version).where(Node.core_config_id == core_config_id).where(Node.xray_version.isnot(None))
)
).scalars().all()
if not versions:
return None

parsed: list[tuple[Version, str]] = []
for raw in versions:
try:
parsed.append((Version(raw.lstrip("v")), raw))
except InvalidVersion:
continue
if not parsed:
return versions[0]
return max(parsed, key=lambda item: item[0])[1]

async def get_nodes(
db: AsyncSession,
Expand Down Expand Up @@ -427,6 +447,9 @@ async def remove_node(db: AsyncSession, db_node: Node) -> None:
await db.commit()


CONNECTION_IDENTITY_FIELDS = ("address", "port", "server_ca", "connection_type", "api_key")


async def modify_node(db: AsyncSession, db_node: Node, modify: NodeModify) -> Node:
"""
modify an existing node with new information.
Expand All @@ -444,12 +467,17 @@ async def modify_node(db: AsyncSession, db_node: Node, modify: NodeModify) -> No
if "proxy_url" in modify.model_fields_set and modify.proxy_url is None:
node_data["proxy_url"] = None

connection_identity_changed = any(
field in node_data and getattr(db_node, field) != node_data[field] for field in CONNECTION_IDENTITY_FIELDS
)

for key, value in node_data.items():
setattr(db_node, key, value)

db_node.xray_version = None
db_node.message = None
db_node.node_version = None
if connection_identity_changed:
db_node.xray_version = None
db_node.node_version = None

if db_node.is_limited:
db_node.status = NodeStatus.limited
Expand Down Expand Up @@ -485,17 +513,17 @@ async def update_node_status(
Returns:
Node: The updated Node object.
"""
stmt = (
update(Node)
.where(Node.id == db_node.id)
.values(
status=status,
message=message,
xray_version=xray_version,
node_version=node_version,
last_status_change=datetime.now(timezone.utc),
)
)
values: dict = {
"status": status,
"message": message,
"last_status_change": datetime.now(timezone.utc),
}
if xray_version:
values["xray_version"] = xray_version
if node_version:
values["node_version"] = node_version

stmt = update(Node).where(Node.id == db_node.id).values(**values)
await db.execute(stmt)
await db.commit()

Expand Down Expand Up @@ -544,8 +572,14 @@ async def bulk_update_node_status(
.values(
status=bindparam("status"),
message=bindparam("message"),
xray_version=bindparam("xray_version"),
node_version=bindparam("node_version"),
xray_version=case(
(bindparam("xray_version") != "", bindparam("xray_version")),
else_=Node.xray_version,
),
node_version=case(
(bindparam("node_version") != "", bindparam("node_version")),
else_=Node.node_version,
),
last_status_change=bindparam("now"),
)
)
Expand Down
1 change: 1 addition & 0 deletions app/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def validate_sets(cls, v: set):
class CoreResponse(CoreBase):
id: int
created_at: dt
xray_version: str | None = Field(default=None)

model_config = ConfigDict(from_attributes=True)

Expand Down
2 changes: 2 additions & 0 deletions app/models/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class XHttpSettings(BaseModel):
uplink_http_method: str | None = Field(default=None)
session_placement: str | None = Field(default=None, pattern=r"^$|^(path|cookie|header|query)$")
session_key: str | None = Field(default=None)
session_id_table: str | None = Field(default=None, pattern=r"^[\x20-\x7E]*$")
session_id_length: str | None = Field(default=None, pattern=r"^\d{1,16}(-\d{1,16})?$")
seq_placement: str | None = Field(default=None, pattern=r"^$|^(path|cookie|header|query)$")
seq_key: str | None = Field(default=None)
uplink_data_placement: str | None = Field(default=None, pattern=r"^$|^(body|cookie|header)$")
Expand Down
2 changes: 1 addition & 1 deletion app/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class UserIPListAll(BaseModel):

class NodeCoreUpdate(BaseModel):
core_version: str = Field(default="latest", pattern=r"^(latest|v?\d+\.\d+\.\d+)$", examples=["v25.8.31"])

confirm: bool = Field(default=False)

class NodeGeoFilesUpdate(BaseModel):
region: GeoFilseRegion = Field(default=GeoFilseRegion.iran, examples=["iran"])
Expand Down
5 changes: 4 additions & 1 deletion app/models/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class XHTTPTransportConfig(BaseTransportConfig):
uplink_http_method: str | None = Field(None, serialization_alias="uplinkHTTPMethod")
session_placement: str | None = Field(None, serialization_alias="sessionPlacement")
session_key: str | None = Field(None, serialization_alias="sessionKey")
session_id_table: str | None = Field(None, serialization_alias="sessionIDTable")
session_id_length: str | None = Field(None, serialization_alias="sessionIDLength")
seq_placement: str | None = Field(None, serialization_alias="seqPlacement")
seq_key: str | None = Field(None, serialization_alias="seqKey")
uplink_data_placement: str | None = Field(None, serialization_alias="uplinkDataPlacement")
Expand All @@ -124,7 +126,8 @@ class XHTTPTransportConfig(BaseTransportConfig):
download_settings: SubscriptionInboundData | dict | None = Field(None, serialization_alias="downloadSettings")
http_headers: dict[str, str] | None = Field(None)
random_user_agent: bool = Field(False)

core_version: str | None = Field(None, exclude=True)

@field_validator(
"sc_max_each_post_bytes",
"sc_min_posts_interval_ms",
Expand Down
25 changes: 22 additions & 3 deletions app/operation/node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import json
from copy import deepcopy
from typing import AsyncIterator, Callable

from fastapi import HTTPException
Expand Down Expand Up @@ -65,6 +67,8 @@
from app.operation import BaseOperation, OperatorType
from app.utils.logger import get_logger
from config import runtime_settings
from app.core.xray import rename_xhttp_session_keys
from app.subscription.base import is_new_xray

MAX_MESSAGE_LENGTH = 128

Expand Down Expand Up @@ -240,14 +244,19 @@ async def connect_node(db_node: Node, core, users: list) -> dict | None:
return None
if core is None:
return None

if core.type == CoreType.xray and (is_new_xray(db_node.xray_version)):
copied = deepcopy(dict(core))
rename_xhttp_session_keys(copied)
config_str = json.dumps(copied)
else:
config_str = core.to_str()
old_status = db_node.status
logger.info(f'Connecting to "{db_node.name}" node')
type = service.BackendType.WIREGUARD if core.type == CoreType.wg else service.BackendType.XRAY

try:
start_kwargs = {
"config": core.to_str(),
"config": config_str,
"backend_type": type,
"users": users,
"keep_alive": db_node.keep_alive,
Expand Down Expand Up @@ -544,9 +553,19 @@ async def update_node(self, db: AsyncSession, node_id: int) -> dict:
return await self._update_node_api_impl(node_id)

async def update_core(self, db: AsyncSession, node_id: int, node_core_update: NodeCoreUpdate) -> dict:
await self.get_validated_node(db, node_id)
db_node = await self.get_validated_node(db, node_id)

current = db_node.xray_version
target = node_core_update.core_version
if current and not is_new_xray(current) and is_new_xray(target) and not node_core_update.confirm:
raise HTTPException(
status_code=400,
detail="Upgrading to Xray 26.6.22+ renames xhttp session parameters (breaking change). Set confirm=true to proceed.",
)

return await self._update_core_impl(node_id, node_core_update)


async def update_geofiles(self, db: AsyncSession, node_id: int, node_geofiles_update: NodeGeoFilesUpdate) -> dict:
await self.get_validated_node(db, node_id)
return await self._update_geofiles_impl(node_id, node_geofiles_update)
Expand Down
8 changes: 6 additions & 2 deletions app/routers/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, status

from app.db import AsyncSession, get_db
from app.db.crud.node import get_xray_version_by_core_id
from app.models.admin import AdminDetails
from app.models.core import (
BulkCoreSelection,
Expand Down Expand Up @@ -36,9 +37,12 @@ async def create_core_config(
@router.get("/{core_id}", response_model=CoreResponse)
async def get_core_config(
core_id: int, _: AdminDetails = Depends(require_permission("cores", "read")), db: AsyncSession = Depends(get_db)
) -> dict:
) -> CoreResponse:
"""Get a core configuration by its ID."""
return await core_operator.get_validated_core_config(db, core_id)
db_core_config = await core_operator.get_validated_core_config(db, core_id)
xray_version = await get_xray_version_by_core_id(db, core_id)
core = CoreResponse.model_validate(db_core_config)
return core.model_copy(update={"xray_version": xray_version})


@router.put("/{core_id}", response_model=CoreResponse)
Expand Down
10 changes: 9 additions & 1 deletion app/subscription/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
from urllib.parse import quote, urlencode

from app.models.subscription import SubscriptionInboundData

from packaging.version import Version

def is_new_xray(version: str | None) -> bool:
if not version:
return False
try:
return Version(version.lstrip("v")) >= Version("26.6.22")
except Exception:
return False

class BaseSubscription:
def __init__(
Expand Down
Loading