Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,16 @@ result = await async_client.files.download(
)
```

For large async downloads, use `stream_download()` to process bytes as they arrive without buffering the full response in memory:

```python
async with async_client.files.stream_download(
url=await async_client.my_files_home() / "relative_folder/my-file.txt"
) as result:
async for bytes_chunk in result:
...
```

As a result, you will receive an object of type `FileDownloadResponse`, that you can iterate by byte chunks:

```python
Expand Down
42 changes: 31 additions & 11 deletions aidial_client/_http_client/_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from collections.abc import AsyncIterator, Callable, Mapping
from collections.abc import AsyncIterator, Mapping
from contextlib import asynccontextmanager, suppress
from http import HTTPStatus
from typing import Any
Expand All @@ -8,7 +8,7 @@

from aidial_client._auth import AsyncAuthValue, aget_combined_auth_headers
from aidial_client._exception import DialException
from aidial_client._http_client._base import BaseHTTPClient
from aidial_client._http_client._base import BaseHTTPClient, ErrorHandler
from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven
from aidial_client._internal_types._generic import ResponseT
from aidial_client._internal_types._http_request import FinalRequestOptions
Expand Down Expand Up @@ -51,8 +51,7 @@ async def request(
options: FinalRequestOptions,
cast_to: type[ResponseT],
remaining_retries: int | None = None,
on_http_error: Callable[[httpx.HTTPStatusError], DialException | None]
| None = None,
on_http_error: ErrorHandler | None = None,
) -> ResponseT:
retries = self._remaining_retries(remaining_retries, options)
auth_headers = await self.auth_headers()
Expand Down Expand Up @@ -101,16 +100,37 @@ async def request(
cast_to=cast_to,
remaining_retries=retries,
)
# Try to get a custom error from response status_code/code/message
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
err.response
)
raise raised_error from err
self._raise_for_status(response, on_http_error)

return process_block_response(cast_to=cast_to, response=response)

@asynccontextmanager
Comment thread
adubovik marked this conversation as resolved.
async def stream(
self,
*,
options: FinalRequestOptions,
on_http_error: ErrorHandler | None = None,
) -> AsyncIterator[httpx.Response]:
auth_headers = await self.auth_headers()
request = self._build_request(options, auth_headers)
try:
response = await self._internal_http_client.send(
request, stream=True
)
except httpx.TimeoutException as err:
raise DialException(
message="Request timed out",
status_code=HTTPStatus.REQUEST_TIMEOUT,
) from err
except httpx.HTTPError as err:
raise DialException(message=f"Request failed: {err}") from err

try:
self._raise_for_status(response, on_http_error)
yield response
finally:
await response.aclose()

@asynccontextmanager
async def stream_sse(
self,
Expand Down
19 changes: 19 additions & 0 deletions aidial_client/_http_client/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from http import HTTPStatus
from random import uniform
from typing import Generic, TypeVar
Expand All @@ -16,6 +17,8 @@
"_HttpInternalClientT", bound=httpx.Client | httpx.AsyncClient
)

ErrorHandler = Callable[[httpx.HTTPStatusError], DialException | None]


class BaseHTTPClient(ABC, Generic[_HttpInternalClientT, AuthValueT]):
_internal_http_client: _HttpInternalClientT
Expand Down Expand Up @@ -106,6 +109,22 @@ def _calculate_retry_sleep_seconds(
timeout = sleep_seconds + uniform(-0.5, 0.5) # noqa: S311
return max(0, timeout)

def _raise_for_status(
self,
response: httpx.Response,
on_http_error: ErrorHandler | None,
) -> None:
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
# Try to get a custom error from response status_code/code/message
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
err.response
)
raise raised_error from err

def _make_dial_error_from_response(
self,
response: httpx.Response,
Expand Down
15 changes: 4 additions & 11 deletions aidial_client/_http_client/_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from collections.abc import Callable, Iterator, Mapping
from collections.abc import Iterator, Mapping
from contextlib import contextmanager, suppress
from http import HTTPStatus
from typing import Any
Expand All @@ -8,7 +8,7 @@

from aidial_client._auth import SyncAuthValue, get_combined_auth_headers
from aidial_client._exception import DialException
from aidial_client._http_client._base import BaseHTTPClient
from aidial_client._http_client._base import BaseHTTPClient, ErrorHandler
from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven
from aidial_client._internal_types._generic import ResponseT
from aidial_client._internal_types._http_request import FinalRequestOptions
Expand Down Expand Up @@ -50,8 +50,7 @@ def request(
cast_to: type[ResponseT],
options: FinalRequestOptions,
remaining_retries: int | None = None,
on_http_error: Callable[[httpx.HTTPStatusError], DialException | None]
| None = None,
on_http_error: ErrorHandler | None = None,
) -> ResponseT:
retries = self._remaining_retries(remaining_retries, options)
auth_headers = self.auth_headers()
Expand Down Expand Up @@ -101,13 +100,7 @@ def request(
cast_to=cast_to,
remaining_retries=retries,
)
# Try to get a custom error from response status_code/code/message
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
err.response
)
raise raised_error from err
self._raise_for_status(response, on_http_error)

return process_block_response(cast_to=cast_to, response=response)

