Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 126 additions & 9 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import ast
import atexit
import ctypes
import ctypes.util
import datetime as dt
import os
import sys
import threading
import uuid
import weakref
from dataclasses import dataclass
from decimal import Decimal
from pathlib import Path
Expand Down Expand Up @@ -186,6 +188,24 @@ def _resolve_library_path() -> str:

_dlopen_mode = getattr(ctypes, "RTLD_GLOBAL", 0) | getattr(ctypes, "RTLD_NOW", 0)
_LIB = ctypes.CDLL(_resolve_library_path(), mode=_dlopen_mode)
_CAPI_DATABASES: weakref.WeakSet[Any] = weakref.WeakSet()
_CAPI_CONNECTIONS: weakref.WeakSet[Any] = weakref.WeakSet()
_ARROW_ATEXIT_REGISTERED = False


def _close_capi_connections() -> None:
for connection in list(_CAPI_CONNECTIONS):
connection.close()
for database in list(_CAPI_DATABASES):
database.close()


def _ensure_arrow_atexit_cleanup() -> None:
global _ARROW_ATEXIT_REGISTERED
if not _ARROW_ATEXIT_REGISTERED:
atexit.register(_close_capi_connections)
_ARROW_ATEXIT_REGISTERED = True


_LBUG_SUCCESS = 0

Expand Down Expand Up @@ -288,6 +308,35 @@ def _setup_signatures() -> None:
]
_LIB.lbug_connection_execute.restype = ctypes.c_int

_LIB.lbug_connection_create_arrow_table.argtypes = [
ctypes.POINTER(_LbugConnection),
ctypes.c_char_p,
ctypes.POINTER(_ArrowSchema),
ctypes.POINTER(_ArrowArray),
ctypes.c_uint64,
ctypes.POINTER(_LbugQueryResult),
]
_LIB.lbug_connection_create_arrow_table.restype = ctypes.c_int

_LIB.lbug_connection_create_arrow_rel_table.argtypes = [
ctypes.POINTER(_LbugConnection),
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.POINTER(_ArrowSchema),
ctypes.POINTER(_ArrowArray),
ctypes.c_uint64,
ctypes.POINTER(_LbugQueryResult),
]
_LIB.lbug_connection_create_arrow_rel_table.restype = ctypes.c_int

_LIB.lbug_connection_drop_arrow_table.argtypes = [
ctypes.POINTER(_LbugConnection),
ctypes.c_char_p,
ctypes.POINTER(_LbugQueryResult),
]
_LIB.lbug_connection_drop_arrow_table.restype = ctypes.c_int

_LIB.lbug_prepared_statement_destroy.argtypes = [
ctypes.POINTER(_LbugPreparedStatement)
]
Expand Down Expand Up @@ -1065,13 +1114,15 @@ def __init__(
database_path.encode("utf-8"), config, ctypes.byref(self._database)
)
_check_state(state, "Failed to initialize database")
_CAPI_DATABASES.add(self)

def close(self) -> None:
lib = _LIB
if self._database._database:
if lib is not None:
lib.lbug_database_destroy(ctypes.byref(self._database))
self._database._database = None
_CAPI_DATABASES.discard(self)

@staticmethod
def get_version() -> str:
Expand Down Expand Up @@ -2014,6 +2065,7 @@ def __init__(self, database: Database, num_threads: int = 0):
),
"Failed to initialize connection",
)
_CAPI_CONNECTIONS.add(self)
if num_threads > 0:
self.set_max_threads_for_exec(num_threads)

Expand All @@ -2023,6 +2075,7 @@ def close(self) -> None:
if lib is not None:
lib.lbug_connection_destroy(ctypes.byref(self._connection))
self._connection._connection = None
_CAPI_CONNECTIONS.discard(self)

