Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ arrow = { workspace = true, features = ["pyarrow"] }
datafusion = { workspace = true }
datafusion-ffi = { workspace = true }
paimon = { path = "../../crates/paimon", features = ["storage-all"] }
paimon-datafusion = { path = "../../crates/integrations/datafusion" }
paimon-datafusion = { path = "../../crates/integrations/datafusion", features = ["fulltext"] }
pyo3 = { version = "0.28", features = ["abi3-py310"] }
tokio = { workspace = true }
36 changes: 35 additions & 1 deletion bindings/python/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use datafusion::catalog::CatalogProvider;
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use paimon::{CatalogFactory, Options};
use paimon_datafusion::{PaimonCatalogProvider, SQLContext};
use paimon_datafusion::{
register_full_text_search, register_vector_search, PaimonCatalogProvider, SQLContext,
};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

Expand Down Expand Up @@ -148,6 +151,37 @@ impl PySQLContext {
.map_err(df_to_py_err)
}

/// Registers a built-in Paimon table-valued function (UDTF) on the session
/// so it can be used in SQL, e.g.
/// `SELECT * FROM vector_search('items', 'embedding', '[1.0, 0.0]', 10)`.
///
/// `name` selects the function; supported values are `vector_search` and
/// `full_text_search`. The function is bound to the current catalog, so a
/// catalog must already be registered (the first `register_catalog` call
/// also sets it current). `default_database` defaults to `"default"` and
/// resolves the table-name argument the function receives in SQL.
#[pyo3(signature = (name, default_database=None))]
fn register_table_function(
&self,
name: String,
default_database: Option<String>,
) -> PyResult<()> {
let catalog = self.inner.current_catalog().map_err(df_to_py_err)?;
let default_database = default_database.as_deref().unwrap_or("default");
let ctx = self.inner.ctx();
match name.as_str() {
"vector_search" => register_vector_search(ctx, catalog, default_database),
"full_text_search" => register_full_text_search(ctx, catalog, default_database),
other => {
return Err(PyValueError::new_err(format!(
"unknown table function '{other}'; \
supported: 'vector_search', 'full_text_search'"
)))
}
}
Ok(())
}

fn sql(&self, py: Python<'_>, sql: String) -> PyResult<Vec<Py<PyAny>>> {
let rt = runtime();
let batches = rt.block_on(async {
Expand Down
49 changes: 49 additions & 0 deletions bindings/python/tests/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,52 @@ def test_register_batch_invalid_catalog():
assert False, "Expected an error for unknown catalog"
except Exception as e:
assert "unknown_catalog" in str(e).lower() or "not a paimon" in str(e).lower() or "unknown" in str(e).lower()


def test_register_table_function_vector_search():
with tempfile.TemporaryDirectory() as warehouse:
ctx = SQLContext()
ctx.register_catalog("paimon", {"warehouse": warehouse})

# Registering against the current catalog should not raise.
ctx.register_table_function("vector_search")


def test_register_table_function_full_text_search():
with tempfile.TemporaryDirectory() as warehouse:
ctx = SQLContext()
ctx.register_catalog("paimon", {"warehouse": warehouse})

ctx.register_table_function("full_text_search")


def test_register_table_function_with_default_database():
with tempfile.TemporaryDirectory() as warehouse:
ctx = SQLContext()
ctx.register_catalog("paimon", {"warehouse": warehouse})

# The optional default_database keyword is accepted.
ctx.register_table_function("vector_search", default_database="default")


def test_register_table_function_unknown_name():
with tempfile.TemporaryDirectory() as warehouse:
ctx = SQLContext()
ctx.register_catalog("paimon", {"warehouse": warehouse})

try:
ctx.register_table_function("does_not_exist")
assert False, "Expected an error for an unknown table function"
except Exception as e:
assert "unknown table function" in str(e).lower()
assert "does_not_exist" in str(e)


def test_register_table_function_without_catalog():
# With no catalog registered there is no current catalog to bind to.
ctx = SQLContext()
try:
ctx.register_table_function("vector_search")
assert False, "Expected an error when no catalog is registered"
except Exception as e:
assert "catalog" in str(e).lower()
7 changes: 6 additions & 1 deletion crates/integrations/datafusion/src/sql_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,12 @@ impl SQLContext {
.clone()
}

fn current_catalog(&self) -> DFResult<Arc<dyn Catalog>> {
/// Returns the Paimon catalog currently set as default.
///
/// Exposed so callers that need the registered [`Catalog`] (for example to
/// register a table-valued function against it) can retrieve it without
/// keeping a duplicate handle of their own.
pub fn current_catalog(&self) -> DFResult<Arc<dyn Catalog>> {
let name = self.current_catalog_name();
self.catalogs.get(&name).cloned().ok_or_else(|| {
DataFusionError::Plan(
Expand Down
Loading