Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/blaxel/core/sandbox/default/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..client.api.drive.delete_drives_mount_mount_path import (
asyncio as delete_drives_mount,
)
from ..client.api.drive.get_drives_mount import asyncio as get_drives_mount
from ..client.api.drive.get_drives_mount import asyncio_detailed as get_drives_mount
from ..client.api.drive.post_drives_mount import asyncio as post_drives_mount
from ..client.client import Client
from ..client.models import (
Expand All @@ -14,7 +14,7 @@
DriveUnmountResponse,
ErrorResponse,
)
from ..transient_retry import retry_on_transient_reset_async
from ..transient_retry import retry_on_transient_sandbox_read_async
from ..types import SandboxConfiguration
from .action import SandboxAction

Expand Down Expand Up @@ -102,18 +102,19 @@ async def list(self) -> List[DriveMountInfo]:
List of DriveMountInfo for each mounted drive
"""

async def list_once() -> List[DriveMountInfo]:
async def list_once():
client = Client(
base_url=self.url,
headers={**settings.headers, **self.sandbox_config.headers},
)

async with client:
response = await get_drives_mount(client=client)
if response is None:
raise Exception("Failed to list drives")
if isinstance(response, ErrorResponse):
raise Exception(f"List drives failed: {response.error}")
return list(response.mounts) if response.mounts else []

return await retry_on_transient_reset_async(list_once)
return await get_drives_mount(client=client)

api_response = await retry_on_transient_sandbox_read_async(list_once)
response = api_response.parsed
if response is None:
raise Exception("Failed to list drives")
if isinstance(response, ErrorResponse):
raise Exception(f"List drives failed: {response.error}")
return list(response.mounts) if response.mounts else []
8 changes: 7 additions & 1 deletion src/blaxel/core/sandbox/default/network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import httpx

from ..transient_retry import retry_on_transient_sandbox_read_async
from ..types import SandboxConfiguration
from .action import SandboxAction

IDEMPOTENT_READ_METHODS = {"GET", "HEAD", "OPTIONS"}


class SandboxNetwork(SandboxAction):
def __init__(self, sandbox_config: SandboxConfiguration):
Expand All @@ -24,4 +27,7 @@ async def fetch(
normalized_path = path if path.startswith("/") else f"/{path}"
url = f"/port/{port}{normalized_path}"
client = self.get_client()
return await client.request(method, url, **kwargs)
fetch_once = lambda: client.request(method, url, **kwargs)
if method.upper() in IDEMPOTENT_READ_METHODS:
return await retry_on_transient_sandbox_read_async(fetch_once)
return await fetch_once()
24 changes: 14 additions & 10 deletions src/blaxel/core/sandbox/default/system.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ...common.settings import settings
from ..client.api.system.get_health import asyncio as get_health
from ..client.api.system.get_health import asyncio_detailed as get_health
from ..client.api.system.post_upgrade import asyncio as post_upgrade
from ..client.client import Client
from ..client.models import ErrorResponse, HealthResponse, SuccessResponse, UpgradeRequest
from ..transient_retry import retry_on_transient_sandbox_read_async
from ..types import SandboxConfiguration
from .action import SandboxAction

Expand Down Expand Up @@ -57,13 +58,16 @@ async def health(self) -> HealthResponse:
Returns:
HealthResponse with system status information
"""
client = Client(
base_url=self.url,
headers={**settings.headers, **self.sandbox_config.headers},
)
async def health_once():
client = Client(
base_url=self.url,
headers={**settings.headers, **self.sandbox_config.headers},
)

async with client:
response = await get_health(client=client)
if response is None:
raise Exception("Failed to get health status")
return response
async with client:
return await get_health(client=client)

api_response = await retry_on_transient_sandbox_read_async(health_once)
if api_response.parsed is None:
raise Exception("Failed to get health status")
return api_response.parsed
23 changes: 12 additions & 11 deletions src/blaxel/core/sandbox/sync/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..client.api.drive.delete_drives_mount_mount_path import (
sync as delete_drives_mount,
)
from ..client.api.drive.get_drives_mount import sync as get_drives_mount
from ..client.api.drive.get_drives_mount import sync_detailed as get_drives_mount
from ..client.api.drive.post_drives_mount import sync as post_drives_mount
from ..client.client import Client
from ..client.models import (
Expand All @@ -14,7 +14,7 @@
DriveUnmountResponse,
ErrorResponse,
)
from ..transient_retry import retry_on_transient_reset
from ..transient_retry import retry_on_transient_sandbox_read
from ..types import SandboxConfiguration
from .action import SyncSandboxAction

