From 814468bf93e8f7496f853e77394bdfd450b10f42 Mon Sep 17 00:00:00 2001 From: Ritwij Aryan Parmar Date: Wed, 17 Jun 2026 13:55:03 -0400 Subject: [PATCH] Harden sandbox reads during resume --- src/blaxel/core/sandbox/default/drive.py | 23 +++-- src/blaxel/core/sandbox/default/network.py | 8 +- src/blaxel/core/sandbox/default/system.py | 24 +++-- src/blaxel/core/sandbox/sync/drive.py | 23 +++-- src/blaxel/core/sandbox/sync/network.py | 8 +- src/blaxel/core/sandbox/sync/system.py | 24 +++-- src/blaxel/core/sandbox/transient_retry.py | 80 +++++++++++++++ tests/core/test_sandbox_transient_retry.py | 112 +++++++++++++++++++++ 8 files changed, 258 insertions(+), 44 deletions(-) diff --git a/src/blaxel/core/sandbox/default/drive.py b/src/blaxel/core/sandbox/default/drive.py index 357ec1b9..f7f113ca 100644 --- a/src/blaxel/core/sandbox/default/drive.py +++ b/src/blaxel/core/sandbox/default/drive.py @@ -4,7 +4,7 @@ from ..client.api.drive.delete_drives_mount_mount_path import ( asyncio as delete_drives_mount, ) -from ..client.api.drive.get_drives_mount import asyncio as get_drives_mount +from ..client.api.drive.get_drives_mount import asyncio_detailed as get_drives_mount from ..client.api.drive.post_drives_mount import asyncio as post_drives_mount from ..client.client import Client from ..client.models import ( @@ -14,7 +14,7 @@ DriveUnmountResponse, ErrorResponse, ) -from ..transient_retry import retry_on_transient_reset_async +from ..transient_retry import retry_on_transient_sandbox_read_async from ..types import SandboxConfiguration from .action import SandboxAction @@ -102,18 +102,19 @@ async def list(self) -> List[DriveMountInfo]: List of DriveMountInfo for each mounted drive """ - async def list_once() -> List[DriveMountInfo]: + async def list_once(): client = Client( base_url=self.url, headers={**settings.headers, **self.sandbox_config.headers}, ) async with client: - response = await get_drives_mount(client=client) - if response is None: - raise Exception("Failed to list drives") - if isinstance(response, ErrorResponse): - raise Exception(f"List drives failed: {response.error}") - return list(response.mounts) if response.mounts else [] - - return await retry_on_transient_reset_async(list_once) + return await get_drives_mount(client=client) + + api_response = await retry_on_transient_sandbox_read_async(list_once) + response = api_response.parsed + if response is None: + raise Exception("Failed to list drives") + if isinstance(response, ErrorResponse): + raise Exception(f"List drives failed: {response.error}") + return list(response.mounts) if response.mounts else [] diff --git a/src/blaxel/core/sandbox/default/network.py b/src/blaxel/core/sandbox/default/network.py index 90361d39..0d1b465d 100644 --- a/src/blaxel/core/sandbox/default/network.py +++ b/src/blaxel/core/sandbox/default/network.py @@ -1,8 +1,11 @@ import httpx +from ..transient_retry import retry_on_transient_sandbox_read_async from ..types import SandboxConfiguration from .action import SandboxAction +IDEMPOTENT_READ_METHODS = {"GET", "HEAD", "OPTIONS"} + class SandboxNetwork(SandboxAction): def __init__(self, sandbox_config: SandboxConfiguration): @@ -24,4 +27,7 @@ async def fetch( normalized_path = path if path.startswith("/") else f"/{path}" url = f"/port/{port}{normalized_path}" client = self.get_client() - return await client.request(method, url, **kwargs) + fetch_once = lambda: client.request(method, url, **kwargs) + if method.upper() in IDEMPOTENT_READ_METHODS: + return await retry_on_transient_sandbox_read_async(fetch_once) + return await fetch_once() diff --git a/src/blaxel/core/sandbox/default/system.py b/src/blaxel/core/sandbox/default/system.py index b63b5a50..d0df4f8a 100644 --- a/src/blaxel/core/sandbox/default/system.py +++ b/src/blaxel/core/sandbox/default/system.py @@ -1,8 +1,9 @@ from ...common.settings import settings -from ..client.api.system.get_health import asyncio as get_health +from ..client.api.system.get_health import asyncio_detailed as get_health from ..client.api.system.post_upgrade import asyncio as post_upgrade from ..client.client import Client from ..client.models import ErrorResponse, HealthResponse, SuccessResponse, UpgradeRequest +from ..transient_retry import retry_on_transient_sandbox_read_async from ..types import SandboxConfiguration from .action import SandboxAction @@ -57,13 +58,16 @@ async def health(self) -> HealthResponse: Returns: HealthResponse with system status information """ - client = Client( - base_url=self.url, - headers={**settings.headers, **self.sandbox_config.headers}, - ) + async def health_once(): + client = Client( + base_url=self.url, + headers={**settings.headers, **self.sandbox_config.headers}, + ) - async with client: - response = await get_health(client=client) - if response is None: - raise Exception("Failed to get health status") - return response + async with client: + return await get_health(client=client) + + api_response = await retry_on_transient_sandbox_read_async(health_once) + if api_response.parsed is None: + raise Exception("Failed to get health status") + return api_response.parsed diff --git a/src/blaxel/core/sandbox/sync/drive.py b/src/blaxel/core/sandbox/sync/drive.py index 748c650c..1e2fb4e6 100644 --- a/src/blaxel/core/sandbox/sync/drive.py +++ b/src/blaxel/core/sandbox/sync/drive.py @@ -4,7 +4,7 @@ from ..client.api.drive.delete_drives_mount_mount_path import ( sync as delete_drives_mount, ) -from ..client.api.drive.get_drives_mount import sync as get_drives_mount +from ..client.api.drive.get_drives_mount import sync_detailed as get_drives_mount from ..client.api.drive.post_drives_mount import sync as post_drives_mount from ..client.client import Client from ..client.models import ( @@ -14,7 +14,7 @@ DriveUnmountResponse, ErrorResponse, ) -from ..transient_retry import retry_on_transient_reset +from ..transient_retry import retry_on_transient_sandbox_read from ..types import SandboxConfiguration from .action import SyncSandboxAction @@ -102,18 +102,19 @@ def list(self) -> List[DriveMountInfo]: List of DriveMountInfo for each mounted drive """ - def list_once() -> List[DriveMountInfo]: + def list_once(): client = Client( base_url=self.url, headers={**settings.headers, **self.sandbox_config.headers}, ) with client: - response = get_drives_mount(client=client) - if response is None: - raise Exception("Failed to list drives") - if isinstance(response, ErrorResponse): - raise Exception(f"List drives failed: {response.error}") - return list(response.mounts) if response.mounts else [] - - return retry_on_transient_reset(list_once) + return get_drives_mount(client=client) + + api_response = retry_on_transient_sandbox_read(list_once) + response = api_response.parsed + if response is None: + raise Exception("Failed to list drives") + if isinstance(response, ErrorResponse): + raise Exception(f"List drives failed: {response.error}") + return list(response.mounts) if response.mounts else [] diff --git a/src/blaxel/core/sandbox/sync/network.py b/src/blaxel/core/sandbox/sync/network.py index a3f6f801..e3385d92 100644 --- a/src/blaxel/core/sandbox/sync/network.py +++ b/src/blaxel/core/sandbox/sync/network.py @@ -1,8 +1,11 @@ import httpx +from ..transient_retry import retry_on_transient_sandbox_read from ..types import SandboxConfiguration from .action import SyncSandboxAction +IDEMPOTENT_READ_METHODS = {"GET", "HEAD", "OPTIONS"} + class SyncSandboxNetwork(SyncSandboxAction): def __init__(self, sandbox_config: SandboxConfiguration): @@ -22,4 +25,7 @@ def fetch(self, port: int, path: str = "/", method: str = "GET", **kwargs) -> ht normalized_path = path if path.startswith("/") else f"/{path}" url = f"/port/{port}{normalized_path}" with self.get_client() as client: - return client.request(method, url, **kwargs) + fetch_once = lambda: client.request(method, url, **kwargs) + if method.upper() in IDEMPOTENT_READ_METHODS: + return retry_on_transient_sandbox_read(fetch_once) + return fetch_once() diff --git a/src/blaxel/core/sandbox/sync/system.py b/src/blaxel/core/sandbox/sync/system.py index e8caf510..bbd01d90 100644 --- a/src/blaxel/core/sandbox/sync/system.py +++ b/src/blaxel/core/sandbox/sync/system.py @@ -1,8 +1,9 @@ from ...common.settings import settings -from ..client.api.system.get_health import sync as get_health +from ..client.api.system.get_health import sync_detailed as get_health from ..client.api.system.post_upgrade import sync as post_upgrade from ..client.client import Client from ..client.models import ErrorResponse, HealthResponse, SuccessResponse, UpgradeRequest +from ..transient_retry import retry_on_transient_sandbox_read from ..types import SandboxConfiguration from .action import SyncSandboxAction @@ -57,13 +58,16 @@ def health(self) -> HealthResponse: Returns: HealthResponse with system status information """ - client = Client( - base_url=self.url, - headers={**settings.headers, **self.sandbox_config.headers}, - ) + def health_once(): + client = Client( + base_url=self.url, + headers={**settings.headers, **self.sandbox_config.headers}, + ) - with client: - response = get_health(client=client) - if response is None: - raise Exception("Failed to get health status") - return response + with client: + return get_health(client=client) + + api_response = retry_on_transient_sandbox_read(health_once) + if api_response.parsed is None: + raise Exception("Failed to get health status") + return api_response.parsed diff --git a/src/blaxel/core/sandbox/transient_retry.py b/src/blaxel/core/sandbox/transient_retry.py index 70a7c3f7..7bf68c35 100644 --- a/src/blaxel/core/sandbox/transient_retry.py +++ b/src/blaxel/core/sandbox/transient_retry.py @@ -2,6 +2,7 @@ import random import time from collections.abc import Awaitable, Callable, Iterator +from http import HTTPStatus from typing import TypeVar import httpx @@ -33,6 +34,13 @@ DEFAULT_BASE_DELAY_SECONDS = 0.2 DEFAULT_MAX_DELAY_SECONDS = 2.0 +TRANSIENT_SANDBOX_READ_STATUSES = { + 425, + 429, + 502, + 503, + 504, +} def _walk_error_chain(error: BaseException) -> Iterator[BaseException]: @@ -87,6 +95,22 @@ def is_transient_reset_error(error: BaseException) -> bool: return any(marker in message for message in messages for marker in TRANSIENT_RESET_MARKERS) +def _coerce_status_code(status: object) -> int | None: + if isinstance(status, HTTPStatus): + return int(status.value) + if isinstance(status, int): + return status + return None + + +def is_transient_sandbox_read_response(response: object) -> bool: + """True for gateway responses that can happen while a sandbox resumes.""" + status_code = _coerce_status_code(getattr(response, "status_code", None)) + if status_code is None: + return False + return status_code in TRANSIENT_SANDBOX_READ_STATUSES + + def _backoff_delay_seconds( attempt: int, base_delay_seconds: float, @@ -120,6 +144,34 @@ async def retry_on_transient_reset_async( await asyncio.sleep(delay) +async def retry_on_transient_sandbox_read_async( + fn: Callable[[], Awaitable[T]], + *, + retries: int | None = None, + base_delay_seconds: float = DEFAULT_BASE_DELAY_SECONDS, + max_delay_seconds: float = DEFAULT_MAX_DELAY_SECONDS, +) -> T: + retry_budget = settings.sandbox_read_retries if retries is None else retries + attempt = 0 + while True: + try: + result = await fn() + except Exception as error: + attempt += 1 + if retry_budget <= 0 or attempt > retry_budget or not is_transient_reset_error(error): + raise + else: + if not is_transient_sandbox_read_response(result): + return result + attempt += 1 + if retry_budget <= 0 or attempt > retry_budget: + return result + + delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds) + if delay: + await asyncio.sleep(delay) + + def retry_on_transient_reset( fn: Callable[[], T], *, @@ -139,3 +191,31 @@ def retry_on_transient_reset( delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds) if delay: time.sleep(delay) + + +def retry_on_transient_sandbox_read( + fn: Callable[[], T], + *, + retries: int | None = None, + base_delay_seconds: float = DEFAULT_BASE_DELAY_SECONDS, + max_delay_seconds: float = DEFAULT_MAX_DELAY_SECONDS, +) -> T: + retry_budget = settings.sandbox_read_retries if retries is None else retries + attempt = 0 + while True: + try: + result = fn() + except Exception as error: + attempt += 1 + if retry_budget <= 0 or attempt > retry_budget or not is_transient_reset_error(error): + raise + else: + if not is_transient_sandbox_read_response(result): + return result + attempt += 1 + if retry_budget <= 0 or attempt > retry_budget: + return result + + delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds) + if delay: + time.sleep(delay) diff --git a/tests/core/test_sandbox_transient_retry.py b/tests/core/test_sandbox_transient_retry.py index 638f9208..501e333e 100644 --- a/tests/core/test_sandbox_transient_retry.py +++ b/tests/core/test_sandbox_transient_retry.py @@ -1,17 +1,24 @@ import asyncio +from http import HTTPStatus from typing import Any, cast import httpx import pytest from blaxel.core.common.settings import settings +from blaxel.core.sandbox.client.types import Response from blaxel.core.sandbox.default.filesystem import SandboxFileSystem +from blaxel.core.sandbox.default.network import SandboxNetwork from blaxel.core.sandbox.default.process import SandboxProcess from blaxel.core.sandbox.sync.filesystem import SyncSandboxFileSystem +from blaxel.core.sandbox.sync.network import SyncSandboxNetwork from blaxel.core.sandbox.transient_retry import ( is_transient_reset_error, + is_transient_sandbox_read_response, retry_on_transient_reset, retry_on_transient_reset_async, + retry_on_transient_sandbox_read, + retry_on_transient_sandbox_read_async, ) from blaxel.core.sandbox.types import ResponseError @@ -65,6 +72,13 @@ async def post(self, *args, **kwargs): raise result return result + async def request(self, *args, **kwargs): + self.calls += 1 + result = self.results.pop(0) + if isinstance(result, BaseException): + raise result + return result + class SyncSequenceClient: def __init__(self, *results): @@ -84,6 +98,13 @@ def get(self, *args, **kwargs): raise result return result + def request(self, *args, **kwargs): + self.calls += 1 + result = self.results.pop(0) + if isinstance(result, BaseException): + raise result + return result + def ok_json_response(data): return httpx.Response( @@ -93,6 +114,13 @@ def ok_json_response(data): ) +def status_response(status_code: int) -> httpx.Response: + return httpx.Response( + status_code, + request=httpx.Request("GET", "https://sandbox.test"), + ) + + def app_error_response() -> ResponseError: response = httpx.Response( 500, @@ -158,6 +186,19 @@ def test_classifier_rejects_application_responses(): assert not is_transient_reset_error(app_error_response()) +def test_read_response_classifier_accepts_resume_gateway_statuses(): + assert is_transient_sandbox_read_response(status_response(502)) + assert is_transient_sandbox_read_response(status_response(503)) + assert is_transient_sandbox_read_response( + Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE, content=b"", headers={}, parsed=None) + ) + + +def test_read_response_classifier_rejects_application_statuses(): + assert not is_transient_sandbox_read_response(status_response(500)) + assert not is_transient_sandbox_read_response(status_response(404)) + + @pytest.mark.asyncio async def test_real_httpx_transport_drop_is_classified_transient(): async with LoopbackFaultServer(close_without_response) as server: @@ -211,6 +252,23 @@ async def flaky(): assert calls == 2 +@pytest.mark.asyncio +async def test_async_sandbox_read_retry_recovers_from_resume_status(): + calls = 0 + + async def flaky_gateway(): + nonlocal calls + calls += 1 + if calls == 1: + return status_response(503) + return status_response(200) + + response = await retry_on_transient_sandbox_read_async(flaky_gateway, retries=1) + + assert response.status_code == 200 + assert calls == 2 + + def test_sync_retry_recovers_once(): calls = 0 @@ -225,6 +283,22 @@ def flaky(): assert calls == 2 +def test_sync_sandbox_read_retry_recovers_from_resume_status(): + calls = 0 + + def flaky_gateway(): + nonlocal calls + calls += 1 + if calls == 1: + return status_response(502) + return status_response(200) + + response = retry_on_transient_sandbox_read(flaky_gateway, retries=1) + + assert response.status_code == 200 + assert calls == 2 + + def test_sync_retry_does_not_retry_application_response(): calls = 0 @@ -252,6 +326,32 @@ async def test_async_filesystem_read_retries_transport_reset(monkeypatch): assert client.calls == 2 +@pytest.mark.asyncio +async def test_async_network_fetch_retries_resume_gateway_status(monkeypatch): + monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "1") + client = AsyncSequenceClient(status_response(503), status_response(200)) + network = cast(Any, object.__new__(SandboxNetwork)) + network.get_client = lambda: client + + response = await network.fetch(8080, "/health") + + assert response.status_code == 200 + assert client.calls == 2 + + +@pytest.mark.asyncio +async def test_async_network_fetch_does_not_retry_post_status(monkeypatch): + monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "1") + client = AsyncSequenceClient(status_response(503), status_response(200)) + network = cast(Any, object.__new__(SandboxNetwork)) + network.get_client = lambda: client + + response = await network.fetch(8080, "/mutate", method="POST") + + assert response.status_code == 503 + assert client.calls == 1 + + def test_sync_filesystem_read_retries_transport_reset(monkeypatch): monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "1") client = SyncSequenceClient( @@ -265,6 +365,18 @@ def test_sync_filesystem_read_retries_transport_reset(monkeypatch): assert client.calls == 2 +def test_sync_network_fetch_retries_resume_gateway_status(monkeypatch): + monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "1") + client = SyncSequenceClient(status_response(502), status_response(200)) + network = cast(Any, object.__new__(SyncSandboxNetwork)) + network.get_client = lambda: client + + response = network.fetch(8080, "/health") + + assert response.status_code == 200 + assert client.calls == 2 + + @pytest.mark.asyncio async def test_process_exec_is_not_retried_on_transport_reset(): client = AsyncSequenceClient(httpx.ConnectError("All connection attempts failed"))