Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#pragma once

#include <memory>
#include <unordered_map>

#include "main/storage_driver.h"
#include "py_database.h"
#include "py_handle_state.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"

Expand All @@ -20,7 +22,7 @@ class PyConnection {

void close();

~PyConnection() = default;
~PyConnection();

void setQueryTimeout(uint64_t timeoutInMS);
void interrupt();
Expand All @@ -29,8 +31,7 @@ class PyConnection {
const py::dict& params);

std::unique_ptr<PyQueryResult> query(const std::string& statement);
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement,
int64_t chunkSize);
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement, int64_t chunkSize);

void setMaxNumThreadForExec(uint64_t numThreads);

Expand Down Expand Up @@ -65,10 +66,10 @@ class PyConnection {
const LogicalType& type);

private:
std::unique_ptr<StorageDriver> storageDriver;
std::unique_ptr<Connection> conn;
std::unordered_map<std::string, py::object> arrowTableRefs;
PyConnectionState& refState() const;

std::shared_ptr<PyConnectionState> state;

static std::unique_ptr<PyQueryResult> checkAndWrapQueryResult(
std::unique_ptr<QueryResult>& queryResult);
std::unique_ptr<QueryResult>& queryResult, std::shared_ptr<PyConnectionState> state);
};
6 changes: 4 additions & 2 deletions src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#pragma once

#include <memory>

#include "main/lbug.h"
#include "main/storage_driver.h"
#include "py_handle_state.h"
#include "pybind_include.h" // IWYU pragma: keep (used for py:: namespace)
#define PYBIND11_DETAILED_ERROR_MESSAGES
using namespace lbug::main;
Expand Down Expand Up @@ -30,6 +33,5 @@ class PyDatabase {
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads);

private:
std::unique_ptr<Database> database;
std::unique_ptr<StorageDriver> storageDriver;
std::shared_ptr<PyDatabaseState> state;
};
94 changes: 94 additions & 0 deletions src_cpp/include/py_handle_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#pragma once

#include <memory>
#include <string>
#include <unordered_map>

#include "common/exception/runtime.h"
#include "main/lbug.h"
#include "main/prepared_statement.h"
#include "main/storage_driver.h"
#include "pybind_include.h"

struct PyDatabaseState {
std::unique_ptr<lbug::main::Database> database;
std::unique_ptr<lbug::main::StorageDriver> storageDriver;

~PyDatabaseState() { closeNative(); }

void closeNative() {
storageDriver.reset();
database.reset();
}

lbug::main::Database& ref() const {
if (database == nullptr) {
throw lbug::common::RuntimeException("Database is closed.");
}
return *database;
}

lbug::main::StorageDriver& storage() const {
if (storageDriver == nullptr) {
throw lbug::common::RuntimeException("Database is closed.");
}
return *storageDriver;
}
};

struct PyConnectionState {
std::shared_ptr<PyDatabaseState> database;
std::unique_ptr<lbug::main::StorageDriver> storageDriver;
std::unique_ptr<lbug::main::Connection> conn;
std::unordered_map<std::string, py::object> arrowTableRefs;

~PyConnectionState() { closeNative(); }

void closeNative() {
arrowTableRefs.clear();
conn.reset();
storageDriver.reset();
database.reset();
}

lbug::main::Connection& ref() const {
if (conn == nullptr) {
throw lbug::common::RuntimeException("Connection is closed.");
}
return *conn;
}

lbug::main::StorageDriver& storage() const {
if (storageDriver == nullptr) {
throw lbug::common::RuntimeException("Connection is closed.");
}
return *storageDriver;
}
};

struct PyPreparedStatementState {
std::shared_ptr<PyConnectionState> connection;
std::unique_ptr<lbug::main::PreparedStatement> preparedStatement;

lbug::main::PreparedStatement& ref() const {
if (preparedStatement == nullptr) {
throw lbug::common::RuntimeException("Prepared statement is closed.");
}
return *preparedStatement;
}
};

struct PyQueryResultState {
std::shared_ptr<PyConnectionState> connection;
std::shared_ptr<PyQueryResultState> parent;
std::unique_ptr<lbug::main::QueryResult> owned;
lbug::main::QueryResult* borrowed = nullptr;

lbug::main::QueryResult& ref() const {
auto* result = owned != nullptr ? owned.get() : borrowed;
if (result == nullptr) {
throw lbug::common::RuntimeException("Query result is closed.");
}
return *result;
}
};
3 changes: 2 additions & 1 deletion src_cpp/include/py_prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "main/lbug.h"
#include "main/prepared_statement.h"
#include "py_handle_state.h"
#include "pybind_include.h"

using namespace lbug::main;
Expand All @@ -17,5 +18,5 @@ class PyPreparedStatement {
bool isSuccess() const;

private:
std::unique_ptr<PreparedStatement> preparedStatement;
std::shared_ptr<PyPreparedStatementState> state;
};
6 changes: 4 additions & 2 deletions src_cpp/include/py_query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "arrow_array.h"
#include "common/arrow/arrow.h"
#include "main/lbug.h"
#include "py_handle_state.h"
#include "pybind_include.h"

using namespace lbug::main;
Expand Down Expand Up @@ -54,6 +55,8 @@ class PyQueryResult {
size_t getNumTuples();

private:
PyQueryResultState& refState() const;

static py::dict convertNodeIdToPyDict(const lbug::common::nodeID_t& nodeId);

void getNextArrowChunk(const std::vector<lbug::common::LogicalType>& types,
Expand All @@ -63,6 +66,5 @@ class PyQueryResult {
const std::vector<std::string>& names, std::int64_t chunkSize, bool fallbackExtensionTypes);

private:
QueryResult* queryResult = nullptr;
bool isOwned = false;
std::shared_ptr<PyQueryResultState> state;
};
Loading
Loading