Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ jobs:
- name: Check formatting (black)
working-directory: ladybug/tools/python_api
run: |
uv pip install black
uv pip install black==26.3.0
.venv/bin/black --check src_py test

- name: Run ruff check
Expand All @@ -140,7 +140,7 @@ jobs:

- name: Update submodules
working-directory: ladybug
run: git submodule update --init --recursive dataset
run: git submodule update --init --recursive dataset extension

- name: Checkout ladybug-python into ladybug/tools/python_api
uses: actions/checkout@v4
Expand Down Expand Up @@ -177,7 +177,7 @@ jobs:
- name: Check formatting (black)
working-directory: ladybug/tools/python_api
run: |
uv pip install black
uv pip install black==26.3.0
.venv/bin/black --check src_py test

- name: Run ruff check
Expand All @@ -191,6 +191,7 @@ jobs:
GEN: Ninja
CMAKE_C_COMPILER_LAUNCHER: ccache
CMAKE_CXX_COMPILER_LAUNCHER: ccache
EXTRA_CMAKE_FLAGS: -DBUILD_EXTENSIONS=json -DEXTENSION_STATIC_LINK_LIST=json
run: |
make python
cp tools/python_api/src_py/*.py tools/python_api/build/ladybug/
Expand Down
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 26.3.0
hooks:
- id: black

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.12
hooks:
- id: ruff-check
12 changes: 11 additions & 1 deletion src_cpp/include/cached_import/py_cached_modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ class PolarsCachedItem : public PythonCachedItem {
};

class PyarrowCachedItem : public PythonCachedItem {
class ArrayCachedItem : public PythonCachedItem {
public:
explicit ArrayCachedItem(PythonCachedItem* parent)
: PythonCachedItem("Array", parent), _import_from_c("_import_from_c", this) {}

PythonCachedItem _import_from_c;
};

class RecordBatchCachedItem : public PythonCachedItem {
public:
explicit RecordBatchCachedItem(PythonCachedItem* parent)
Expand Down Expand Up @@ -132,8 +140,10 @@ class PyarrowCachedItem : public PythonCachedItem {
class LibCachedItem : public PythonCachedItem {
public:
explicit LibCachedItem(PythonCachedItem* parent)
: PythonCachedItem("lib", parent), RecordBatch(this), Schema(this), Table(this) {}
: PythonCachedItem("lib", parent), Array(this), RecordBatch(this), Schema(this),
Table(this) {}

ArrayCachedItem Array;
RecordBatchCachedItem RecordBatch;
SchemaCachedItem Schema;
TableCachedItem Table;
Expand Down
2 changes: 2 additions & 0 deletions src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class PyConnection {
const py::dict& params);

std::unique_ptr<PyQueryResult> query(const std::string& statement);
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement,
int64_t chunkSize);

void setMaxNumThreadForExec(uint64_t numThreads);

Expand Down
1 change: 1 addition & 0 deletions src_cpp/include/py_query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PyQueryResult {
py::object getAsDF();

lbug::pyarrow::Table getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes);
py::dict getCSR();

py::list getColumnDataTypes();

Expand Down
10 changes: 10 additions & 0 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ void PyConnection::initialize(py::handle& m) {
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::dict())
.def("query", &PyConnection::query, py::arg("statement"))
.def("query_as_arrow", &PyConnection::queryAsArrow, py::arg("statement"),
py::arg("chunk_size"))
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("prepare", &PyConnection::prepare, py::arg("query"),
Expand Down Expand Up @@ -175,6 +177,14 @@ std::unique_ptr<PyQueryResult> PyConnection::query(const std::string& statement)
return checkAndWrapQueryResult(queryResult);
}

std::unique_ptr<PyQueryResult> PyConnection::queryAsArrow(const std::string& statement,
int64_t chunkSize) {
py::gil_scoped_release release;
auto queryResult = conn->queryAsArrow(statement, chunkSize);
py::gil_scoped_acquire acquire;
return checkAndWrapQueryResult(queryResult);
}

void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
conn->setMaxNumThreadForExec(numThreads);
}
Expand Down
50 changes: 50 additions & 0 deletions src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "common/arrow/arrow_row_batch.h"
#include "common/constants.h"
#include "common/exception/not_implemented.h"
#include "common/exception/runtime.h"
#include "common/types/uuid.h"
#include "common/types/value/nested.h"
#include "common/types/value/node.h"
#include "common/types/value/rel.h"
#include "datetime.h" // python lib
#include "include/py_query_result_converter.h"
#include "main/query_result/arrow_query_result.h"

using namespace lbug::common;
using lbug::importCache;
Expand All @@ -30,6 +32,7 @@ void PyQueryResult::initialize(py::handle& m) {
.def("close", &PyQueryResult::close)
.def("getAsDF", &PyQueryResult::getAsDF)
.def("getAsArrow", &PyQueryResult::getAsArrow)
.def("getCSR", &PyQueryResult::getCSR)
.def("getColumnNames", &PyQueryResult::getColumnNames)
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes)
.def("resetIterator", &PyQueryResult::resetIterator)
Expand Down Expand Up @@ -85,6 +88,27 @@ void PyQueryResult::close() {
}
}

namespace {

py::object importCSRArrowArray(lbug::main::ArrowQueryResult::CSRArrowArray& array) {
auto arrayImportFunc = importCache->pyarrow.lib.Array._import_from_c();
return arrayImportFunc((std::uint64_t)&array.array, (std::uint64_t)&array.schema);
}

py::dict buildCSRResult(lbug::main::ArrowQueryResult::CSRArrowArrays arrays) {
py::dict result;
result["indptr"] = importCSRArrowArray(arrays.indptr);
result["indices"] = importCSRArrowArray(arrays.indices);
if (arrays.edgeIDs.has_value()) {
result["edge_ids"] = importCSRArrowArray(*arrays.edgeIDs);
} else {
result["edge_ids"] = py::none();
}
return result;
}

} // namespace

static py::object converTimestampToPyObject(timestamp_t& timestamp) {
int32_t year = 0, month = 0, day = 0, hour = 0, min = 0, sec = 0, micros = 0;
date_t date;
Expand Down Expand Up @@ -320,6 +344,23 @@ py::object PyQueryResult::getArrowChunks(const std::vector<LogicalType>& types,

lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
bool fallbackExtensionTypes) {
if (queryResult->getType() == QueryResultType::ARROW) {
auto types = queryResult->getColumnDataTypes();
auto names = queryResult->getColumnNames();
py::list batches;
auto batchImportFunc = importCache->pyarrow.lib.RecordBatch._import_from_c();
while (queryResult->hasNextArrowChunk()) {
auto data = queryResult->getNextArrowChunk(chunkSize);
auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes);
batches.append(
batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get()));
}
auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes);
auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches();
auto schemaImportFunc = importCache->pyarrow.lib.Schema._import_from_c();
auto schemaObj = schemaImportFunc((std::uint64_t)schema.get());
return py::cast<lbug::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
}
auto types = queryResult->getColumnDataTypes();
auto names = queryResult->getColumnNames();
py::list batches = getArrowChunks(types, names, chunkSize, fallbackExtensionTypes);
Expand All @@ -330,6 +371,15 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
return py::cast<lbug::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
}

py::dict PyQueryResult::getCSR() {
if (auto* arrowQueryResult = dynamic_cast<lbug::main::ArrowQueryResult*>(queryResult);
arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) {
return buildCSRResult(arrowQueryResult->getCSRArrowArrays());
}
throw RuntimeException(
"CSR export is only supported for Arrow query results with native CSR metadata.");
}

py::list PyQueryResult::getColumnDataTypes() {
auto columnDataTypes = queryResult->getColumnDataTypes();
py::tuple result(columnDataTypes.size());
Expand Down
4 changes: 3 additions & 1 deletion src_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .connection import Connection # noqa: E402
from .database import Database # noqa: E402
from .prepared_statement import PreparedStatement # noqa: E402
from .query_result import QueryResult # noqa: E402
from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402
from .types import Type # noqa: E402

_VERSION_INFO: tuple[str, int] | None = None
Expand All @@ -80,7 +80,9 @@ def __getattr__(name: str) -> str | int:

__all__ = [
"AsyncConnection",
"ArrowQueryResult",
"Connection",
"CSRResult",
"Database",
"PreparedStatement",
"QueryResult",
Expand Down
3 changes: 3 additions & 0 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,9 @@ def getAsArrow(self, *_args: Any, **_kwargs: Any) -> Any:
"Arrow export is not yet implemented in C-API backend"
)

def getCSR(self, *_args: Any, **_kwargs: Any) -> Any:
raise NotImplementedError("CSR export is not yet implemented in C-API backend")

def getAsDF(self) -> Any:
raise NotImplementedError(
"DataFrame export is not yet implemented in C-API backend"
Expand Down
23 changes: 22 additions & 1 deletion src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ._backend import get_capi_module, get_pybind_module
from .prepared_statement import PreparedStatement
from .query_result import QueryResult
from .query_result import ArrowQueryResult, QueryResult

if TYPE_CHECKING:
import sys
Expand Down Expand Up @@ -369,6 +369,27 @@ def execute(
all_query_results.append(next_query_result)
return all_query_results

def query_as_arrow(self, query: str, chunk_size: int) -> ArrowQueryResult:
"""
Execute a query with the native Arrow collector path.

This is the efficient path for CSR-aware Arrow export.
"""
self.init_connection()
if not self._using_pybind_backend():
msg = "query_as_arrow requires the pybind backend"
raise NotImplementedError(msg)
query_result_internal = self._get_pybind_connection().query_as_arrow(
query, chunk_size
)
if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())
current_query_result = ArrowQueryResult(
self, query_result_internal, native_chunk_size=chunk_size
)
self._register_query_result(current_query_result)
return current_query_result

def _prepare(
self,
query: str,
Expand Down
52 changes: 52 additions & 0 deletions src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from .constants import DST, ID, LABEL, NODES, RELS, SRC
Expand Down Expand Up @@ -525,6 +526,57 @@ def rows_as_dict(self, state=True) -> Self:
return self


class ArrowQueryResult(QueryResult):
"""QueryResult backed by the native Arrow collector path."""

def __init__(
self, connection: Any, query_result: Any, native_chunk_size: int
) -> None:
super().__init__(connection, query_result)
self._native_chunk_size = native_chunk_size

def get_as_arrow(
self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False
) -> pa.Table:
"""
Get the query result as a PyArrow Table.

Arrow-native results preserve the execution-time chunking chosen by
`Connection.query_as_arrow(...)`. Requesting `None`, `0`, or `-1`
reuses that native chunk size instead of rechunking the result.
"""
if chunk_size is None or chunk_size <= 0:
chunk_size = self._native_chunk_size
return super().get_as_arrow(
chunk_size, fallbackExtensionTypes=fallbackExtensionTypes
)

def csr(self) -> CSRResult:
"""
Get native CSR arrays from an Arrow query result.

This is available only for Arrow results with CSR metadata, typically
from `Connection.query_as_arrow(...)` on relationship-shaped projections.
"""
self.check_for_query_result_close()

csr = self._query_result.getCSR()
return CSRResult(
indptr=csr["indptr"],
indices=csr["indices"],
edge_ids=csr["edge_ids"],
)


@dataclass(frozen=True)
class CSRResult:
"""Native CSR arrays returned by an Arrow query result."""

indptr: pa.Array
indices: pa.Array
edge_ids: pa.Array | None = None


def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]:
if len(columns) != len(row):
msg = "Number of columns in output row does not match number of columns"
Expand Down
6 changes: 6 additions & 0 deletions test/capi_xfails.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
"test/test_arrow.py::test_to_arrow_map",
"test/test_arrow.py::test_to_arrow_array",
"test/test_arrow.py::test_to_arrow_complex",
"test/test_arrow.py::test_query_as_arrow_csr_with_rel_ids",
"test/test_arrow.py::test_query_as_arrow_csr_with_extra_columns",
"test/test_arrow.py::test_query_as_arrow_csr_without_rel_ids",
"test/test_arrow.py::test_query_as_arrow_csr_rejects_non_csr_shape",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_basic",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_filtering",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pandas",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pyarrow",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_empty_result",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_count",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_arrow_node_and_rel_table",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_native_node_and_arrow_rel_table",
"test/test_async_connection.py::test_async_scan_df",
"test/test_blob_parameter.py::test_bytes_param_udf",
"test/test_df.py::test_to_df",
Expand Down
Loading
Loading