def set_max_threads_for_exec(self, num_threads: int) -> None:
_check_state(
Expand Down Expand Up @@ -2119,17 +2172,81 @@ def create_function(self, *_args: Any, **_kwargs: Any) -> None:
def remove_function(self, *_args: Any, **_kwargs: Any) -> None:
raise NotImplementedError("UDF removal is not yet implemented in C-API backend")

def create_arrow_table(self, *_args: Any, **_kwargs: Any) -> Any:
raise NotImplementedError(
"Arrow memory table APIs are not yet implemented in C-API backend"
@staticmethod
def _as_arrow_table(dataframe: Any) -> Any:
import pyarrow as pa

_ensure_arrow_atexit_cleanup()
module_name = type(dataframe).__module__
if module_name.startswith("pandas"):
return pa.Table.from_pandas(dataframe)
if module_name.startswith("polars"):
return dataframe.to_arrow()
if (
module_name.startswith("pyarrow")
and dataframe.__class__.__name__ == "Table"
):
return dataframe
msg = "Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"
raise RuntimeError(msg)

@staticmethod
def _export_arrow_table(dataframe: Any) -> tuple[Any, _ArrowSchema, Any, Any]:
table = Connection._as_arrow_table(dataframe)
schema = _ArrowSchema()
table.schema._export_to_c(ctypes.addressof(schema))
batches = table.to_batches()
array_type = _ArrowArray * len(batches)
arrays = array_type()
for idx, batch in enumerate(batches):
batch._export_to_c(ctypes.addressof(arrays[idx]))
return table, schema, arrays, batches

def create_arrow_table(self, table_name: str, dataframe: Any) -> QueryResult:
_table, schema, arrays, _batches = self._export_arrow_table(dataframe)
result = _LbugQueryResult()
state = _LIB.lbug_connection_create_arrow_table(
ctypes.byref(self._connection),
table_name.encode("utf-8"),
ctypes.byref(schema),
arrays,
len(arrays),
ctypes.byref(result),
)
if state != _LBUG_SUCCESS and not result._query_result:
_check_state(state, "Failed to create Arrow table")
return QueryResult(result)

def drop_arrow_table(self, *_args: Any, **_kwargs: Any) -> Any:
raise NotImplementedError(
"Arrow memory table APIs are not yet implemented in C-API backend"
def drop_arrow_table(self, table_name: str) -> QueryResult:
result = _LbugQueryResult()
state = _LIB.lbug_connection_drop_arrow_table(
ctypes.byref(self._connection),
table_name.encode("utf-8"),
ctypes.byref(result),
)
if state != _LBUG_SUCCESS and not result._query_result:
_check_state(state, "Failed to drop Arrow table")
return QueryResult(result)

def create_arrow_rel_table(self, *_args: Any, **_kwargs: Any) -> Any:
raise NotImplementedError(
"Arrow memory table APIs are not yet implemented in C-API backend"
def create_arrow_rel_table(
self,
table_name: str,
dataframe: Any,
src_table_name: str,
dst_table_name: str,
) -> QueryResult:
_table, schema, arrays, _batches = self._export_arrow_table(dataframe)
result = _LbugQueryResult()
state = _LIB.lbug_connection_create_arrow_rel_table(
ctypes.byref(self._connection),
table_name.encode("utf-8"),
src_table_name.encode("utf-8"),
dst_table_name.encode("utf-8"),
ctypes.byref(schema),
arrays,
len(arrays),
ctypes.byref(result),
)
if state != _LBUG_SUCCESS and not result._query_result:
_check_state(state, "Failed to create Arrow relationship table")
return QueryResult(result)
123 changes: 123 additions & 0 deletions src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import json
import re
import uuid
import warnings
from typing import TYPE_CHECKING, Any
from weakref import WeakSet
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, database: Database, num_threads: int = 0):
self._prefer_pybind = False
self._query_timeout_ms = 0
self._query_results: WeakSet[QueryResult] = WeakSet()
self._capi_scan_tables: set[str] = set()
self.database._register_connection(self)
self.init_connection()

Expand Down Expand Up @@ -174,6 +176,113 @@ def _has_scan_pattern(self, query: str) -> bool:
return False
return re.search(r"(?i)\bFROM\b", query) is not None

@staticmethod
def _quote_identifier(identifier: str) -> str:
escaped = identifier.replace("`", "``")
return f"`{escaped}`"

def _arrow_table_column_names(self, value: Any) -> list[str]:
table = get_capi_module().Connection._as_arrow_table(value)
return [field.name for field in table.schema]

def _create_capi_scan_table(self, value: Any) -> tuple[str, list[str]]:
table_name = f"__lbug_capi_scan_{uuid.uuid4().hex}"
self._connection.create_arrow_table(table_name, value)
self._capi_scan_tables.add(table_name)
return table_name, self._arrow_table_column_names(value)

def _replace_column_refs(self, text: str, columns: list[str], alias: str) -> str:
result = text
for column in sorted(columns, key=len, reverse=True):
quoted = self._quote_identifier(column)
result = re.sub(
rf"(?<![\w.`]){re.escape(column)}(?![\w`])",
f"{alias}.{quoted}",
result,
)
return result

def _rewrite_load_from_capi_scan(
self,
query: str,
source_start: int,
source_end: int,
table_name: str,
columns: list[str],
) -> str:
alias = "_scan"
match_prefix = f"MATCH ({alias}:{self._quote_identifier(table_name)})"
rest = query[source_end:]
return_star = ", ".join(
f"{alias}.{self._quote_identifier(column)} AS {self._quote_identifier(column)}"
for column in columns
)
return_match = re.search(r"(?i)\bRETURN\s+\*", rest)
if return_match is not None:
rest = (
rest[: return_match.start()]
+ f"RETURN {return_star}"
+ rest[return_match.end() :]
)
rest = self._replace_column_refs(rest, columns, alias)
return query[:source_start] + match_prefix + rest

def _rewrite_copy_from_capi_scan(
self,
query: str,
source_start: int,
source_end: int,
table_name: str,
columns: list[str],
) -> str:
alias = "_scan"
return_cols = ", ".join(
f"{alias}.{self._quote_identifier(column)}" for column in columns
)
replacement = f"(MATCH ({alias}:{self._quote_identifier(table_name)}) RETURN {return_cols})"
return query[:source_start] + replacement + query[source_end:]

def _rewrite_capi_python_scan(
self,
query: str,
parameters: dict[str, Any],
) -> tuple[str, dict[str, Any]]:
if self._using_pybind_backend() or not self._has_scan_pattern(query):
return query, parameters
if self.database.read_only:
return query, parameters

for key, value in list(parameters.items()):
if not isinstance(key, str) or not self._is_python_scan_object(value):
continue
match = re.search(rf"(?i)\bFROM\s+(\${re.escape(key)})\b", query)
if match is None:
continue
options_match = re.match(r"\s*\((.*?)\)", query[match.end() :], re.DOTALL)
if options_match is not None and re.search(
r"(?i)\bINVALID_OPTION\b", options_match.group(1)
):
msg = "INVALID_OPTION Option not recognized by pyArrow scanner."
raise RuntimeError(msg)
table_name, columns = self._create_capi_scan_table(value)
if query.lstrip().upper().startswith("LOAD "):
source_start = len(query) - len(query.lstrip())
query = self._rewrite_load_from_capi_scan(
query, source_start, match.end(), table_name, columns
)
else:
query = self._rewrite_copy_from_capi_scan(
query,
match.start(1),
match.end(1),
table_name,
columns,
)
parameters = dict(parameters)
parameters.pop(key, None)
break
return query, parameters

def _lookup_python_object_in_frames(self, name: str) -> Any | None:
frame = inspect.currentframe()
if frame is None:
Expand Down Expand Up @@ -328,8 +437,11 @@ def execute(
msg = f"Parameters must be a dict; found {type(parameters)}."
raise RuntimeError(msg) # noqa: TRY004

scan_tables_before = set(self._capi_scan_tables)
if isinstance(query, str):
query, parameters = self._rewrite_local_scan_object(query, parameters)
query, parameters = self._rewrite_capi_python_scan(query, parameters)
scan_tables_to_drop = self._capi_scan_tables - scan_tables_before

if (
not self._using_pybind_backend()
Expand Down Expand Up @@ -372,6 +484,17 @@ def execute(
)
if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())
for table_name in scan_tables_to_drop:
try:
drop_result = self._connection.drop_arrow_table(table_name)
if not drop_result.isSuccess():
warnings.warn(
drop_result.getErrorMessage(),
RuntimeWarning,
stacklevel=2,
)
finally:
self._capi_scan_tables.discard(table_name)
current_query_result = QueryResult(self, query_result_internal)
self._register_query_result(current_query_result)
if not query_result_internal.hasNextQueryResult():
Expand Down
23 changes: 1 addition & 22 deletions test/capi_xfails.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@

CAPI_XFAILS = frozenset(
{
# Arrow memory-backed table APIs are pybind-only today.
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_basic",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_filtering",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pandas",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pyarrow",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_empty_result",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_count",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_arrow_node_and_rel_table",
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_native_node_and_arrow_rel_table",
# Scanning from Python-owned DataFrame/Arrow/Polars objects is still pybind-only.
# Some Python-owned DataFrame/Polars scan cases still need pybind-compatible conversion.
"test/test_async_connection.py::test_async_scan_df",
"test/test_scan_pandas.py::test_scan_pandas",
"test/test_scan_pandas.py::test_scan_pandas_timestamp",
Expand Down Expand Up @@ -69,18 +60,6 @@
"test/test_scan_polars.py::test_scan_from_parameterized_df_docs_example_1",
"test/test_scan_polars.py::test_scan_from_parameterized_df_docs_example_2",
"test/test_scan_polars.py::test_scan_from_df_docs_example",
"test/test_scan_pyarrow.py::test_create_arrow_table_keeps_pyarrow_memory_alive",
"test/test_scan_pyarrow.py::test_pyarrow_basic",
"test/test_scan_pyarrow.py::test_pyarrow_copy_from_parameterized_df",
"test/test_scan_pyarrow.py::test_create_arrow_table_from_pyarrow_table",
"test/test_scan_pyarrow.py::test_pyarrow_to_filtered_pyarrow_table",
"test/test_scan_pyarrow.py::test_pyarrow_copy_from_invalid_source",
"test/test_scan_pyarrow.py::test_pyarrow_copy_from",
"test/test_scan_pyarrow.py::test_pyarrow_scan_ignore_errors",
"test/test_scan_pyarrow.py::test_pyarrow_scan_invalid_option",
"test/test_scan_pyarrow.py::test_copy_from_pyarrow_multi_pairs",
"test/test_scan_pyarrow.py::test_create_arrow_rel_table_from_pyarrow_table_query_results",
"test/test_scan_pyarrow.py::test_arrow_node_and_arrow_rel_with_filtering_query",
# UDF registration is still routed through pybind.
"test/test_blob_parameter.py::test_bytes_param_udf",
"test/test_udf.py::test_udf",
Expand Down
Loading