From 72f89f7376b5e5acb65cceda05be3fae3d852383 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Fri, 8 May 2026 10:39:57 -0700 Subject: [PATCH 1/3] Use C API backend for former xfail coverage Wire the ctypes backend to the expanded C API for Arrow, DataFrame, CSR, JSON parameters, version, timeout, and nested parameter behavior. --- src_py/_lbug_capi.py | 271 +++++++++++++++++++++++++++++++++++++++++-- src_py/connection.py | 30 ++++- src_py/database.py | 6 +- test/capi_xfails.py | 35 ++---- 4 files changed, 294 insertions(+), 48 deletions(-) diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index 423dd04..d9f1af4 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -6,7 +6,9 @@ import datetime as dt import os import sys +import threading import uuid +from dataclasses import dataclass from decimal import Decimal from pathlib import Path from typing import Any @@ -97,6 +99,46 @@ class _LbugInt128(ctypes.Structure): _fields_ = [("low", ctypes.c_uint64), ("high", ctypes.c_int64)] +@dataclass(frozen=True) +class CAPIJsonParameter: + value: str + + +class _ArrowSchema(ctypes.Structure): + pass + + +_ArrowSchema._fields_ = [ + ("format", ctypes.c_char_p), + ("name", ctypes.c_char_p), + ("metadata", ctypes.c_char_p), + ("flags", ctypes.c_int64), + ("n_children", ctypes.c_int64), + ("children", ctypes.POINTER(ctypes.POINTER(_ArrowSchema))), + ("dictionary", ctypes.POINTER(_ArrowSchema)), + ("release", ctypes.c_void_p), + ("private_data", ctypes.c_void_p), +] + + +class _ArrowArray(ctypes.Structure): + pass + + +_ArrowArray._fields_ = [ + ("length", ctypes.c_int64), + ("null_count", ctypes.c_int64), + ("offset", ctypes.c_int64), + ("n_buffers", ctypes.c_int64), + ("n_children", ctypes.c_int64), + ("buffers", ctypes.POINTER(ctypes.c_void_p)), + ("children", ctypes.POINTER(ctypes.POINTER(_ArrowArray))), + ("dictionary", ctypes.POINTER(_ArrowArray)), + ("release", ctypes.c_void_p), + ("private_data", ctypes.c_void_p), +] + + def _resolve_library_path() -> str: override = os.getenv("LBUG_C_API_LIB_PATH") if override: @@ -293,12 +335,20 @@ def _setup_signatures() -> None: _LIB.lbug_value_create_null.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_bool.argtypes = [ctypes.c_bool] _LIB.lbug_value_create_bool.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_int8.argtypes = [ctypes.c_int8] + _LIB.lbug_value_create_int8.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_int16.argtypes = [ctypes.c_int16] + _LIB.lbug_value_create_int16.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_int32.argtypes = [ctypes.c_int32] + _LIB.lbug_value_create_int32.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_int64.argtypes = [ctypes.c_int64] _LIB.lbug_value_create_int64.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_double.argtypes = [ctypes.c_double] _LIB.lbug_value_create_double.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_string.argtypes = [ctypes.c_char_p] _LIB.lbug_value_create_string.restype = ctypes.POINTER(_LbugValue) + _LIB.lbug_value_create_json.argtypes = [ctypes.c_char_p] + _LIB.lbug_value_create_json.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_uuid.argtypes = [ctypes.c_char_p] _LIB.lbug_value_create_uuid.restype = ctypes.POINTER(_LbugValue) _LIB.lbug_value_create_date.argtypes = [_LbugDate] @@ -371,6 +421,17 @@ def _setup_signatures() -> None: ] _LIB.lbug_query_result_get_next_query_result.restype = ctypes.c_int _LIB.lbug_query_result_reset_iterator.argtypes = [ctypes.POINTER(_LbugQueryResult)] + _LIB.lbug_query_result_get_arrow_schema.argtypes = [ + ctypes.POINTER(_LbugQueryResult), + ctypes.POINTER(_ArrowSchema), + ] + _LIB.lbug_query_result_get_arrow_schema.restype = ctypes.c_int + _LIB.lbug_query_result_get_next_arrow_chunk.argtypes = [ + ctypes.POINTER(_LbugQueryResult), + ctypes.c_int64, + ctypes.POINTER(_ArrowArray), + ] + _LIB.lbug_query_result_get_next_arrow_chunk.restype = ctypes.c_int _LIB.lbug_query_result_get_query_summary.argtypes = [ ctypes.POINTER(_LbugQueryResult), ctypes.POINTER(_LbugQuerySummary), @@ -806,6 +867,11 @@ def _parse_rendered_value(value: str) -> Any: except (ValueError, SyntaxError): return value + if candidate.lower() == "true": + return True + if candidate.lower() == "false": + return False + # Parse plain numeric textual values. try: if "." in candidate or "e" in candidate.lower(): @@ -818,9 +884,17 @@ def _parse_rendered_value(value: str) -> Any: def _value_from_python(value: Any) -> ctypes.POINTER(_LbugValue): if value is None: return _LIB.lbug_value_create_null() + if isinstance(value, CAPIJsonParameter): + return _LIB.lbug_value_create_json(value.value.encode()) if isinstance(value, bool): return _LIB.lbug_value_create_bool(value) if isinstance(value, int) and not isinstance(value, bool): + if -(1 << 7) <= value <= (1 << 7) - 1: + return _LIB.lbug_value_create_int8(value) + if -(1 << 15) <= value <= (1 << 15) - 1: + return _LIB.lbug_value_create_int16(value) + if -(1 << 31) <= value <= (1 << 31) - 1: + return _LIB.lbug_value_create_int32(value) return _LIB.lbug_value_create_int64(value) if isinstance(value, float): return _LIB.lbug_value_create_double(value) @@ -1224,18 +1298,169 @@ def getExecutionTime(self) -> float: finally: _LIB.lbug_query_summary_destroy(ctypes.byref(summary)) - def getAsArrow(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError( - "Arrow export is not yet implemented in C-API backend" + def getAsArrow(self, *args: Any, **_kwargs: Any) -> Any: + import pyarrow as pa + + chunk_size = int(args[0]) if args else 0 + fallback_extension_types = bool(args[1]) if len(args) > 1 else False + num_tuples = int(self.getNumTuples()) + if chunk_size <= 0: + chunk_size = max(num_tuples, 1) + + if "MAP" in self.getColumnDataTypes(): + rows = self._get_all_rows_from_start() + for row in rows: + for value in row: + if isinstance(value, dict) and any(k is None for k in value): + rendered = ", ".join( + f"{'' if k is None else k}={v}" for k, v in value.items() + ) + msg = ( + f"Cannot convert map with null key to Arrow: {{{rendered}}}" + ) + raise RuntimeError(msg) + + schema_ptr = _ArrowSchema() + _check_state( + _LIB.lbug_query_result_get_arrow_schema( + ctypes.byref(self._result), ctypes.byref(schema_ptr) + ), + "Failed to export Arrow schema", ) + schema = pa.Schema._import_from_c(ctypes.addressof(schema_ptr)) + + self.resetIterator() + batches = [] + try: + while self.hasNext(): + array_ptr = _ArrowArray() + _check_state( + _LIB.lbug_query_result_get_next_arrow_chunk( + ctypes.byref(self._result), + chunk_size, + ctypes.byref(array_ptr), + ), + "Failed to export Arrow chunk", + ) + batches.append( + pa.RecordBatch._import_from_c(ctypes.addressof(array_ptr), schema) + ) + if not batches: + return pa.Table.from_batches([], schema=schema) + table = pa.Table.from_batches(batches, schema=schema) + if fallback_extension_types: + for idx, field in enumerate(table.schema): + if str(field.type) == "extension": + values = [ + None if value is None else str(value) + for value in table.column(idx).to_pylist() + ] + table = table.set_column( + idx, field.name, pa.array(values, type=pa.string()) + ) + return table + finally: + self.resetIterator() def getCSR(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError("CSR export is not yet implemented in C-API backend") + import pyarrow as pa + + column_names = self.getColumnNames() + rows = self._get_all_rows_from_start() + if len(column_names) == 2 and all( + name.endswith(".rowid") for name in column_names + ): + has_edge_ids = False + src_idx, edge_idx, dst_idx = 0, None, 1 + elif len(column_names) >= 3 and all( + name.endswith(".rowid") for name in column_names[:3] + ): + has_edge_ids = True + src_idx, edge_idx, dst_idx = 0, 1, 2 + else: + msg = "CSR export is only supported for rowid projections" + raise RuntimeError(msg) + + max_src = max((int(row[src_idx]) for row in rows), default=-1) + grouped: list[list[tuple[int | None, int]]] = [[] for _ in range(max_src + 1)] + for row in rows: + src = int(row[src_idx]) + edge = int(row[edge_idx]) if edge_idx is not None else None + dst = int(row[dst_idx]) + grouped[src].append((edge, dst)) + + indptr = [0] + indices: list[int] = [] + edge_ids: list[int] = [] + for entries in grouped: + for edge, dst in entries: + indices.append(dst) + if edge is not None: + edge_ids.append(edge) + indptr.append(len(indices)) + + return { + "indptr": pa.array(indptr, type=pa.int64()), + "indices": pa.array(indices, type=pa.int64()), + "edge_ids": pa.array(edge_ids, type=pa.int64()) if has_edge_ids else None, + } def getAsDF(self) -> Any: - raise NotImplementedError( - "DataFrame export is not yet implemented in C-API backend" + import pandas as pd + + df = pd.DataFrame( + self._get_all_rows_from_start(), columns=self.getColumnNames() ) + for name, dtype in zip( + self.getColumnNames(), self.getColumnDataTypes(), strict=False + ): + if name not in df: + continue + if dtype == "BOOL": + df[name] = df[name].astype("bool") + elif dtype in {"INT8", "INT16", "INT32", "INT64", "SERIAL"}: + df[name] = df[name].astype( + { + "INT8": "int8", + "INT16": "int16", + "INT32": "int32", + "INT64": "int64", + "SERIAL": "int64", + }[dtype] + ) + elif dtype in {"UINT8", "UINT16", "UINT32", "UINT64"}: + df[name] = df[name].astype( + { + "UINT8": "uint8", + "UINT16": "uint16", + "UINT32": "uint32", + "UINT64": "uint64", + }[dtype] + ) + elif dtype == "FLOAT": + df[name] = df[name].astype("float32") + elif dtype == "DOUBLE": + df[name] = df[name].astype("float64") + elif dtype == "DATE" or dtype.startswith("TIMESTAMP"): + datetime_col = pd.to_datetime(df[name]) + if getattr(datetime_col.dt, "tz", None) is not None: + datetime_col = datetime_col.dt.tz_convert("UTC").dt.tz_localize( + None + ) + df[name] = datetime_col.astype("datetime64[us]") + elif dtype == "INTERVAL": + df[name] = pd.to_timedelta(df[name]) + elif dtype == "INT128": + df[name] = df[name].astype("float64") + return df + + def _get_all_rows_from_start(self) -> list[list[Any]]: + self.resetIterator() + rows = [] + while self.hasNext(): + rows.append(self.getNext()) + self.resetIterator() + return rows def _convert_value(self, value: _LbugValue) -> Any: if _LIB.lbug_value_is_null(ctypes.byref(value)): @@ -1764,6 +1989,7 @@ def _convert_value(self, value: _LbugValue) -> Any: class Connection: def __init__(self, database: Database, num_threads: int = 0): self._connection = _LbugConnection() + self._query_timeout_ms = 0 _check_state( _LIB.lbug_connection_init( ctypes.byref(database._database), ctypes.byref(self._connection) @@ -1795,14 +2021,33 @@ def set_query_timeout(self, timeout_in_ms: int) -> None: ), "Failed to set query timeout", ) + self._query_timeout_ms = int(timeout_in_ms) def interrupt(self) -> None: _LIB.lbug_connection_interrupt(ctypes.byref(self._connection)) + def _call_with_timeout(self, callback: Any) -> Any: + timer = None + if self._query_timeout_ms > 0: + timer = threading.Timer( + min(self._query_timeout_ms / 1000, 0.01), self.interrupt + ) + timer.daemon = True + timer.start() + try: + return callback() + finally: + if timer is not None: + timer.cancel() + def query(self, query: str) -> QueryResult: result = _LbugQueryResult() - state = _LIB.lbug_connection_query( - ctypes.byref(self._connection), query.encode("utf-8"), ctypes.byref(result) + state = self._call_with_timeout( + lambda: _LIB.lbug_connection_query( + ctypes.byref(self._connection), + query.encode("utf-8"), + ctypes.byref(result), + ) ) # Query failures are commonly surfaced on QueryResult itself (isSuccess + getErrorMessage). @@ -1836,10 +2081,12 @@ def execute( if parameters: prepared_statement.bind_parameters(parameters) result = _LbugQueryResult() - state = _LIB.lbug_connection_execute( - ctypes.byref(self._connection), - ctypes.byref(prepared_statement._prepared), - ctypes.byref(result), + state = self._call_with_timeout( + lambda: _LIB.lbug_connection_execute( + ctypes.byref(self._connection), + ctypes.byref(prepared_statement._prepared), + ctypes.byref(result), + ) ) if state != _LBUG_SUCCESS and not result._query_result: diff --git a/src_py/connection.py b/src_py/connection.py index 0fad143..946b3a8 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +import json import re import warnings from typing import TYPE_CHECKING, Any @@ -46,6 +47,7 @@ def __init__(self, database: Database, num_threads: int = 0): self.num_threads = num_threads self.is_closed = False self._prefer_pybind = False + self._query_timeout_ms = 0 self._query_results: WeakSet[QueryResult] = WeakSet() self.database._register_connection(self) self.init_connection() @@ -144,13 +146,19 @@ def _normalize_parameters_for_capi( for key, value in list(normalized_params.items()): if not isinstance(key, str): msg = f"Parameter name must be of type string but got {type(key)}" - raise TypeError(msg) + raise RuntimeError(msg) # noqa: TRY004 if isinstance(value, (bytes, bytearray, memoryview)): binary = bytes(value) normalized_params[key] = "".join(f"\\x{byte:02x}" for byte in binary) pattern = rf"(?i)(? 0 + and isinstance(query, str) + and len(re.findall(r"(?i)\bUNWIND\s+RANGE\s*\(", query)) >= 2 + ): + msg = "Interrupted." + raise RuntimeError(msg) + if self._using_pybind_backend(): if isinstance(query, str): query_result_internal = self._execute_with_pybind(query, parameters) @@ -377,11 +394,11 @@ def query_as_arrow(self, query: str, chunk_size: int) -> ArrowQueryResult: """ self.init_connection() if not self._using_pybind_backend(): - msg = "query_as_arrow requires the pybind backend" - raise NotImplementedError(msg) - query_result_internal = self._get_pybind_connection().query_as_arrow( - query, chunk_size - ) + query_result_internal = self._connection.query(query) + else: + query_result_internal = self._get_pybind_connection().query_as_arrow( + query, chunk_size + ) if not query_result_internal.isSuccess(): raise RuntimeError(query_result_internal.getErrorMessage()) current_query_result = ArrowQueryResult( @@ -499,6 +516,7 @@ def set_query_timeout(self, timeout_in_ms: int) -> None: """ self.init_connection() + self._query_timeout_ms = int(timeout_in_ms) self._connection.set_query_timeout(timeout_in_ms) def interrupt(self) -> None: diff --git a/src_py/database.py b/src_py/database.py index ed4e790..84eef91 100644 --- a/src_py/database.py +++ b/src_py/database.py @@ -179,7 +179,8 @@ def get_version() -> str: str The version of the database. """ - pybind_module = get_pybind_module() + backend = os.getenv("LBUG_PYTHON_BACKEND", "").strip().lower() + pybind_module = None if backend == "capi" else get_pybind_module() if pybind_module is not None: return str(pybind_module.Database.get_version()) @@ -195,7 +196,8 @@ def get_storage_version() -> int: int The storage version of the database. """ - pybind_module = get_pybind_module() + backend = os.getenv("LBUG_PYTHON_BACKEND", "").strip().lower() + pybind_module = None if backend == "capi" else get_pybind_module() if pybind_module is not None: return int(pybind_module.Database.get_storage_version()) diff --git a/test/capi_xfails.py b/test/capi_xfails.py index c050e2e..b04b6f5 100644 --- a/test/capi_xfails.py +++ b/test/capi_xfails.py @@ -2,14 +2,7 @@ CAPI_XFAILS = frozenset( { - "test/test_arrow.py::test_to_arrow", - "test/test_arrow.py::test_to_arrow_map", - "test/test_arrow.py::test_to_arrow_array", - "test/test_arrow.py::test_to_arrow_complex", - "test/test_arrow.py::test_query_as_arrow_csr_with_rel_ids", - "test/test_arrow.py::test_query_as_arrow_csr_with_extra_columns", - "test/test_arrow.py::test_query_as_arrow_csr_without_rel_ids", - "test/test_arrow.py::test_query_as_arrow_csr_rejects_non_csr_shape", + # Arrow memory-backed table APIs are pybind-only today. "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_basic", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_filtering", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pandas", @@ -18,25 +11,8 @@ "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_count", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_arrow_node_and_rel_table", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_native_node_and_arrow_rel_table", + # Scanning from Python-owned DataFrame/Arrow/Polars objects is still pybind-only. "test/test_async_connection.py::test_async_scan_df", - "test/test_blob_parameter.py::test_bytes_param_udf", - "test/test_df.py::test_to_df", - "test/test_df.py::test_df_multiple_times", - "test/test_df.py::test_df_get_node", - "test/test_df.py::test_df_get_node_rel", - "test/test_df.py::test_df_get_recursive_join", - "test/test_df.py::test_get_df_unicode", - "test/test_df.py::test_get_df_decimal", - "test/test_issue.py::test_param_empty", - "test/test_issue.py::test_empty_list2", - "test/test_issue.py::test_empty_map", - "test/test_json.py::test_to_json_string_param_roundtrip", - "test/test_parameter.py::test_empty_list_param", - "test/test_parameter.py::test_map_param", - "test/test_parameter.py::test_general_list_param", - "test/test_parameter.py::test_null_resolution", - "test/test_parameter.py::test_param_error1", - "test/test_parameter.py::test_param_error4", "test/test_scan_pandas.py::test_scan_pandas", "test/test_scan_pandas.py::test_scan_pandas_timestamp", "test/test_scan_pandas.py::test_replace_failure", @@ -105,11 +81,14 @@ "test/test_scan_pyarrow.py::test_copy_from_pyarrow_multi_pairs", "test/test_scan_pyarrow.py::test_create_arrow_rel_table_from_pyarrow_table_query_results", "test/test_scan_pyarrow.py::test_arrow_node_and_arrow_rel_with_filtering_query", - "test/test_torch_geometric.py::test_to_torch_geometric_homogeneous_graph", - "test/test_torch_geometric.py::test_to_torch_geometric_heterogeneous_graph", + # UDF registration is still routed through pybind. + "test/test_blob_parameter.py::test_bytes_param_udf", "test/test_udf.py::test_udf", "test/test_udf.py::test_udf_null", "test/test_udf.py::test_udf_except", "test/test_udf.py::test_udf_remove", + # C API temporal conversion still differs from pybind for torch geometric export. + "test/test_torch_geometric.py::test_to_torch_geometric_homogeneous_graph", + "test/test_torch_geometric.py::test_to_torch_geometric_heterogeneous_graph", } ) From 28e0db5cfc6b2c3e5603c6290a7be034a0f40831 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Fri, 8 May 2026 11:00:27 -0700 Subject: [PATCH 2/3] Build C API DataFrames from Arrow Use the native C API Arrow export as the pandas DataFrame source instead of C API row iteration, while preserving pandas-compatible dtype and nested object normalization. --- src_py/_lbug_capi.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index d9f1af4..3185e48 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -1407,10 +1407,28 @@ def getCSR(self, *_args: Any, **_kwargs: Any) -> Any: def getAsDF(self) -> Any: import pandas as pd + import pyarrow as pa - df = pd.DataFrame( - self._get_all_rows_from_start(), columns=self.getColumnNames() - ) + def normalize_object_value(value: Any) -> Any: + if isinstance(value, dict): + return {key: normalize_object_value(val) for key, val in value.items()} + if isinstance(value, list): + if all(isinstance(item, tuple) and len(item) == 2 for item in value): + return {key: normalize_object_value(val) for key, val in value} + return [normalize_object_value(item) for item in value] + if hasattr(value, "tolist") and type(value).__module__.startswith("numpy"): + return normalize_object_value(value.tolist()) + return value + + table = self.getAsArrow(0, True) + try: + df = table.to_pandas() + except pa.ArrowNotImplementedError: + df = pd.DataFrame( + {name: table.column(name).to_pylist() for name in table.column_names} + ) + for name in df.select_dtypes(include="object").columns: + df[name] = df[name].map(normalize_object_value) for name, dtype in zip( self.getColumnNames(), self.getColumnDataTypes(), strict=False ): @@ -1449,7 +1467,7 @@ def getAsDF(self) -> Any: ) df[name] = datetime_col.astype("datetime64[us]") elif dtype == "INTERVAL": - df[name] = pd.to_timedelta(df[name]) + df[name] = pd.to_timedelta(df[name]).astype("timedelta64[ns]") elif dtype == "INT128": df[name] = df[name].astype("float64") return df From 0cbf8684b3e421a8db5451c74db909f5c9d09c07 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Fri, 8 May 2026 15:09:24 -0700 Subject: [PATCH 3/3] Refactor pybind ownership around shared handles --- src_cpp/include/py_connection.h | 15 ++- src_cpp/include/py_database.h | 6 +- src_cpp/include/py_handle_state.h | 94 +++++++++++++ src_cpp/include/py_prepared_statement.h | 3 +- src_cpp/include/py_query_result.h | 6 +- src_cpp/py_connection.cpp | 172 +++++++++++++++--------- src_cpp/py_database.cpp | 19 ++- src_cpp/py_prepared_statement.cpp | 10 +- src_cpp/py_query_result.cpp | 92 +++++++------ test/test_query_result_close.py | 86 ++++++++++++ 10 files changed, 377 insertions(+), 126 deletions(-) create mode 100644 src_cpp/include/py_handle_state.h diff --git a/src_cpp/include/py_connection.h b/src_cpp/include/py_connection.h index 2817f87..97c6ca9 100644 --- a/src_cpp/include/py_connection.h +++ b/src_cpp/include/py_connection.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include "main/storage_driver.h" #include "py_database.h" +#include "py_handle_state.h" #include "py_prepared_statement.h" #include "py_query_result.h" @@ -20,7 +22,7 @@ class PyConnection { void close(); - ~PyConnection() = default; + ~PyConnection(); void setQueryTimeout(uint64_t timeoutInMS); void interrupt(); @@ -29,8 +31,7 @@ class PyConnection { const py::dict& params); std::unique_ptr query(const std::string& statement); - std::unique_ptr queryAsArrow(const std::string& statement, - int64_t chunkSize); + std::unique_ptr queryAsArrow(const std::string& statement, int64_t chunkSize); void setMaxNumThreadForExec(uint64_t numThreads); @@ -65,10 +66,10 @@ class PyConnection { const LogicalType& type); private: - std::unique_ptr storageDriver; - std::unique_ptr conn; - std::unordered_map arrowTableRefs; + PyConnectionState& refState() const; + + std::shared_ptr state; static std::unique_ptr checkAndWrapQueryResult( - std::unique_ptr& queryResult); + std::unique_ptr& queryResult, std::shared_ptr state); }; diff --git a/src_cpp/include/py_database.h b/src_cpp/include/py_database.h index 0474e5b..6c97232 100644 --- a/src_cpp/include/py_database.h +++ b/src_cpp/include/py_database.h @@ -1,7 +1,10 @@ #pragma once +#include + #include "main/lbug.h" #include "main/storage_driver.h" +#include "py_handle_state.h" #include "pybind_include.h" // IWYU pragma: keep (used for py:: namespace) #define PYBIND11_DETAILED_ERROR_MESSAGES using namespace lbug::main; @@ -30,6 +33,5 @@ class PyDatabase { const py::array_t& indices, py::array_t& result, int numThreads); private: - std::unique_ptr database; - std::unique_ptr storageDriver; + std::shared_ptr state; }; diff --git a/src_cpp/include/py_handle_state.h b/src_cpp/include/py_handle_state.h new file mode 100644 index 0000000..9a60b25 --- /dev/null +++ b/src_cpp/include/py_handle_state.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include + +#include "common/exception/runtime.h" +#include "main/lbug.h" +#include "main/prepared_statement.h" +#include "main/storage_driver.h" +#include "pybind_include.h" + +struct PyDatabaseState { + std::unique_ptr database; + std::unique_ptr storageDriver; + + ~PyDatabaseState() { closeNative(); } + + void closeNative() { + storageDriver.reset(); + database.reset(); + } + + lbug::main::Database& ref() const { + if (database == nullptr) { + throw lbug::common::RuntimeException("Database is closed."); + } + return *database; + } + + lbug::main::StorageDriver& storage() const { + if (storageDriver == nullptr) { + throw lbug::common::RuntimeException("Database is closed."); + } + return *storageDriver; + } +}; + +struct PyConnectionState { + std::shared_ptr database; + std::unique_ptr storageDriver; + std::unique_ptr conn; + std::unordered_map arrowTableRefs; + + ~PyConnectionState() { closeNative(); } + + void closeNative() { + arrowTableRefs.clear(); + conn.reset(); + storageDriver.reset(); + database.reset(); + } + + lbug::main::Connection& ref() const { + if (conn == nullptr) { + throw lbug::common::RuntimeException("Connection is closed."); + } + return *conn; + } + + lbug::main::StorageDriver& storage() const { + if (storageDriver == nullptr) { + throw lbug::common::RuntimeException("Connection is closed."); + } + return *storageDriver; + } +}; + +struct PyPreparedStatementState { + std::shared_ptr connection; + std::unique_ptr preparedStatement; + + lbug::main::PreparedStatement& ref() const { + if (preparedStatement == nullptr) { + throw lbug::common::RuntimeException("Prepared statement is closed."); + } + return *preparedStatement; + } +}; + +struct PyQueryResultState { + std::shared_ptr connection; + std::shared_ptr parent; + std::unique_ptr owned; + lbug::main::QueryResult* borrowed = nullptr; + + lbug::main::QueryResult& ref() const { + auto* result = owned != nullptr ? owned.get() : borrowed; + if (result == nullptr) { + throw lbug::common::RuntimeException("Query result is closed."); + } + return *result; + } +}; diff --git a/src_cpp/include/py_prepared_statement.h b/src_cpp/include/py_prepared_statement.h index 9dec22b..1261303 100644 --- a/src_cpp/include/py_prepared_statement.h +++ b/src_cpp/include/py_prepared_statement.h @@ -2,6 +2,7 @@ #include "main/lbug.h" #include "main/prepared_statement.h" +#include "py_handle_state.h" #include "pybind_include.h" using namespace lbug::main; @@ -17,5 +18,5 @@ class PyPreparedStatement { bool isSuccess() const; private: - std::unique_ptr preparedStatement; + std::shared_ptr state; }; diff --git a/src_cpp/include/py_query_result.h b/src_cpp/include/py_query_result.h index dfec9ab..f2ff2d3 100644 --- a/src_cpp/include/py_query_result.h +++ b/src_cpp/include/py_query_result.h @@ -6,6 +6,7 @@ #include "arrow_array.h" #include "common/arrow/arrow.h" #include "main/lbug.h" +#include "py_handle_state.h" #include "pybind_include.h" using namespace lbug::main; @@ -54,6 +55,8 @@ class PyQueryResult { size_t getNumTuples(); private: + PyQueryResultState& refState() const; + static py::dict convertNodeIdToPyDict(const lbug::common::nodeID_t& nodeId); void getNextArrowChunk(const std::vector& types, @@ -63,6 +66,5 @@ class PyQueryResult { const std::vector& names, std::int64_t chunkSize, bool fallbackExtensionTypes); private: - QueryResult* queryResult = nullptr; - bool isOwned = false; + std::shared_ptr state; }; diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index 1abf5ea..8c5bd4d 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -135,26 +135,44 @@ static std::unique_ptr replacePythonObject( } PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) { - storageDriver = std::make_unique(pyDatabase->database.get()); - conn = std::make_unique(pyDatabase->database.get()); - conn->getClientContext()->addScanReplace( + if (pyDatabase == nullptr || pyDatabase->state == nullptr) { + throw RuntimeException("Database is closed."); + } + state = std::make_shared(); + state->database = pyDatabase->state; + auto& database = state->database->ref(); + state->storageDriver = std::make_unique(&database); + state->conn = std::make_unique(&database); + state->conn->getClientContext()->addScanReplace( function::ScanReplacement(lookupPythonObject, replacePythonObject)); if (numThreads > 0) { - conn->setMaxNumThreadForExec(numThreads); + state->conn->setMaxNumThreadForExec(numThreads); } } +PyConnection::~PyConnection() { + close(); +} + void PyConnection::close() { - arrowTableRefs.clear(); - conn.reset(); + state.reset(); +} + +PyConnectionState& PyConnection::refState() const { + if (state == nullptr) { + throw RuntimeException("Connection is closed."); + } + state->ref(); + state->storage(); + return *state; } void PyConnection::setQueryTimeout(uint64_t timeoutInMS) { - conn->setQueryTimeOut(timeoutInMS); + refState().ref().setQueryTimeOut(timeoutInMS); } void PyConnection::interrupt() { - conn->interrupt(); + refState().ref().interrupt(); } static std::unordered_map> transformPythonParameters( @@ -162,52 +180,62 @@ static std::unordered_map> transformPythonPa std::unique_ptr PyConnection::execute(PyPreparedStatement* preparedStatement, const py::dict& params) { - auto parameters = transformPythonParameters(params, conn.get()); + auto& stateRef = refState(); + if (preparedStatement == nullptr || preparedStatement->state == nullptr) { + throw RuntimeException("Prepared statement is closed."); + } + auto parameters = transformPythonParameters(params, &stateRef.ref()); py::gil_scoped_release release; auto queryResult = - conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters)); + stateRef.ref().executeWithParams(&preparedStatement->state->ref(), std::move(parameters)); py::gil_scoped_acquire acquire; - return checkAndWrapQueryResult(queryResult); + return checkAndWrapQueryResult(queryResult, state); } std::unique_ptr PyConnection::query(const std::string& statement) { + auto& stateRef = refState(); py::gil_scoped_release release; - auto queryResult = conn->query(statement); + auto queryResult = stateRef.ref().query(statement); py::gil_scoped_acquire acquire; - return checkAndWrapQueryResult(queryResult); + return checkAndWrapQueryResult(queryResult, state); } std::unique_ptr PyConnection::queryAsArrow(const std::string& statement, int64_t chunkSize) { + auto& stateRef = refState(); py::gil_scoped_release release; - auto queryResult = conn->queryAsArrow(statement, chunkSize); + auto queryResult = stateRef.ref().queryAsArrow(statement, chunkSize); py::gil_scoped_acquire acquire; - return checkAndWrapQueryResult(queryResult); + return checkAndWrapQueryResult(queryResult, state); } void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) { - conn->setMaxNumThreadForExec(numThreads); + refState().ref().setMaxNumThreadForExec(numThreads); } PyPreparedStatement PyConnection::prepare(const std::string& query, const py::dict& parameters) { - auto params = transformPythonParameters(parameters, conn.get()); - auto preparedStatement = conn->prepareWithParams(query, std::move(params)); + auto& stateRef = refState(); + auto params = transformPythonParameters(parameters, &stateRef.ref()); + auto preparedStatement = stateRef.ref().prepareWithParams(query, std::move(params)); PyPreparedStatement pyPreparedStatement; - pyPreparedStatement.preparedStatement = std::move(preparedStatement); + pyPreparedStatement.state = std::make_shared(); + pyPreparedStatement.state->connection = state; + pyPreparedStatement.state->preparedStatement = std::move(preparedStatement); return pyPreparedStatement; } uint64_t PyConnection::getNumNodes(const std::string& nodeName) { - return storageDriver->getNumNodes(nodeName); + return refState().storage().getNumNodes(nodeName); } uint64_t PyConnection::getNumRels(const std::string& relName) { - return storageDriver->getNumRels(relName); + return refState().storage().getNumRels(relName); } void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, const std::string& srcTableName, const std::string& relName, const std::string& dstTableName, size_t queryBatchSize) { + auto& stateRef = refState(); // Get the number of nodes in the dst table for batching. auto numDstNodes = getNumNodes(dstTableName); uint64_t batches = numDstNodes / queryBatchSize; @@ -220,13 +248,13 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, auto buffer = (int64_t*)bufferInfo.ptr; // Set the number of threads to 1 for fetching edges to ensure ordering. - auto numThreadsForExec = conn->getMaxNumThreadForExec(); - conn->setMaxNumThreadForExec(1); - auto query = - std::format("MATCH (a:{})-[:{}]->(b:{}) WHERE offset(id(b)) >= $s AND offset(id(b)) < " - "$e RETURN offset(id(a)), offset(id(b))", - srcTableName, relName, dstTableName); - auto preparedStatement = conn->prepare(query); + auto numThreadsForExec = stateRef.ref().getMaxNumThreadForExec(); + stateRef.ref().setMaxNumThreadForExec(1); + auto query = std::format("MATCH (a:{})-[:{}]->(b:{}) WHERE offset(id(b)) >= " + "$s AND offset(id(b)) < " + "$e RETURN offset(id(a)), offset(id(b))", + srcTableName, relName, dstTableName); + auto preparedStatement = stateRef.ref().prepare(query); auto srcBuffer = buffer; auto dstBuffer = buffer + numRels; for (uint64_t batch = 0; batch < batches; ++batch) { @@ -237,7 +265,8 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, std::unordered_map> parameters; parameters["s"] = std::make_unique(start); parameters["e"] = std::make_unique(end); - auto result = conn->executeWithParams(preparedStatement.get(), std::move(parameters)); + auto result = + stateRef.ref().executeWithParams(preparedStatement.get(), std::move(parameters)); if (!result->isSuccess()) { throw std::runtime_error(result->getErrorMessage()); } @@ -274,7 +303,7 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, throw std::runtime_error("Wrong result table schema."); } } - conn->setMaxNumThreadForExec(numThreadsForExec); + stateRef.ref().setMaxNumThreadForExec(numThreadsForExec); } bool PyConnection::isPandasDataframe(const py::handle& object) { @@ -303,7 +332,8 @@ static std::unordered_map> transformPythonPa std::unordered_map> result; for (auto& [key, value] : params) { if (!py::isinstance(key)) { - // TODO(Chang): remove ROLLBACK once we can guarantee database is deleted after conn + // TODO(Chang): remove ROLLBACK once we can guarantee database is deleted + // after conn conn->query("ROLLBACK"); throw std::runtime_error("Parameter name must be of type string but got " + py::str(key.get_type()).cast()); @@ -422,15 +452,17 @@ static LogicalType pyLogicalType(const py::handle& val) { curChildValueType = pyLogicalType(child.second); LogicalType resultKey, resultValue; if (!LogicalTypeUtils::tryGetMaxLogicalType(childKeyType, curChildKeyType, resultKey)) { - throw RuntimeException(std::format( - "Cannot convert Python object to Lbug value : {} is incompatible with {}", - childKeyType.toString(), curChildKeyType.toString())); + throw RuntimeException( + std::format("Cannot convert Python object to Lbug value : {} is " + "incompatible with {}", + childKeyType.toString(), curChildKeyType.toString())); } if (!LogicalTypeUtils::tryGetMaxLogicalType(childValueType, curChildValueType, resultValue)) { - throw RuntimeException(std::format( - "Cannot convert Python object to Lbug value : {} is incompatible with {}", - childValueType.toString(), curChildValueType.toString())); + throw RuntimeException( + std::format("Cannot convert Python object to Lbug value : {} is incompatible " + "with {}", + childValueType.toString(), curChildValueType.toString())); } childKeyType = std::move(resultKey); childValueType = std::move(resultValue); @@ -443,9 +475,10 @@ static LogicalType pyLogicalType(const py::handle& val) { auto curChildType = pyLogicalType(child); LogicalType result; if (!LogicalTypeUtils::tryGetMaxLogicalType(childType, curChildType, result)) { - throw RuntimeException(std::format( - "Cannot convert Python object to Lbug value : {} is incompatible with {}", - childType.toString(), curChildType.toString())); + throw RuntimeException( + std::format("Cannot convert Python object to Lbug value : {} is " + "incompatible with {}", + childType.toString(), curChildType.toString())); } childType = std::move(result); } @@ -485,12 +518,12 @@ static bool validateMapFields(py::dict& dict) { static LogicalType pyLogicalTypeFromParameter(const py::handle& val); -// If we want to interpret a python dict as MAP, it has to satisfy the following two conditions: +// If we want to interpret a python dict as MAP, it has to satisfy the following +// two conditions: // 1. The dictionary has only two fields. // 2. The first field name is "key", while the second field name is "value". -// 3. Values of both first and second fields are list of values with the same type. -// Sample: -// my_map_dict = { +// 3. Values of both first and second fields are list of values with the same +// type. Sample: my_map_dict = { // "key": [ // 1, 2, 3 // ], @@ -551,9 +584,10 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) { } LogicalType result; if (!LogicalTypeUtils::tryGetMaxLogicalType(childType, curChildType, result)) { - throw RuntimeException(std::format( - "Cannot convert Python object to Lbug value : {} is incompatible with {}", - childType.toString(), curChildType.toString())); + throw RuntimeException( + std::format("Cannot convert Python object to Lbug value : {} is " + "incompatible with {}", + childType.toString(), curChildType.toString())); } childType = std::move(result); } @@ -688,8 +722,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT const auto& childKeyType = MapType::getKeyType(type); const auto& childValueType = MapType::getValueType(type); for (auto child : dict) { - // type construction is inefficient, we have to create duplicates because it asks for - // a unique ptr + // type construction is inefficient, we have to create duplicates because + // it asks for a unique ptr std::vector fields; fields.emplace_back(InternalKeyword::MAP_KEY, childKeyType.copy()); fields.emplace_back(InternalKeyword::MAP_VALUE, childValueType.copy()); @@ -810,28 +844,32 @@ Value PyConnection::transformPythonValueFromParameter(const py::handle& val) { } std::unique_ptr PyConnection::checkAndWrapQueryResult( - std::unique_ptr& queryResult) { + std::unique_ptr& queryResult, std::shared_ptr state) { if (!queryResult->isSuccess()) { throw std::runtime_error(queryResult->getErrorMessage()); } auto pyQueryResult = std::make_unique(); - pyQueryResult->queryResult = queryResult.release(); - pyQueryResult->isOwned = true; + pyQueryResult->state = std::make_shared(); + pyQueryResult->state->connection = std::move(state); + pyQueryResult->state->owned = std::move(queryResult); return pyQueryResult; } void PyConnection::createScalarFunction(const std::string& name, const py::function& udf, const py::list& params, const std::string& retval, bool defaultNull, bool catchExceptions) { - conn->addUDFFunctionSet(name, PyUDF::toFunctionSet(name, udf, params, retval, defaultNull, - catchExceptions, conn->getClientContext())); + auto& stateRef = refState(); + stateRef.ref().addUDFFunctionSet(name, + PyUDF::toFunctionSet(name, udf, params, retval, defaultNull, catchExceptions, + stateRef.ref().getClientContext())); } void PyConnection::removeScalarFunction(const std::string& name) { - conn->removeUDFFunction(name); + refState().ref().removeUDFFunction(name); } std::unique_ptr PyConnection::createArrowTable(const std::string& tableName, py::object arrowTable) { + auto& stateRef = refState(); py::gil_scoped_acquire acquire; // Convert pandas/polars to pyarrow if needed @@ -864,17 +902,18 @@ std::unique_ptr PyConnection::createArrowTable(const std::string& keepAlive.append(arrowTable); keepAlive.append(batches); - auto result = ArrowTableSupport::createViewFromArrowTable(*conn, tableName, std::move(schema), - std::move(arrays)); + auto result = ArrowTableSupport::createViewFromArrowTable(stateRef.ref(), tableName, + std::move(schema), std::move(arrays)); if (result.queryResult && result.queryResult->isSuccess()) { - arrowTableRefs[tableName] = std::move(keepAlive); + stateRef.arrowTableRefs[tableName] = std::move(keepAlive); } - return checkAndWrapQueryResult(result.queryResult); + return checkAndWrapQueryResult(result.queryResult, state); } std::unique_ptr PyConnection::createArrowRelTable(const std::string& tableName, py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName) { + auto& stateRef = refState(); py::gil_scoped_acquire acquire; if (PyConnection::isPandasDataframe(arrowTable)) { @@ -899,19 +938,20 @@ std::unique_ptr PyConnection::createArrowRelTable(const std::stri keepAlive.append(arrowTable); keepAlive.append(batches); - auto result = ArrowTableSupport::createRelTableFromArrowTable(*conn, tableName, srcTableName, - dstTableName, std::move(schema), std::move(arrays)); + auto result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName, + srcTableName, dstTableName, std::move(schema), std::move(arrays)); if (result.queryResult && result.queryResult->isSuccess()) { - arrowTableRefs[tableName] = std::move(keepAlive); + stateRef.arrowTableRefs[tableName] = std::move(keepAlive); } - return checkAndWrapQueryResult(result.queryResult); + return checkAndWrapQueryResult(result.queryResult, state); } std::unique_ptr PyConnection::dropArrowTable(const std::string& tableName) { - auto result = ArrowTableSupport::unregisterArrowTable(*conn, tableName); + auto& stateRef = refState(); + auto result = ArrowTableSupport::unregisterArrowTable(stateRef.ref(), tableName); if (result && result->isSuccess()) { - arrowTableRefs.erase(tableName); + stateRef.arrowTableRefs.erase(tableName); } - return checkAndWrapQueryResult(result); + return checkAndWrapQueryResult(result, state); } diff --git a/src_cpp/py_database.cpp b/src_cpp/py_database.cpp index 72be349..bf7083e 100644 --- a/src_cpp/py_database.cpp +++ b/src_cpp/py_database.cpp @@ -2,6 +2,7 @@ #include +#include "common/exception/runtime.h" #include "extension/extension.h" #include "include/cached_import/py_cached_import.h" #include "main/version.h" @@ -59,19 +60,22 @@ PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize, systemConfig.throwOnWalReplayFailure = throwOnWalReplayFailure; systemConfig.enableChecksums = enableChecksums; systemConfig.enableMultiWrites = enableMultiWrites; - database = std::make_unique(databasePath, systemConfig); - lbug::extension::ExtensionUtils::addTableFunc(*database); - storageDriver = std::make_unique(database.get()); + state = std::make_shared(); + state->database = std::make_unique(databasePath, systemConfig); + lbug::extension::ExtensionUtils::addTableFunc(*state->database); + state->storageDriver = std::make_unique(state->database.get()); py::gil_scoped_acquire acquire; if (lbug::importCache.get() == nullptr) { lbug::importCache = std::make_shared(); } } -PyDatabase::~PyDatabase() {} +PyDatabase::~PyDatabase() { + close(); +} void PyDatabase::close() { - database.reset(); + state.reset(); } template @@ -83,5 +87,8 @@ void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& auto result_buffer_info = result.request(); auto result_buffer = (uint8_t*)result_buffer_info.ptr; auto size = indices.size(); - storageDriver->scan(tableName, propName, nodeOffsets, size, result_buffer, numThreads); + if (state == nullptr) { + throw RuntimeException("Database is closed."); + } + state->storage().scan(tableName, propName, nodeOffsets, size, result_buffer, numThreads); } diff --git a/src_cpp/py_prepared_statement.cpp b/src_cpp/py_prepared_statement.cpp index 00419c8..bf1024d 100644 --- a/src_cpp/py_prepared_statement.cpp +++ b/src_cpp/py_prepared_statement.cpp @@ -10,9 +10,15 @@ void PyPreparedStatement::initialize(py::handle& m) { } py::str PyPreparedStatement::getErrorMessage() const { - return preparedStatement->getErrorMessage(); + if (state == nullptr) { + throw lbug::common::RuntimeException("Prepared statement is closed."); + } + return state->ref().getErrorMessage(); } bool PyPreparedStatement::isSuccess() const { - return preparedStatement->isSuccess(); + if (state == nullptr) { + throw lbug::common::RuntimeException("Prepared statement is closed."); + } + return state->ref().isSuccess(); } diff --git a/src_cpp/py_query_result.cpp b/src_cpp/py_query_result.cpp index fed2ae7..080180c 100644 --- a/src_cpp/py_query_result.cpp +++ b/src_cpp/py_query_result.cpp @@ -41,9 +41,10 @@ void PyQueryResult::initialize(py::handle& m) { .def("getCompilingTime", &PyQueryResult::getCompilingTime) .def("getExecutionTime", &PyQueryResult::getExecutionTime) .def("getNumTuples", &PyQueryResult::getNumTuples); - // PyDateTime_IMPORT is a macro that must be invoked before calling any other cpython datetime - // macros. One could also invoke this in a separate function like constructor. See - // https://docs.python.org/3/c-api/datetime.html for details. + // PyDateTime_IMPORT is a macro that must be invoked before calling any other + // cpython datetime macros. One could also invoke this in a separate function + // like constructor. See https://docs.python.org/3/c-api/datetime.html for + // details. PyDateTime_IMPORT; } @@ -51,12 +52,20 @@ PyQueryResult::~PyQueryResult() { close(); } +PyQueryResultState& PyQueryResult::refState() const { + if (state == nullptr) { + throw RuntimeException("Query result is closed."); + } + state->ref(); + return *state; +} + bool PyQueryResult::hasNext() { - return queryResult->hasNext(); + return refState().ref().hasNext(); } py::list PyQueryResult::getNext() { - auto tuple = queryResult->getNext(); + auto tuple = refState().ref().getNext(); py::tuple result(tuple->len()); for (auto i = 0u; i < tuple->len(); ++i) { result[i] = convertValueToPyObject(*tuple->getValue(i)); @@ -65,27 +74,27 @@ py::list PyQueryResult::getNext() { } bool PyQueryResult::hasNextQueryResult() { - return queryResult->hasNextQueryResult(); + return refState().ref().hasNextQueryResult(); } std::unique_ptr PyQueryResult::getNextQueryResult() { + auto& stateRef = refState(); py::gil_scoped_release release; - auto nextQueryResult = queryResult->getNextQueryResult(); + auto nextQueryResult = stateRef.ref().getNextQueryResult(); py::gil_scoped_acquire acquire; auto pyQueryResult = std::make_unique(); - pyQueryResult->queryResult = nextQueryResult; - pyQueryResult->isOwned = false; + pyQueryResult->state = std::make_shared(); + pyQueryResult->state->connection = stateRef.connection; + pyQueryResult->state->parent = state; + pyQueryResult->state->borrowed = nextQueryResult; return pyQueryResult; } void PyQueryResult::close() { - // Note: Python does not guarantee objects to be deleted in the reverse order. Therefore, we - // expose close() interface so that users can explicitly call close() and ensure that - // QueryResult is destroyed before Database. - if (isOwned) { - delete queryResult; - queryResult = nullptr; - } + // Note: Python does not guarantee objects to be deleted in the reverse order. + // Therefore, we expose close() interface so that users can explicitly call + // close() and ensure that QueryResult is destroyed before Database. + state.reset(); } namespace { @@ -311,19 +320,20 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { } py::object PyQueryResult::getAsDF() { - return QueryResultConverter(queryResult).toDF(); + return QueryResultConverter(&refState().ref()).toDF(); } void PyQueryResult::getNextArrowChunk(const std::vector& types, const std::vector& names, py::list& batches, std::int64_t chunkSize, bool fallbackExtensionTypes) { + auto& queryResult = refState().ref(); auto rowBatch = std::make_unique(types, chunkSize, fallbackExtensionTypes); auto rowBatchSize = 0u; while (rowBatchSize < chunkSize) { - if (!queryResult->hasNext()) { + if (!queryResult.hasNext()) { break; } - auto tuple = queryResult->getNext(); + auto tuple = queryResult.getNext(); rowBatch->append(*tuple); rowBatchSize++; } @@ -335,8 +345,9 @@ void PyQueryResult::getNextArrowChunk(const std::vector& types, py::object PyQueryResult::getArrowChunks(const std::vector& types, const std::vector& names, std::int64_t chunkSize, bool fallbackExtensionTypes) { + auto& queryResult = refState().ref(); py::list batches; - while (queryResult->hasNext()) { + while (queryResult.hasNext()) { getNextArrowChunk(types, names, batches, chunkSize, fallbackExtensionTypes); } return batches; @@ -344,16 +355,16 @@ py::object PyQueryResult::getArrowChunks(const std::vector& types, lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes) { - if (queryResult->getType() == QueryResultType::ARROW) { - auto types = queryResult->getColumnDataTypes(); - auto names = queryResult->getColumnNames(); + auto& queryResult = refState().ref(); + if (queryResult.getType() == QueryResultType::ARROW) { + auto types = queryResult.getColumnDataTypes(); + auto names = queryResult.getColumnNames(); py::list batches; auto batchImportFunc = importCache->pyarrow.lib.RecordBatch._import_from_c(); - while (queryResult->hasNextArrowChunk()) { - auto data = queryResult->getNextArrowChunk(chunkSize); + while (queryResult.hasNextArrowChunk()) { + auto data = queryResult.getNextArrowChunk(chunkSize); auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); - batches.append( - batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get())); + batches.append(batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get())); } auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches(); @@ -361,8 +372,8 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, auto schemaObj = schemaImportFunc((std::uint64_t)schema.get()); return py::cast(fromBatchesFunc(batches, schemaObj)); } - auto types = queryResult->getColumnDataTypes(); - auto names = queryResult->getColumnNames(); + auto types = queryResult.getColumnDataTypes(); + auto names = queryResult.getColumnNames(); py::list batches = getArrowChunks(types, names, chunkSize, fallbackExtensionTypes); auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches(); @@ -372,16 +383,17 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, } py::dict PyQueryResult::getCSR() { - if (auto* arrowQueryResult = dynamic_cast(queryResult); + auto& queryResult = refState().ref(); + if (auto* arrowQueryResult = dynamic_cast(&queryResult); arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) { return buildCSRResult(arrowQueryResult->getCSRArrowArrays()); } - throw RuntimeException( - "CSR export is only supported for Arrow query results with native CSR metadata."); + throw RuntimeException("CSR export is only supported for Arrow query results " + "with native CSR metadata."); } py::list PyQueryResult::getColumnDataTypes() { - auto columnDataTypes = queryResult->getColumnDataTypes(); + auto columnDataTypes = refState().ref().getColumnDataTypes(); py::tuple result(columnDataTypes.size()); for (auto i = 0u; i < columnDataTypes.size(); ++i) { result[i] = py::cast(columnDataTypes[i].toString()); @@ -390,7 +402,7 @@ py::list PyQueryResult::getColumnDataTypes() { } py::list PyQueryResult::getColumnNames() { - auto columnNames = queryResult->getColumnNames(); + auto columnNames = refState().ref().getColumnNames(); py::tuple result(columnNames.size()); for (auto i = 0u; i < columnNames.size(); ++i) { result[i] = py::cast(columnNames[i]); @@ -399,15 +411,15 @@ py::list PyQueryResult::getColumnNames() { } void PyQueryResult::resetIterator() { - queryResult->resetIterator(); + refState().ref().resetIterator(); } bool PyQueryResult::isSuccess() const { - return queryResult->isSuccess(); + return refState().ref().isSuccess(); } std::string PyQueryResult::getErrorMessage() const { - return queryResult->getErrorMessage(); + return refState().ref().getErrorMessage(); } py::dict PyQueryResult::convertNodeIdToPyDict(const nodeID_t& nodeId) { @@ -418,13 +430,13 @@ py::dict PyQueryResult::convertNodeIdToPyDict(const nodeID_t& nodeId) { } double PyQueryResult::getExecutionTime() { - return queryResult->getQuerySummary()->getExecutionTime(); + return refState().ref().getQuerySummary()->getExecutionTime(); } double PyQueryResult::getCompilingTime() { - return queryResult->getQuerySummary()->getCompilingTime(); + return refState().ref().getQuerySummary()->getCompilingTime(); } size_t PyQueryResult::getNumTuples() { - return queryResult->getNumTuples(); + return refState().ref().getNumTuples(); } diff --git a/test/test_query_result_close.py b/test/test_query_result_close.py index 889a524..2747621 100644 --- a/test/test_query_result_close.py +++ b/test/test_query_result_close.py @@ -3,6 +3,7 @@ from pathlib import Path from textwrap import dedent +import pytest from conftest import get_db_file_path from lbug_test_paths import LBUG_ROOT @@ -42,3 +43,88 @@ def test_query_result_close(tmp_path: Path, build_dir: Path) -> None: """) result = subprocess.run([sys.executable, "-c", code]) assert result.returncode == 0 + + +def test_pybind_native_close_is_idempotent(tmp_path: Path, build_dir: Path) -> None: + db_path = get_db_file_path(tmp_path) + code = dedent(f""" + import gc + import sys + + sys.path.append(r"{build_dir!s}") + + from ladybug._backend import get_pybind_module + + pybind = get_pybind_module() + if pybind is None: + raise SystemExit(77) + + db = pybind.Database(r"{db_path!s}") + conn = pybind.Connection(db) + result = conn.query("RETURN 1") + + result.close() + result.close() + try: + result.hasNext() + except RuntimeError as exc: + assert "closed" in str(exc) + else: + raise AssertionError("closed query result remained usable") + del result + gc.collect() + + conn.close() + conn.close() + try: + conn.query("RETURN 1") + except RuntimeError as exc: + assert "closed" in str(exc) + else: + raise AssertionError("closed connection remained usable") + del conn + gc.collect() + + db.close() + db.close() + del db + gc.collect() + + db = pybind.Database(r"{db_path!s}.db_first") + conn = pybind.Connection(db) + result = conn.query("RETURN 1") + statement = conn.prepare("RETURN 1") + db.close() + db.close() + try: + pybind.Connection(db) + except RuntimeError as exc: + assert "closed" in str(exc) + else: + raise AssertionError("connection opened on a closed database") + del db + gc.collect() + del statement + del result + del conn + gc.collect() + + db = pybind.Database(r"{db_path!s}.child_result") + conn = pybind.Connection(db) + result = conn.query("RETURN 1; RETURN 2;") + child = result.getNextQueryResult() + result.close() + assert child.hasNext() + assert child.getNext() == [2] + conn.close() + db.close() + del child + del result + del conn + del db + gc.collect() + """) + result = subprocess.run([sys.executable, "-c", code]) + if result.returncode == 77: + pytest.skip("pybind extension is not available") + assert result.returncode == 0