diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index 3185e48..1fbacfa 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import atexit import ctypes import ctypes.util import datetime as dt @@ -8,6 +9,7 @@ import sys import threading import uuid +import weakref from dataclasses import dataclass from decimal import Decimal from pathlib import Path @@ -186,6 +188,24 @@ def _resolve_library_path() -> str: _dlopen_mode = getattr(ctypes, "RTLD_GLOBAL", 0) | getattr(ctypes, "RTLD_NOW", 0) _LIB = ctypes.CDLL(_resolve_library_path(), mode=_dlopen_mode) +_CAPI_DATABASES: weakref.WeakSet[Any] = weakref.WeakSet() +_CAPI_CONNECTIONS: weakref.WeakSet[Any] = weakref.WeakSet() +_ARROW_ATEXIT_REGISTERED = False + + +def _close_capi_connections() -> None: + for connection in list(_CAPI_CONNECTIONS): + connection.close() + for database in list(_CAPI_DATABASES): + database.close() + + +def _ensure_arrow_atexit_cleanup() -> None: + global _ARROW_ATEXIT_REGISTERED + if not _ARROW_ATEXIT_REGISTERED: + atexit.register(_close_capi_connections) + _ARROW_ATEXIT_REGISTERED = True + _LBUG_SUCCESS = 0 @@ -288,6 +308,35 @@ def _setup_signatures() -> None: ] _LIB.lbug_connection_execute.restype = ctypes.c_int + _LIB.lbug_connection_create_arrow_table.argtypes = [ + ctypes.POINTER(_LbugConnection), + ctypes.c_char_p, + ctypes.POINTER(_ArrowSchema), + ctypes.POINTER(_ArrowArray), + ctypes.c_uint64, + ctypes.POINTER(_LbugQueryResult), + ] + _LIB.lbug_connection_create_arrow_table.restype = ctypes.c_int + + _LIB.lbug_connection_create_arrow_rel_table.argtypes = [ + ctypes.POINTER(_LbugConnection), + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.POINTER(_ArrowSchema), + ctypes.POINTER(_ArrowArray), + ctypes.c_uint64, + ctypes.POINTER(_LbugQueryResult), + ] + _LIB.lbug_connection_create_arrow_rel_table.restype = ctypes.c_int + + _LIB.lbug_connection_drop_arrow_table.argtypes = [ + ctypes.POINTER(_LbugConnection), + ctypes.c_char_p, + ctypes.POINTER(_LbugQueryResult), + ] + _LIB.lbug_connection_drop_arrow_table.restype = ctypes.c_int + _LIB.lbug_prepared_statement_destroy.argtypes = [ ctypes.POINTER(_LbugPreparedStatement) ] @@ -1065,6 +1114,7 @@ def __init__( database_path.encode("utf-8"), config, ctypes.byref(self._database) ) _check_state(state, "Failed to initialize database") + _CAPI_DATABASES.add(self) def close(self) -> None: lib = _LIB @@ -1072,6 +1122,7 @@ def close(self) -> None: if lib is not None: lib.lbug_database_destroy(ctypes.byref(self._database)) self._database._database = None + _CAPI_DATABASES.discard(self) @staticmethod def get_version() -> str: @@ -2014,6 +2065,7 @@ def __init__(self, database: Database, num_threads: int = 0): ), "Failed to initialize connection", ) + _CAPI_CONNECTIONS.add(self) if num_threads > 0: self.set_max_threads_for_exec(num_threads) @@ -2023,6 +2075,7 @@ def close(self) -> None: if lib is not None: lib.lbug_connection_destroy(ctypes.byref(self._connection)) self._connection._connection = None + _CAPI_CONNECTIONS.discard(self) def set_max_threads_for_exec(self, num_threads: int) -> None: _check_state( @@ -2119,17 +2172,81 @@ def create_function(self, *_args: Any, **_kwargs: Any) -> None: def remove_function(self, *_args: Any, **_kwargs: Any) -> None: raise NotImplementedError("UDF removal is not yet implemented in C-API backend") - def create_arrow_table(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError( - "Arrow memory table APIs are not yet implemented in C-API backend" + @staticmethod + def _as_arrow_table(dataframe: Any) -> Any: + import pyarrow as pa + + _ensure_arrow_atexit_cleanup() + module_name = type(dataframe).__module__ + if module_name.startswith("pandas"): + return pa.Table.from_pandas(dataframe) + if module_name.startswith("polars"): + return dataframe.to_arrow() + if ( + module_name.startswith("pyarrow") + and dataframe.__class__.__name__ == "Table" + ): + return dataframe + msg = "Expected a pyarrow Table, polars DataFrame, or pandas DataFrame" + raise RuntimeError(msg) + + @staticmethod + def _export_arrow_table(dataframe: Any) -> tuple[Any, _ArrowSchema, Any, Any]: + table = Connection._as_arrow_table(dataframe) + schema = _ArrowSchema() + table.schema._export_to_c(ctypes.addressof(schema)) + batches = table.to_batches() + array_type = _ArrowArray * len(batches) + arrays = array_type() + for idx, batch in enumerate(batches): + batch._export_to_c(ctypes.addressof(arrays[idx])) + return table, schema, arrays, batches + + def create_arrow_table(self, table_name: str, dataframe: Any) -> QueryResult: + _table, schema, arrays, _batches = self._export_arrow_table(dataframe) + result = _LbugQueryResult() + state = _LIB.lbug_connection_create_arrow_table( + ctypes.byref(self._connection), + table_name.encode("utf-8"), + ctypes.byref(schema), + arrays, + len(arrays), + ctypes.byref(result), ) + if state != _LBUG_SUCCESS and not result._query_result: + _check_state(state, "Failed to create Arrow table") + return QueryResult(result) - def drop_arrow_table(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError( - "Arrow memory table APIs are not yet implemented in C-API backend" + def drop_arrow_table(self, table_name: str) -> QueryResult: + result = _LbugQueryResult() + state = _LIB.lbug_connection_drop_arrow_table( + ctypes.byref(self._connection), + table_name.encode("utf-8"), + ctypes.byref(result), ) + if state != _LBUG_SUCCESS and not result._query_result: + _check_state(state, "Failed to drop Arrow table") + return QueryResult(result) - def create_arrow_rel_table(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError( - "Arrow memory table APIs are not yet implemented in C-API backend" + def create_arrow_rel_table( + self, + table_name: str, + dataframe: Any, + src_table_name: str, + dst_table_name: str, + ) -> QueryResult: + _table, schema, arrays, _batches = self._export_arrow_table(dataframe) + result = _LbugQueryResult() + state = _LIB.lbug_connection_create_arrow_rel_table( + ctypes.byref(self._connection), + table_name.encode("utf-8"), + src_table_name.encode("utf-8"), + dst_table_name.encode("utf-8"), + ctypes.byref(schema), + arrays, + len(arrays), + ctypes.byref(result), ) + if state != _LBUG_SUCCESS and not result._query_result: + _check_state(state, "Failed to create Arrow relationship table") + return QueryResult(result) diff --git a/src_py/connection.py b/src_py/connection.py index 946b3a8..417e458 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -3,6 +3,7 @@ import inspect import json import re +import uuid import warnings from typing import TYPE_CHECKING, Any from weakref import WeakSet @@ -49,6 +50,7 @@ def __init__(self, database: Database, num_threads: int = 0): self._prefer_pybind = False self._query_timeout_ms = 0 self._query_results: WeakSet[QueryResult] = WeakSet() + self._capi_scan_tables: set[str] = set() self.database._register_connection(self) self.init_connection() @@ -174,6 +176,113 @@ def _has_scan_pattern(self, query: str) -> bool: return False return re.search(r"(?i)\bFROM\b", query) is not None + @staticmethod + def _quote_identifier(identifier: str) -> str: + escaped = identifier.replace("`", "``") + return f"`{escaped}`" + + def _arrow_table_column_names(self, value: Any) -> list[str]: + table = get_capi_module().Connection._as_arrow_table(value) + return [field.name for field in table.schema] + + def _create_capi_scan_table(self, value: Any) -> tuple[str, list[str]]: + table_name = f"__lbug_capi_scan_{uuid.uuid4().hex}" + self._connection.create_arrow_table(table_name, value) + self._capi_scan_tables.add(table_name) + return table_name, self._arrow_table_column_names(value) + + def _replace_column_refs(self, text: str, columns: list[str], alias: str) -> str: + result = text + for column in sorted(columns, key=len, reverse=True): + quoted = self._quote_identifier(column) + result = re.sub( + rf"(? str: + alias = "_scan" + match_prefix = f"MATCH ({alias}:{self._quote_identifier(table_name)})" + rest = query[source_end:] + return_star = ", ".join( + f"{alias}.{self._quote_identifier(column)} AS {self._quote_identifier(column)}" + for column in columns + ) + return_match = re.search(r"(?i)\bRETURN\s+\*", rest) + if return_match is not None: + rest = ( + rest[: return_match.start()] + + f"RETURN {return_star}" + + rest[return_match.end() :] + ) + rest = self._replace_column_refs(rest, columns, alias) + return query[:source_start] + match_prefix + rest + + def _rewrite_copy_from_capi_scan( + self, + query: str, + source_start: int, + source_end: int, + table_name: str, + columns: list[str], + ) -> str: + alias = "_scan" + return_cols = ", ".join( + f"{alias}.{self._quote_identifier(column)}" for column in columns + ) + replacement = f"(MATCH ({alias}:{self._quote_identifier(table_name)}) RETURN {return_cols})" + return query[:source_start] + replacement + query[source_end:] + + def _rewrite_capi_python_scan( + self, + query: str, + parameters: dict[str, Any], + ) -> tuple[str, dict[str, Any]]: + if self._using_pybind_backend() or not self._has_scan_pattern(query): + return query, parameters + if self.database.read_only: + return query, parameters + + for key, value in list(parameters.items()): + if not isinstance(key, str) or not self._is_python_scan_object(value): + continue + match = re.search(rf"(?i)\bFROM\s+(\${re.escape(key)})\b", query) + if match is None: + continue + options_match = re.match(r"\s*\((.*?)\)", query[match.end() :], re.DOTALL) + if options_match is not None and re.search( + r"(?i)\bINVALID_OPTION\b", options_match.group(1) + ): + msg = "INVALID_OPTION Option not recognized by pyArrow scanner." + raise RuntimeError(msg) + table_name, columns = self._create_capi_scan_table(value) + if query.lstrip().upper().startswith("LOAD "): + source_start = len(query) - len(query.lstrip()) + query = self._rewrite_load_from_capi_scan( + query, source_start, match.end(), table_name, columns + ) + else: + query = self._rewrite_copy_from_capi_scan( + query, + match.start(1), + match.end(1), + table_name, + columns, + ) + parameters = dict(parameters) + parameters.pop(key, None) + break + return query, parameters + def _lookup_python_object_in_frames(self, name: str) -> Any | None: frame = inspect.currentframe() if frame is None: @@ -328,8 +437,11 @@ def execute( msg = f"Parameters must be a dict; found {type(parameters)}." raise RuntimeError(msg) # noqa: TRY004 + scan_tables_before = set(self._capi_scan_tables) if isinstance(query, str): query, parameters = self._rewrite_local_scan_object(query, parameters) + query, parameters = self._rewrite_capi_python_scan(query, parameters) + scan_tables_to_drop = self._capi_scan_tables - scan_tables_before if ( not self._using_pybind_backend() @@ -372,6 +484,17 @@ def execute( ) if not query_result_internal.isSuccess(): raise RuntimeError(query_result_internal.getErrorMessage()) + for table_name in scan_tables_to_drop: + try: + drop_result = self._connection.drop_arrow_table(table_name) + if not drop_result.isSuccess(): + warnings.warn( + drop_result.getErrorMessage(), + RuntimeWarning, + stacklevel=2, + ) + finally: + self._capi_scan_tables.discard(table_name) current_query_result = QueryResult(self, query_result_internal) self._register_query_result(current_query_result) if not query_result_internal.hasNextQueryResult(): diff --git a/test/capi_xfails.py b/test/capi_xfails.py index b04b6f5..58b83ec 100644 --- a/test/capi_xfails.py +++ b/test/capi_xfails.py @@ -2,16 +2,7 @@ CAPI_XFAILS = frozenset( { - # Arrow memory-backed table APIs are pybind-only today. - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_basic", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_filtering", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pandas", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pyarrow", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_empty_result", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_count", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_arrow_node_and_rel_table", - "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_native_node_and_arrow_rel_table", - # Scanning from Python-owned DataFrame/Arrow/Polars objects is still pybind-only. + # Some Python-owned DataFrame/Polars scan cases still need pybind-compatible conversion. "test/test_async_connection.py::test_async_scan_df", "test/test_scan_pandas.py::test_scan_pandas", "test/test_scan_pandas.py::test_scan_pandas_timestamp", @@ -69,18 +60,6 @@ "test/test_scan_polars.py::test_scan_from_parameterized_df_docs_example_1", "test/test_scan_polars.py::test_scan_from_parameterized_df_docs_example_2", "test/test_scan_polars.py::test_scan_from_df_docs_example", - "test/test_scan_pyarrow.py::test_create_arrow_table_keeps_pyarrow_memory_alive", - "test/test_scan_pyarrow.py::test_pyarrow_basic", - "test/test_scan_pyarrow.py::test_pyarrow_copy_from_parameterized_df", - "test/test_scan_pyarrow.py::test_create_arrow_table_from_pyarrow_table", - "test/test_scan_pyarrow.py::test_pyarrow_to_filtered_pyarrow_table", - "test/test_scan_pyarrow.py::test_pyarrow_copy_from_invalid_source", - "test/test_scan_pyarrow.py::test_pyarrow_copy_from", - "test/test_scan_pyarrow.py::test_pyarrow_scan_ignore_errors", - "test/test_scan_pyarrow.py::test_pyarrow_scan_invalid_option", - "test/test_scan_pyarrow.py::test_copy_from_pyarrow_multi_pairs", - "test/test_scan_pyarrow.py::test_create_arrow_rel_table_from_pyarrow_table_query_results", - "test/test_scan_pyarrow.py::test_arrow_node_and_arrow_rel_with_filtering_query", # UDF registration is still routed through pybind. "test/test_blob_parameter.py::test_bytes_param_udf", "test/test_udf.py::test_udf",