Expand Down
24 changes: 24 additions & 0 deletions aidial_client/helpers/storage_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from aidial_client._compatibility.pydantic_v1 import BaseModel
from aidial_client._constants import API_PREFIX
from aidial_client._exception import InvalidDialURLError, NotDialURLError
from aidial_client._internal_types._http_request import FinalRequestOptions
from aidial_client._utils._dict import remove_none
from aidial_client.helpers._url import enforce_trailing_slash

StorageResourceType = Literal["files", "conversations", "prompts"]
Expand Down Expand Up @@ -156,3 +158,25 @@ def get_display_name(self, url: str) -> str:
Get the display name of the resource from the URL
"""
return self.get_storage_resource(url).bucket_path

def _prepare_download_request(
self,
url: str | PurePosixPath,
etag_if_match: str | None,
) -> tuple[FinalRequestOptions, str]:
storage_resource = self.get_storage_resource(str(url))

if storage_resource.filename is None:
raise InvalidDialURLError("URL points to a directory, not a file")

options = FinalRequestOptions(
method="GET",
url=urljoin(API_PREFIX, storage_resource.api_path),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
)

return options, storage_resource.filename
52 changes: 21 additions & 31 deletions aidial_client/resources/files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import PurePosixPath
from typing import Literal
from urllib.parse import urljoin
Expand All @@ -8,7 +10,6 @@
from aidial_client._exception import (
DialException,
EtagMismatchError,
InvalidDialURLError,
ResourceNotFoundError,
)
from aidial_client._internal_types._generic import NoneType
Expand Down Expand Up @@ -70,25 +71,13 @@ def download(
url: str | PurePosixPath,
etag_if_match: str | None = None,
) -> FileDownloadResponse:
storage_resource = self.get_storage_resource(str(url))
if storage_resource.filename is None:
raise InvalidDialURLError("URL points to a directory, not a file")
options, filename = self._prepare_download_request(url, etag_if_match)
response = self.http_client.request(
cast_to=httpx.Response,
options=FinalRequestOptions(
method="GET",
url=urljoin(API_PREFIX, storage_resource.api_path),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
options=options,
on_http_error=_files_error_processor,
)
return FileDownloadResponse(
response=response, filename=storage_resource.filename
)
return FileDownloadResponse(response=response, filename=filename)

def delete(
self,
Expand Down Expand Up @@ -188,25 +177,26 @@ async def download(
url: str | PurePosixPath,
etag_if_match: str | None = None,
) -> FileDownloadResponse:
storage_resource = self.get_storage_resource(str(url))
if storage_resource.filename is None:
raise InvalidDialURLError("URL points to a directory, not a file")
options, filename = self._prepare_download_request(url, etag_if_match)
response = await self.http_client.request(
cast_to=httpx.Response,
options=FinalRequestOptions(
method="GET",
url=urljoin(API_PREFIX, storage_resource.api_path),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
options=options,
on_http_error=_files_error_processor,
)
return FileDownloadResponse(
response=response, filename=storage_resource.filename
)
return FileDownloadResponse(response=response, filename=filename)

@asynccontextmanager
async def stream_download(
self,
url: str | PurePosixPath,
etag_if_match: str | None = None,
) -> AsyncIterator[FileDownloadResponse]:
options, filename = self._prepare_download_request(url, etag_if_match)
async with self.http_client.stream(
options=options,
on_http_error=_files_error_processor,
) as response:
yield FileDownloadResponse(response=response, filename=filename)

async def delete(
self,
Expand Down
63 changes: 63 additions & 0 deletions tests/resources/files/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Any, cast
from unittest.mock import AsyncMock

import httpx
import pytest

from aidial_client._client import AsyncDial
from aidial_client._exception import InvalidDialURLError
from tests.client_mock import MockStreamIterator


@pytest.mark.asyncio
async def test_stream_download_async_streams_and_closes_response():
captured_requests: list[httpx.Request] = []
captured_kwargs: list[dict[str, Any]] = []
captured_responses: list[httpx.Response] = []
client = AsyncDial(api_key="dummy", base_url="http://dial.core")
client._get_my_bucket = cast(Any, AsyncMock(return_value="test-bucket"))

async def send_mock(
request: httpx.Request, *, stream: bool = False, **kwargs: Any
) -> httpx.Response:
captured_requests.append(request)
captured_kwargs.append({"stream": stream, **kwargs})
response = httpx.Response(
status_code=200,
request=request,
stream=MockStreamIterator(mock_chunks=[b"hello ", b"world"]),
)
captured_responses.append(response)
return response

client._http_client._internal_http_client.send = cast(Any, send_mock)

async with client.files.stream_download(
url=await client.my_files_home() / "folder/file.txt"
) as response:
assert response.filename == "file.txt"
assert b"".join([chunk async for chunk in response]) == b"hello world"

assert (
captured_requests[0].url.path == "/v1/files/test-bucket/folder/file.txt"
)
assert captured_kwargs == [{"stream": True}]
assert captured_responses[0].is_closed is True


@pytest.mark.asyncio
async def test_stream_download_async_rejects_directory_url():
client = AsyncDial(api_key="dummy", base_url="http://dial.core")
client._get_my_bucket = cast(Any, AsyncMock(return_value="test-bucket"))
send_mock = AsyncMock()
client._http_client._internal_http_client.send = cast(Any, send_mock)

with pytest.raises(
InvalidDialURLError, match="URL points to a directory, not a file"
):
async with client.files.stream_download(
url="files/test-bucket/folder/"
):
pass

send_mock.assert_not_called()
Loading