Expand Down Expand Up @@ -102,18 +102,19 @@ def list(self) -> List[DriveMountInfo]:
List of DriveMountInfo for each mounted drive
"""

def list_once() -> List[DriveMountInfo]:
def list_once():
client = Client(
base_url=self.url,
headers={**settings.headers, **self.sandbox_config.headers},
)

with client:
response = get_drives_mount(client=client)
if response is None:
raise Exception("Failed to list drives")
if isinstance(response, ErrorResponse):
raise Exception(f"List drives failed: {response.error}")
return list(response.mounts) if response.mounts else []

return retry_on_transient_reset(list_once)
return get_drives_mount(client=client)

api_response = retry_on_transient_sandbox_read(list_once)
response = api_response.parsed
if response is None:
raise Exception("Failed to list drives")
if isinstance(response, ErrorResponse):
raise Exception(f"List drives failed: {response.error}")
return list(response.mounts) if response.mounts else []
8 changes: 7 additions & 1 deletion src/blaxel/core/sandbox/sync/network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import httpx

from ..transient_retry import retry_on_transient_sandbox_read
from ..types import SandboxConfiguration
from .action import SyncSandboxAction

IDEMPOTENT_READ_METHODS = {"GET", "HEAD", "OPTIONS"}


class SyncSandboxNetwork(SyncSandboxAction):
def __init__(self, sandbox_config: SandboxConfiguration):
Expand All @@ -22,4 +25,7 @@ def fetch(self, port: int, path: str = "/", method: str = "GET", **kwargs) -> ht
normalized_path = path if path.startswith("/") else f"/{path}"
url = f"/port/{port}{normalized_path}"
with self.get_client() as client:
return client.request(method, url, **kwargs)
fetch_once = lambda: client.request(method, url, **kwargs)
if method.upper() in IDEMPOTENT_READ_METHODS:
return retry_on_transient_sandbox_read(fetch_once)
return fetch_once()
24 changes: 14 additions & 10 deletions src/blaxel/core/sandbox/sync/system.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ...common.settings import settings
from ..client.api.system.get_health import sync as get_health
from ..client.api.system.get_health import sync_detailed as get_health
from ..client.api.system.post_upgrade import sync as post_upgrade
from ..client.client import Client
from ..client.models import ErrorResponse, HealthResponse, SuccessResponse, UpgradeRequest
from ..transient_retry import retry_on_transient_sandbox_read
from ..types import SandboxConfiguration
from .action import SyncSandboxAction

Expand Down Expand Up @@ -57,13 +58,16 @@ def health(self) -> HealthResponse:
Returns:
HealthResponse with system status information
"""
client = Client(
base_url=self.url,
headers={**settings.headers, **self.sandbox_config.headers},
)
def health_once():
client = Client(
base_url=self.url,
headers={**settings.headers, **self.sandbox_config.headers},
)

with client:
response = get_health(client=client)
if response is None:
raise Exception("Failed to get health status")
return response
with client:
return get_health(client=client)

api_response = retry_on_transient_sandbox_read(health_once)
if api_response.parsed is None:
raise Exception("Failed to get health status")
return api_response.parsed
80 changes: 80 additions & 0 deletions src/blaxel/core/sandbox/transient_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import time
from collections.abc import Awaitable, Callable, Iterator
from http import HTTPStatus
from typing import TypeVar

import httpx
Expand Down Expand Up @@ -33,6 +34,13 @@

DEFAULT_BASE_DELAY_SECONDS = 0.2
DEFAULT_MAX_DELAY_SECONDS = 2.0
TRANSIENT_SANDBOX_READ_STATUSES = {
425,
429,
502,
503,
504,
}


def _walk_error_chain(error: BaseException) -> Iterator[BaseException]:
Expand Down Expand Up @@ -87,6 +95,22 @@ def is_transient_reset_error(error: BaseException) -> bool:
return any(marker in message for message in messages for marker in TRANSIENT_RESET_MARKERS)


def _coerce_status_code(status: object) -> int | None:
if isinstance(status, HTTPStatus):
return int(status.value)
if isinstance(status, int):
return status
return None


def is_transient_sandbox_read_response(response: object) -> bool:
"""True for gateway responses that can happen while a sandbox resumes."""
status_code = _coerce_status_code(getattr(response, "status_code", None))
if status_code is None:
return False
return status_code in TRANSIENT_SANDBOX_READ_STATUSES


def _backoff_delay_seconds(
attempt: int,
base_delay_seconds: float,
Expand Down Expand Up @@ -120,6 +144,34 @@ async def retry_on_transient_reset_async(
await asyncio.sleep(delay)


async def retry_on_transient_sandbox_read_async(
fn: Callable[[], Awaitable[T]],
*,
retries: int | None = None,
base_delay_seconds: float = DEFAULT_BASE_DELAY_SECONDS,
max_delay_seconds: float = DEFAULT_MAX_DELAY_SECONDS,
) -> T:
retry_budget = settings.sandbox_read_retries if retries is None else retries
attempt = 0
while True:
try:
result = await fn()
except Exception as error:
attempt += 1
if retry_budget <= 0 or attempt > retry_budget or not is_transient_reset_error(error):
raise
else:
if not is_transient_sandbox_read_response(result):
return result
attempt += 1
if retry_budget <= 0 or attempt > retry_budget:
return result

delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds)
if delay:
await asyncio.sleep(delay)


def retry_on_transient_reset(
fn: Callable[[], T],
*,
Expand All @@ -139,3 +191,31 @@ def retry_on_transient_reset(
delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds)
if delay:
time.sleep(delay)


def retry_on_transient_sandbox_read(
fn: Callable[[], T],
*,
retries: int | None = None,
base_delay_seconds: float = DEFAULT_BASE_DELAY_SECONDS,
max_delay_seconds: float = DEFAULT_MAX_DELAY_SECONDS,
) -> T:
retry_budget = settings.sandbox_read_retries if retries is None else retries
attempt = 0
while True:
try:
result = fn()
except Exception as error:
attempt += 1
if retry_budget <= 0 or attempt > retry_budget or not is_transient_reset_error(error):
raise
else:
if not is_transient_sandbox_read_response(result):
return result
attempt += 1
if retry_budget <= 0 or attempt > retry_budget:
return result

delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds)
if delay:
time.sleep(delay)
Loading