From 2db8db87211dc6354f4e45759f24e7803cbabbb8 Mon Sep 17 00:00:00 2001 From: Andrew de Waal Date: Mon, 15 Jun 2026 07:18:44 -0700 Subject: [PATCH 1/2] feat: register databases by key and validate paths at setup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace path-as-identifier IPC with explicit key registration so the frontend and Rust callers open databases by stable keys (e.g. "MAIN") instead of filesystem paths. Paths are resolved once in Rust during plugin setup, validated, and cached; later opens only look up the key. Problems solved: - Security: untrusted frontend can no longer supply arbitrary paths over IPC; unknown keys fail with PATH_NOT_REGISTERED. - Cross-language identity: avoids TS/Rust path string mismatches (separators, symlinks, canonicalization) on every load. - Runtime path discovery: `on_setup` + `SetupRegistrar` register paths from `app.path()` / platform resolvers once, without repeat JNI or resolver work on each open. - Consistent open path: `load`, `Connection::connect`, and migrations share `connect_to_database` and the same `DbInstances` cache, along with assurance that migrations have completed before connecting to the database. - Safer registration: `validate_database_path` rejects relative paths, traversal, null bytes, and canonicalizes file paths at startup (fail-fast INVALID_PATH / PATH_TRAVERSAL). API changes: - `add_migrations(path, migrator)` → `register_database(key, path, migrator?)` (returns Result; adds `Builder::on_setup` / `SetupRegistrar`) - IPC/command args: `db` → `dbKey`; attached `databasePath` → `databaseKey` - `MigrationEvent`: adds `dbKey`, `dbPath` is now absolute PathBuf - `TransactionToken`: `dbPath` → `dbKey` - `Connection` trait on `AppHandle` for Rust-side opens by key Also: expose `canonicalize_database_path` from conn-mgr, tighten `is_memory_database` query-param matching, add toolkit `:memory:` tests, update README/guest-js rustdoc, and add `validate.rs` (replaces load-time path resolution in `resolve.rs`). --- CHANGELOG.md | 40 + Cargo.toml | 1 + README.md | 243 +++++-- api-iife.js | 2 +- crates/sqlx-sqlite-conn-mgr/README.md | 3 +- crates/sqlx-sqlite-conn-mgr/src/config.rs | 1 + crates/sqlx-sqlite-conn-mgr/src/lib.rs | 2 + crates/sqlx-sqlite-conn-mgr/src/registry.rs | 45 +- .../sqlx-sqlite-toolkit/tests/memory_tests.rs | 81 +++ guest-js/index.test.ts | 196 +++-- guest-js/index.ts | 164 +++-- package.json | 3 +- src/commands.rs | 204 ++---- src/error.rs | 23 + src/lib.rs | 684 ++++++++++++++++-- src/resolve.rs | 221 ------ src/subscriptions.rs | 16 +- src/validate.rs | 171 +++++ 18 files changed, 1468 insertions(+), 632 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 crates/sqlx-sqlite-toolkit/tests/memory_tests.rs delete mode 100644 src/resolve.rs create mode 100644 src/validate.rs diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..330fe9c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,40 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +## [0.2.0] - Unreleased + +### Breaking Changes + +#### Database registration by key + +Databases must be registered on the Rust side with a stable key before they can be opened. The frontend and Rust callers open databases by **key**, not filesystem path. + +- **`Builder::add_migrations(path, migrator)`** replaced by **`Builder::register_database(key, path, migrator?)`**, which returns `Result`. +- Added **`Builder::on_setup`** and **`SetupRegistrar`** for runtime path registration (e.g. from `app.path().app_data_dir()`). +- **`Builder`** is now generic over Tauri `Runtime`; use `Builder::::new()` where a turbofish is required. + +#### Frontend API + +- First argument to **`Database.load()`** and **`Database.get()`** is a registration **key**, not a path. Passing a path string type-checks but fails at runtime with `PATH_NOT_REGISTERED`. +- **`Database.get(dbKey)`** is synchronous again; it defers connection until the first operation that requires a loaded database. Use **`Database.load(dbKey, customConfig?)`** to connect eagerly or pass pool configuration. +- IPC command arguments: **`db`** → **`dbKey`**; attached **`databasePath`** → **`databaseKey`**. +- **`MigrationEvent`**: adds **`dbKey`**; **`dbPath`** is now an absolute path. +- **`TransactionToken`**: **`dbPath`** → **`dbKey`**. + +#### Rust API + +- Added **`Connection`** trait on **`AppHandle`** for Rust-side opens by registration key. +- **`Builder::build()`** returns **`Result`**; duplicate paths across distinct registration keys fail with **`INVALID_CONFIG`**. +- File paths must be **absolute** at registration; relative path resolution at load time was removed. +- Invalid registration paths fail at startup (`INVALID_PATH`, `PATH_TRAVERSAL`); unregistered keys fail at open time (`PATH_NOT_REGISTERED`). + +### Added + +- Path validation module (`validate.rs`) with canonicalization at registration time. +- Parent directory auto-creation during registration validation for file paths. +- CI check that committed `api-iife.js` matches a fresh Rollup build. + +### Fixed + +- Regenerated `api-iife.js` so all IPC calls use `dbKey` consistently. diff --git a/Cargo.toml b/Cargo.toml index 7271450..a5864fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,5 +47,6 @@ futures = "0.3.31" tauri-plugin = { version = "2.5.1", features = ["build"] } [dev-dependencies] +tauri = { version = "2.9.3", features = ["test"] } tempfile = "3.23.0" tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] } diff --git a/README.md b/README.md index dc3aa6b..5c013bc 100644 --- a/README.md +++ b/README.md @@ -119,12 +119,97 @@ Register the plugin in your Tauri application: ```rust fn main() { tauri::Builder::default() - .plugin(tauri_plugin_sqlite::Builder::new().build()) + .plugin(tauri_plugin_sqlite::Builder::new().build().expect("failed to build sqlite plugin")) .run(tauri::generate_context!()) .expect("error while running tauri application"); } ``` +### Registering Databases + +Every database must be **registered** on the Rust side before it can be opened. +Registration assigns a stable **key** (for example `"MAIN"`) to a filesystem path or +in-memory URI. The frontend and Rust callers open databases by **key**, not by path. + +**Registration rules** (enforced when you call `register_database`): + + * File paths must be **absolute**, with no `..` components or null bytes + * File paths are **canonicalized** once at registration (symlink-safe when the path or + parent exists) + * In-memory URIs (`:memory:`, `file::memory:*`, and `file:` URIs with an exact + `mode=memory` query parameter) are accepted as-is + * Each registration **key** must map to a **distinct** database path; two keys for the + same path fail with `INVALID_CONFIG` when calling `build()` or during plugin setup + after `on_setup` merges registrations + +Invalid registration paths fail at startup with `INVALID_PATH` or `PATH_TRAVERSAL`. + +**Why keys:** + +1) `Database.load()` is callable from the frontend over IPC. The frontend sends + only a registration key; unregistered keys are rejected with `PATH_NOT_REGISTERED`. + This prevents untrusted frontend code from opening arbitrary files on disk. + +2) Keys avoid cross-language path string mismatches. With path-at-load, TS and Rust would + need identical canonical path strings on every open (slashes, symlinks, etc.). Keys + resolve the path once at registration; all later opens use the plain string key. + +3) Without registration keys, every call site would repeat that path discovery or keep its + own `PathBuf`. Registration stores the key-to-path mapping once; `connect` reuses the + key so callers do not supply a filesystem path on every open. + On mobile, path discovery is not a cheap string join. Resolvers such as + [tauri-plugin-fs-resolver](https://github.com/silvermine/tauri-plugin-fs-resolver) + call platform-native APIs so paths match OS sandbox rules. On Android that means a + JNI call into Kotlin `Context` (e.g. `getFilesDir()`) on each resolve — noticeably + more expensive than a local HashMap lookup, and a different kind of boundary than + TypeScript-to-Rust IPC (in-process JNI vs webview bridge). Register the resolved + `PathBuf` once in `on_setup`; every later `connect(database_key)` only looks up that + key in [`RegisteredDatabases`] — no repeat native or JNI work. + +Because legitimate paths usually depend on runtime values (app data directory, platform +path resolvers, etc.), registration normally happens in the `on_setup` hook: + +```rust +use tauri_plugin_sqlite::Builder; +use tauri::Manager; + +const MAIN_DB_KEY: &str = "MAIN"; + +fn main() { + tauri::Builder::default() + .plugin( + Builder::new() + .on_setup(|app, reg| { + let db = app.path().app_data_dir()?.join("main.db"); + reg.register_database(MAIN_DB_KEY, db, None)?; + Ok(()) + }) + .build() + .expect("failed to build sqlite plugin") + ) + .run(tauri::generate_context!()) + .expect("error while running tauri application"); +} +``` + +A static registration (compile-time path) can be registered on the builder directly: + +```rust +use tauri_plugin_sqlite::Builder; +use std::path::PathBuf; + +const MAIN_DB_KEY: &str = "MAIN"; + +# fn main() -> tauri_plugin_sqlite::Result<()> { +let _plugin = Builder::new() + .register_database(MAIN_DB_KEY, PathBuf::from("/var/lib/myapp/main.db"), None)? + .build()?; +# Ok(()) +# } +``` + +The frontend then calls `Database.load(MAIN_DB_KEY)` (see [Connecting](#connecting)). + ### Migrations This plugin uses [SQLx's migration system][sqlx-migrate]. Create numbered `.sql` @@ -139,23 +224,40 @@ src-tauri/migrations/ └── 0003_create_posts.sql ``` -Register migrations using SQLx's `migrate!()` macro, which embeds them at compile time: +Register migrations using SQLx's `migrate!()` macro, which embeds them at compile time. +Pass the migrator as the third argument to `register_database`. The `on_setup` hook is the +usual place to register app-derived paths: ```rust use tauri_plugin_sqlite::Builder; +use tauri::Manager; + +const MAIN_DB_KEY: &str = "MAIN"; fn main() { tauri::Builder::default() .plugin( Builder::new() - .add_migrations("main.db", sqlx::migrate!("./migrations")) + .on_setup(|app, reg| { + let db = app.path().app_data_dir()?.join("main.db"); + reg.register_database( + MAIN_DB_KEY, + db, + Some(sqlx::migrate!("./migrations")), + )?; + Ok(()) + }) .build() + .expect("failed to build sqlite plugin") ) .run(tauri::generate_context!()) .expect("error while running tauri application"); } ``` +The frontend must call `Database.load()` with the same **registration key** so migrations +are awaited correctly. + **Timing:** Migrations start automatically at plugin setup (non-blocking). When TypeScript calls `Database.load()`, it waits for migrations to complete before returning. If migrations fail, `load()` returns an error. Applied migrations are @@ -168,12 +270,15 @@ Use `getMigrationEvents()` to retrieve cached events: ```typescript import Database from '@silvermine/tauri-plugin-sqlite'; -const db = await Database.load('mydb.db'); +const MAIN_DB_KEY = 'MAIN'; + +// Same registration key used in Rust `register_database` +const db = await Database.load(MAIN_DB_KEY); // Get all migration events (including ones emitted before listener could be registered) const events = await db.getMigrationEvents(); for (const event of events) { - console.info(`${event.status}: ${event.dbPath}`); + console.info(`${event.status}: ${event.dbKey} (${event.dbPath})`); if (event.status === 'failed') { console.error(`Migration error: ${event.error}`); } @@ -188,29 +293,41 @@ import { listen } from '@tauri-apps/api/event'; import type { MigrationEvent } from '@silvermine/tauri-plugin-sqlite'; await listen('sqlite:migration', (event) => { - const { dbPath, status, migrationCount, error } = event.payload; - console.info(`Migration ${status} for ${dbPath}: ${migrationCount} migrations`, error); + const { dbKey, dbPath, status, migrationCount, error } = event.payload; + console.info(`Migration ${status} for ${dbKey} (${dbPath}): ${migrationCount} migrations`, error); }); ``` ### Connecting +Pass the **registration key** from Rust `register_database` (see +[Registering Databases](#registering-databases)). + ```typescript import Database from '@silvermine/tauri-plugin-sqlite'; -// Path is relative to app config directory (no sqlite: prefix needed) -let db = await Database.load('mydb.db'); +const MAIN_DB_KEY = 'MAIN'; + +// Connect (no sqlite: prefix needed) +let db = await Database.load(MAIN_DB_KEY); // With custom configuration -db = await Database.load('mydb.db', { +db = await Database.load(MAIN_DB_KEY, { maxReadConnections: 10, // default: 6 idleTimeoutSecs: 60 // default: 30 }); -// Lazy initialization (connects on first query) -db = Database.get('mydb.db'); +// Lazy initialization (connects on first query; sync — no await) +db = Database.get(MAIN_DB_KEY); + +// In-memory: register first, then load by key +// reg.register_database('MEM', ':memory:', None)? in on_setup +const mem = await Database.load('MEM'); ``` +An unregistered key throws `PATH_NOT_REGISTERED`. Invalid paths are rejected at +registration time on the Rust side (`INVALID_PATH`, `PATH_TRAVERSAL`). + ### Parameter Binding All query methods use `$1`, `$2`, etc. syntax with `SqlValue` types: @@ -400,16 +517,20 @@ tables. **Builder Pattern:** All query methods (`execute`, `executeTransaction`, `fetchAll`, `fetchOne`, `fetchPage`) return builders that support `.attach()` -for cross-database operations. +for cross-database operations. Each attached database must already be loaded and is +identified by its **registration key** (see +[Registering Databases](#registering-databases)). ```typescript +const ORDERS_DB_KEY = 'ORDERS'; + // Join data from multiple databases const results = await db.fetchAll( 'SELECT u.name, o.total FROM users u JOIN orders.orders o ON u.id = o.user_id', [] ).attach([ { - databasePath: 'orders.db', + databaseKey: ORDERS_DB_KEY, schemaName: 'orders', mode: 'readOnly' } @@ -422,14 +543,13 @@ await db.execute( ['archived'] ).attach([ { - databasePath: 'archive.db', + databaseKey: 'ARCHIVE', schemaName: 'archive', mode: 'readOnly' } ]); // Atomic writes across multiple databases -// Assuming userId and total are defined in your application context const userId = 123; const total = 99.99; @@ -438,7 +558,7 @@ await db.executeTransaction([ ['UPDATE stats.order_count SET count = count + 1', []] ]).attach([ { - databasePath: 'stats.db', + databaseKey: 'STATS', schemaName: 'stats', mode: 'readWrite' } @@ -536,7 +656,9 @@ Common error codes: * `SQLITE_CONSTRAINT` - Constraint violation (unique, foreign key, etc.) * `SQLITE_NOTFOUND` - Table or column not found * `DATABASE_NOT_LOADED` - Database hasn't been loaded yet - * `INVALID_PATH` - Invalid database path + * `INVALID_PATH` - Invalid path at registration (relative or failed canonicalization) + * `PATH_NOT_REGISTERED` - Registration key not found + * `PATH_TRAVERSAL` - Registration path contains `..` or null bytes * `IO_ERROR` - File system error * `MIGRATION_ERROR` - Migration failed * `MULTIPLE_ROWS_RETURNED` - `fetchOne()` returned multiple rows @@ -557,8 +679,8 @@ await db.remove(); // Close and DELETE database file(s) - irreversible | Method | Description | | ------ | ----------- | -| `Database.load(path, config?)` | Connect and return Database instance (or existing) | -| `Database.get(path)` | Get instance without connecting (lazy init) | +| `Database.load(dbKey, config?)` | Connect eagerly and return Database instance (or existing) | +| `Database.get(dbKey)` | Sync handle; connects on first query (no `customConfig`) | | `Database.close_all()` | Close all database connections | ### Instance Methods @@ -619,7 +741,7 @@ interface CustomConfig { } interface AttachedDatabaseSpec { - databasePath: string; // Path relative to app config directory + databaseKey: string; // Registration key of a database already loaded via load() schemaName: string; // Schema name for accessing tables (e.g., 'orders') mode: 'readOnly' | 'readWrite'; } @@ -672,25 +794,44 @@ type TableChangeEvent = ## Rust-Only API -For Rust code that needs direct database access without going through Tauri commands, -use `DatabaseWrapper`. - -### Setup (Rust) +For Rust code in a Tauri app, register databases first, then open by key using the +[`Connection`](src/lib.rs) trait on `AppHandle`. This uses the same open path as the +frontend `load` command (`connect_to_database`). -```rust -use tauri_plugin_sqlite::DatabaseWrapper; -use std::path::PathBuf; +For standalone Rust projects without the plugin, use `DatabaseWrapper::connect(path)` +from [`sqlx-sqlite-toolkit`](crates/sqlx-sqlite-toolkit/) directly (no registration). -// Load a database -let mut db = DatabaseWrapper::load(PathBuf::from("/path/to/mydb.db"), None).await?; +### Setup (Tauri plugin) -// With custom configuration -use tauri_plugin_sqlite::CustomConfig; -let config = CustomConfig { - max_read_connections: Some(10), - idle_timeout_secs: Some(60), -}; -db = DatabaseWrapper::load(PathBuf::from("/path/to/mydb.db"), Some(config)).await?; +```rust +use tauri::{Manager, Runtime}; +use tauri_plugin_sqlite::{Builder, Connection, SqliteDatabaseConfig}; + +const MAIN_DB_KEY: &str = "MAIN"; + +// In lib.rs setup — register key + path +Builder::new() + .on_setup(|app, reg| { + let db = app.path().app_data_dir()?.join("main.db"); + reg.register_database(MAIN_DB_KEY, db, None)?; + Ok(()) + }) + .build()?; + +// Anywhere with AppHandle +async fn example(app: tauri::AppHandle) -> tauri_plugin_sqlite::Result<()> { + let db = app.connect(MAIN_DB_KEY).await?; + + let db = app.connect_with_config( + MAIN_DB_KEY, + SqliteDatabaseConfig { + max_read_connections: 10, + idle_timeout_secs: 60, + }, + ).await?; + + Ok(()) +} ``` ### Basic Operations @@ -824,17 +965,17 @@ tx.commit().await?; ### Cross-Database Operations -Attach other databases for cross-database queries. For Rust API usage, you need to load -both databases first, then create `AttachedSpec` instances using their inner database -references: +Attach other databases for cross-database queries. Load each database by registration +key first (`app.connect("STATS").await?`), then create `AttachedSpec` instances using +their inner database references: ```rust -use tauri_plugin_sqlite::{DatabaseWrapper, AttachedSpec, AttachedMode}; +use tauri_plugin_sqlite::{Connection, AttachedSpec, AttachedMode}; use std::sync::Arc; -// Load both databases -let main_db = DatabaseWrapper::load("/path/to/main.db".into(), None).await?; -let stats_db = DatabaseWrapper::load("/path/to/stats.db".into(), None).await?; +// After registering and connecting both databases by key +let main_db = app.connect("MAIN").await?; +let stats_db = app.connect("STATS").await?; // Create attached spec using the inner database reference let stats_spec = AttachedSpec { @@ -853,8 +994,7 @@ let results = main_db.execute_transaction(vec![ println!("Cross-database transaction completed: {} statements", results.len()); // Interruptible transaction with attached database -// Load the inventory database -let inventory_db = DatabaseWrapper::load("/path/to/inventory.db".into(), None).await?; +let inventory_db = app.connect("INVENTORY").await?; // Create spec for inventory database let inv_spec = AttachedSpec { @@ -888,7 +1028,7 @@ db.remove().await?; // Close and DELETE database file(s) | Method | Description | | ------ | ----------- | -| `load(path, config?)` | Load database, returns `DatabaseWrapper` | +| `connect(path, config?)` | Open database by path, returns `DatabaseWrapper` | | `execute(query, values)` | Execute write query | | `execute_transaction(statements)` | Execute statements atomically (builder) | | `begin_interruptible_transaction()` | Begin interruptible transaction (builder) | @@ -936,7 +1076,7 @@ fn init_tracing() {} fn main() { init_tracing(); tauri::Builder::default() - .plugin(tauri_plugin_sqlite::Builder::new().build()) + .plugin(tauri_plugin_sqlite::Builder::new().build().expect("failed to build sqlite plugin")) .run(tauri::generate_context!()) .expect("error while running tauri application"); } @@ -985,9 +1125,10 @@ pagination to keep memory usage bounded on both the Rust and TypeScript sides. ### Path Validation -Database paths are validated to prevent directory traversal. Absolute paths, -`..` segments, and null bytes are rejected. All paths are resolved relative to -the app config directory. +Registration validates filesystem paths once (absolute, no traversal, canonicalized). +In-memory URIs are accepted as-is. At runtime, `Database.load(dbKey)` and +`Connection::connect(dbKey)` only accept **registered keys**; unknown keys return +`PATH_NOT_REGISTERED`. ## Development diff --git a/api-iife.js b/api-iife.js index 5fdfc7a..74ff231 100644 --- a/api-iife.js +++ b/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(t){"use strict";function e(t,e,s,i){if("function"==typeof e?t!==e||!i:!e.has(t))throw new TypeError("Cannot read private member from an object whose class did not declare it");return"m"===s?i:"a"===s?i.call(t):i?i.value:e.get(t)}function s(t,e,s,i,a){if("function"==typeof e||!e.has(t))throw new TypeError("Cannot write private member to an object whose class did not declare it");return e.set(t,s),s}var i,a,n,h;"function"==typeof SuppressedError&&SuppressedError;const r="__TAURI_TO_IPC_KEY__";class c{constructor(t){i.set(this,void 0),a.set(this,0),n.set(this,[]),h.set(this,void 0),s(this,i,t||(()=>{})),this.id=function(t,e=!1){return window.__TAURI_INTERNALS__.transformCallback(t,e)}(t=>{const r=t.index;if("end"in t)return void(r==e(this,a,"f")?this.cleanupCallback():s(this,h,r));const c=t.message;if(r==e(this,a,"f")){for(e(this,i,"f").call(this,c),s(this,a,e(this,a,"f")+1);e(this,a,"f")in e(this,n,"f");){const t=e(this,n,"f")[e(this,a,"f")];e(this,i,"f").call(this,t),delete e(this,n,"f")[e(this,a,"f")],s(this,a,e(this,a,"f")+1)}e(this,a,"f")===e(this,h,"f")&&this.cleanupCallback()}else e(this,n,"f")[r]=c})}cleanupCallback(){window.__TAURI_INTERNALS__.unregisterCallback(this.id)}set onmessage(t){s(this,i,t)}get onmessage(){return e(this,i,"f")}[(i=new WeakMap,a=new WeakMap,n=new WeakMap,h=new WeakMap,r)](){return`__CHANNEL__:${this.id}`}toJSON(){return this[r]()}}async function u(t,e={},s){return window.__TAURI_INTERNALS__.invoke(t,e,s)}class _{constructor(t,e){this._dbPath=t,this._transactionId=e}async read(t,e){return await u("plugin:sqlite|transaction_read",{token:{dbPath:this._dbPath,transactionId:this._transactionId},query:t,values:e??[]})}async continueWith(t){const e=await u("plugin:sqlite|transaction_continue",{token:{dbPath:this._dbPath,transactionId:this._transactionId},action:{type:"Continue",statements:t.map(([t,e])=>({query:t,values:e??[]}))}});return new _(e.dbPath,e.transactionId)}async commit(){await u("plugin:sqlite|transaction_continue",{token:{dbPath:this._dbPath,transactionId:this._transactionId},action:{type:"Commit"}})}async rollback(){await u("plugin:sqlite|transaction_continue",{token:{dbPath:this._dbPath,transactionId:this._transactionId},action:{type:"Rollback"}})}}class l{constructor(t){this._subscriptionId=t}get id(){return this._subscriptionId}async unsubscribe(){return await u("plugin:sqlite|unsubscribe",{subscriptionId:this._subscriptionId})}}class o{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|fetch_all",{db:this._db.path,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class d{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|fetch_one",{db:this._db.path,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class b{constructor(t,e,s,i,a){this._db=t,this._query=e,this._bindValues=s,this._keyset=i,this._pageSize=a,this._after=null,this._before=null,this._attached=[]}after(t){return this._after=t,this}before(t){return this._before=t,this}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|fetch_page",{db:this._db.path,query:this._query,values:this._bindValues,keyset:this._keyset,pageSize:this._pageSize,after:this._after,before:this._before,attached:this._attached.length>0?this._attached:null})}}class p{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){const[t,e]=await u("plugin:sqlite|execute",{db:this._db.path,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null});return{lastInsertId:e,rowsAffected:t}}}class f{constructor(t,e,s=[]){this._db=t,this._initialStatements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){const t=await u("plugin:sqlite|begin_interruptible_transaction",{db:this._db.path,initialStatements:this._initialStatements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null});return new _(t.dbPath,t.transactionId)}}class w{constructor(t,e,s=[]){this._db=t,this._statements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|execute_transaction",{db:this._db.path,statements:this._statements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null})}}class y{constructor(t){this.path=t}static async load(t,e){const s=await u("plugin:sqlite|load",{db:t,customConfig:e});return new y(s)}static get(t){return new y(t)}static async close_all(){await u("plugin:sqlite|close_all")}execute(t,e){return new p(this,t,e??[])}executeTransaction(t){return new w(this,t)}fetchAll(t,e){return new o(this,t,e??[])}fetchOne(t,e){return new d(this,t,e??[])}fetchPage(t,e,s,i){return new b(this,t,e,s,i)}async observe(t,e){await u("plugin:sqlite|observe",{db:this.path,tables:t,config:e??null})}async subscribe(t,e){const s=new c;s.onmessage=e;const i=await u("plugin:sqlite|subscribe",{db:this.path,tables:t,onEvent:s});return new l(i)}async unobserve(){await u("plugin:sqlite|unobserve",{db:this.path})}async close(){return await u("plugin:sqlite|close",{db:this.path})}async remove(){return await u("plugin:sqlite|remove",{db:this.path})}beginInterruptibleTransaction(t){return new f(this,t)}async getMigrationEvents(){return await u("plugin:sqlite|get_migration_events",{db:this.path})}}return t.InterruptibleTransaction=_,t.Subscription=l,t.default=y,Object.defineProperty(t,"__esModule",{value:!0}),t}({});Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} +if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(t){"use strict";function e(t,e,s,i){if("function"==typeof e?t!==e||!i:!e.has(t))throw new TypeError("Cannot read private member from an object whose class did not declare it");return"m"===s?i:"a"===s?i.call(t):i?i.value:e.get(t)}function s(t,e,s,i,a){if("function"==typeof e||!e.has(t))throw new TypeError("Cannot write private member to an object whose class did not declare it");return e.set(t,s),s}var i,a,n,r;"function"==typeof SuppressedError&&SuppressedError;const h="__TAURI_TO_IPC_KEY__";class c{constructor(t){i.set(this,void 0),a.set(this,0),n.set(this,[]),r.set(this,void 0),s(this,i,t||(()=>{})),this.id=function(t,e=!1){return window.__TAURI_INTERNALS__.transformCallback(t,e)}(t=>{const h=t.index;if("end"in t)return void(h==e(this,a,"f")?this.cleanupCallback():s(this,r,h));const c=t.message;if(h==e(this,a,"f")){for(e(this,i,"f").call(this,c),s(this,a,e(this,a,"f")+1);e(this,a,"f")in e(this,n,"f");){const t=e(this,n,"f")[e(this,a,"f")];e(this,i,"f").call(this,t),delete e(this,n,"f")[e(this,a,"f")],s(this,a,e(this,a,"f")+1)}e(this,a,"f")===e(this,r,"f")&&this.cleanupCallback()}else e(this,n,"f")[h]=c})}cleanupCallback(){window.__TAURI_INTERNALS__.unregisterCallback(this.id)}set onmessage(t){s(this,i,t)}get onmessage(){return e(this,i,"f")}[(i=new WeakMap,a=new WeakMap,n=new WeakMap,r=new WeakMap,h)](){return`__CHANNEL__:${this.id}`}toJSON(){return this[h]()}}async function u(t,e={},s){return window.__TAURI_INTERNALS__.invoke(t,e,s)}const _=new WeakMap;async function l(t){if(""!==t.path)return;let e=_.get(t);e||(e=u("plugin:sqlite|load",{dbKey:t.key}).then(e=>{t.path=e}),_.set(t,e)),await e}class o{constructor(t,e){this._dbKey=t,this._transactionId=e}async read(t,e){return await u("plugin:sqlite|transaction_read",{token:{dbKey:this._dbKey,transactionId:this._transactionId},query:t,values:e??[]})}async continueWith(t){const e=await u("plugin:sqlite|transaction_continue",{token:{dbKey:this._dbKey,transactionId:this._transactionId},action:{type:"Continue",statements:t.map(([t,e])=>({query:t,values:e??[]}))}});return new o(e.dbKey,e.transactionId)}async commit(){await u("plugin:sqlite|transaction_continue",{token:{dbKey:this._dbKey,transactionId:this._transactionId},action:{type:"Commit"}})}async rollback(){await u("plugin:sqlite|transaction_continue",{token:{dbKey:this._dbKey,transactionId:this._transactionId},action:{type:"Rollback"}})}}class d{constructor(t){this._subscriptionId=t}get id(){return this._subscriptionId}async unsubscribe(){return await u("plugin:sqlite|unsubscribe",{subscriptionId:this._subscriptionId})}}class b{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|fetch_all",{dbKey:this._db.key,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class y{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|fetch_one",{dbKey:this._db.key,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class w{constructor(t,e,s,i,a){this._db=t,this._query=e,this._bindValues=s,this._keyset=i,this._pageSize=a,this._after=null,this._before=null,this._attached=[]}after(t){return this._after=t,this}before(t){return this._before=t,this}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|fetch_page",{dbKey:this._db.key,query:this._query,values:this._bindValues,keyset:this._keyset,pageSize:this._pageSize,after:this._after,before:this._before,attached:this._attached.length>0?this._attached:null})}}class p{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){await l(this._db);const[t,e]=await u("plugin:sqlite|execute",{dbKey:this._db.key,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null});return{lastInsertId:e,rowsAffected:t}}}class f{constructor(t,e,s=[]){this._db=t,this._initialStatements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){await l(this._db);const t=await u("plugin:sqlite|begin_interruptible_transaction",{dbKey:this._db.key,initialStatements:this._initialStatements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null});return new o(t.dbKey,t.transactionId)}}class g{constructor(t,e,s=[]){this._db=t,this._statements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|execute_transaction",{dbKey:this._db.key,statements:this._statements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null})}}class q{constructor(t,e){this.key=t,this.path=e}static async load(t,e){const s=await u("plugin:sqlite|load",{dbKey:t,customConfig:e});return new q(t,s)}static get(t){return new q(t,"")}static async close_all(){await u("plugin:sqlite|close_all")}execute(t,e){return new p(this,t,e??[])}executeTransaction(t){return new g(this,t)}fetchAll(t,e){return new b(this,t,e??[])}fetchOne(t,e){return new y(this,t,e??[])}fetchPage(t,e,s,i){return new w(this,t,e,s,i)}async observe(t,e){await l(this),await u("plugin:sqlite|observe",{dbKey:this.key,tables:t,config:e??null})}async subscribe(t,e){await l(this);const s=new c;s.onmessage=e;const i=await u("plugin:sqlite|subscribe",{dbKey:this.key,tables:t,onEvent:s});return new d(i)}async unobserve(){await l(this),await u("plugin:sqlite|unobserve",{dbKey:this.key})}async close(){return await u("plugin:sqlite|close",{dbKey:this.key})}async remove(){return await u("plugin:sqlite|remove",{dbKey:this.key})}beginInterruptibleTransaction(t){return new f(this,t)}async getMigrationEvents(){return await u("plugin:sqlite|get_migration_events",{dbKey:this.key})}}return t.InterruptibleTransaction=o,t.Subscription=d,t.default=q,Object.defineProperty(t,"__esModule",{value:!0}),t}({});Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} diff --git a/crates/sqlx-sqlite-conn-mgr/README.md b/crates/sqlx-sqlite-conn-mgr/README.md index 578dddf..2b2c087 100644 --- a/crates/sqlx-sqlite-conn-mgr/README.md +++ b/crates/sqlx-sqlite-conn-mgr/README.md @@ -89,7 +89,8 @@ Migrations are tracked in `_sqlx_migrations` — calling `run_migrations()` mult times is safe (already-applied migrations are skipped). > **Note:** When using the Tauri plugin, migrations are handled automatically via -> `Builder::add_migrations()`. The plugin starts migrations at setup and waits for +> `Builder::register_database(..., Some(migrator))`. The plugin starts migrations at setup +> and waits for > completion when `load()` is called. ### Attached Databases diff --git a/crates/sqlx-sqlite-conn-mgr/src/config.rs b/crates/sqlx-sqlite-conn-mgr/src/config.rs index a8ab964..4baf501 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/config.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/config.rs @@ -25,6 +25,7 @@ use serde::{Deserialize, Serialize}; /// }; /// ``` #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct SqliteDatabaseConfig { /// Maximum number of concurrent read connections /// diff --git a/crates/sqlx-sqlite-conn-mgr/src/lib.rs b/crates/sqlx-sqlite-conn-mgr/src/lib.rs index c4a15b7..d3bc2b3 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/lib.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/lib.rs @@ -80,5 +80,7 @@ pub use write_guard::WriteGuard; // Re-export sqlx migrate types for convenience pub use sqlx::migrate::Migrator; +pub use registry::{canonicalize_database_path, is_memory_database}; + /// A type alias for Results with our custom Error type pub type Result = std::result::Result; diff --git a/crates/sqlx-sqlite-conn-mgr/src/registry.rs b/crates/sqlx-sqlite-conn-mgr/src/registry.rs index fb98009..f3915ff 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/registry.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/registry.rs @@ -16,14 +16,23 @@ fn registry() -> &'static RwLock>> { DATABASE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new())) } +/// Returns true when a `file:` URI query string contains an exact `mode=memory` parameter. +fn file_uri_has_mode_memory(path_str: &str) -> bool { + let Some(query) = path_str.split_once('?').map(|(_, query)| query) else { + return false; + }; + query.split('&').any(|param| param == "mode=memory") +} + /// Check if a path represents an in-memory SQLite database /// -/// Returns true for `:memory:` and `file::memory:*` URIs +/// Returns true for `:memory:` and `file::memory:*` URIs, and for `file:` URIs whose +/// query string includes a `mode=memory` parameter (not merely a substring match). pub fn is_memory_database(path: &Path) -> bool { let path_str = path.to_str().unwrap_or(""); path_str == ":memory:" || path_str.starts_with("file::memory:") - || path_str.contains("mode=memory") + || (path_str.starts_with("file:") && file_uri_has_mode_memory(path_str)) } /// Get or open a SQLite database connection @@ -44,7 +53,7 @@ where } // Canonicalize the path for consistent lookups - let canonical_path = canonicalize_path(path)?; + let canonical_path = canonicalize_database_path(path)?; // Try to get existing database with read lock (allows concurrent reads) { @@ -82,10 +91,9 @@ where Ok(arc_db) } -/// Helper to canonicalize a database path +/// Canonicalize a database file path for consistent registry lookups. /// -/// This function attempts to resolve paths to their canonical form to ensure -/// consistent cache lookups. It handles: +/// This function attempts to resolve paths to their canonical form. It handles: /// - Absolute path resolution /// - Symlink resolution (when file exists) /// - Parent directory canonicalization (when file doesn't exist yet) @@ -97,7 +105,7 @@ where /// least until the file is created and can be canonicalized properly. /// - Symlinks in filename: If the filename itself will be a symlink (rare for SQLite), /// different symlink names won't be resolved until the file exists. -fn canonicalize_path(path: &Path) -> std::io::Result { +pub fn canonicalize_database_path(path: &Path) -> std::io::Result { match path.canonicalize() { Ok(p) => Ok(p), Err(_) => { @@ -132,7 +140,7 @@ pub async fn uncache_database(path: &Path) -> std::io::Result<()> { } // Canonicalize path - let canonical_path = canonicalize_path(path)?; + let canonical_path = canonicalize_database_path(path)?; let mut registry = registry().write().await; registry.remove(&canonical_path); @@ -149,12 +157,12 @@ mod tests { let test_path = temp_dir.join("test.db"); // Test that path is canonicalized to absolute path - let canonical = canonicalize_path(&test_path).unwrap(); + let canonical = canonicalize_database_path(&test_path).unwrap(); assert!(canonical.is_absolute()); // Test relative path let relative_path = Path::new("./test_relative.db"); - let canonical_relative = canonicalize_path(relative_path).unwrap(); + let canonical_relative = canonicalize_database_path(relative_path).unwrap(); assert!(canonical_relative.is_absolute()); } @@ -164,7 +172,22 @@ mod tests { let nonexistent = temp_dir.join("nonexistent_dir").join("test.db"); // Should fail if parent directory doesn't exist - let result = canonicalize_path(&nonexistent); + let result = canonicalize_database_path(&nonexistent); assert!(result.is_err()); } + + #[test] + fn test_mode_memory_query_param() { + assert!(is_memory_database(Path::new("file:test?mode=memory"))); + assert!(is_memory_database(Path::new( + "file:/data/db?cache=shared&mode=memory" + ))); + } + + #[test] + fn test_mode_memory_substring_in_value_is_not_memory() { + assert!(!is_memory_database(Path::new( + "file:/home/user/real.db?x=mode=memory" + ))); + } } diff --git a/crates/sqlx-sqlite-toolkit/tests/memory_tests.rs b/crates/sqlx-sqlite-toolkit/tests/memory_tests.rs new file mode 100644 index 0000000..291c491 --- /dev/null +++ b/crates/sqlx-sqlite-toolkit/tests/memory_tests.rs @@ -0,0 +1,81 @@ +use std::path::Path; +use std::sync::Arc; + +use serde_json::json; +use sqlx_sqlite_toolkit::DatabaseWrapper; + +/// These tests verify the correct behavior of `DatabaseWrapper::connect` with an +/// in-memory database. +#[tokio::test] +async fn connect_memory_runs_ddl_and_dml() { + let db = DatabaseWrapper::connect(Path::new(":memory:"), None) + .await + .expect("Failed to connect to in-memory database"); + + db.execute( + "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)".into(), + vec![], + ) + .execute() + .await + .expect("CREATE TABLE should succeed"); + + let result = db + .execute( + "INSERT INTO t (name) VALUES ($1)".into(), + vec![json!("Alice")], + ) + .execute() + .await + .expect("INSERT should succeed"); + + assert_eq!((result.rows_affected, result.last_insert_id), (1, 1)); + + // Read back on the same write connection via an interruptible transaction. + let mut tx = db + .begin_interruptible_transaction() + .execute(vec![( + "INSERT INTO t (name) VALUES (?)", + vec![json!("Bob")], + )]) + .await + .expect("transaction should start"); + + let rows = tx + .read("SELECT name FROM t ORDER BY id".into(), vec![]) + .await + .expect("SELECT within transaction should succeed"); + + assert_eq!(rows.len(), 2); + tx.commit().await.expect("commit should succeed"); +} + +#[tokio::test] +async fn connect_memory_instances_are_independent() { + let db1 = DatabaseWrapper::connect(Path::new(":memory:"), None) + .await + .expect("Failed to connect first in-memory database"); + let db2 = DatabaseWrapper::connect(Path::new(":memory:"), None) + .await + .expect("Failed to connect second in-memory database"); + + assert!( + !Arc::ptr_eq(db1.inner(), db2.inner()), + ":memory: databases should not share the same SqliteDatabase instance" + ); + + db1.execute("CREATE TABLE test (id INTEGER)".into(), vec![]) + .execute() + .await + .expect("CREATE TABLE on first database should succeed"); + + let result = db2 + .fetch_all("SELECT * FROM test".into(), vec![]) + .execute() + .await; + + assert!( + result.is_err(), + "Second :memory: database should not see tables from the first" + ); +} diff --git a/guest-js/index.test.ts b/guest-js/index.test.ts index 39d1ed9..54a0cb4 100644 --- a/guest-js/index.test.ts +++ b/guest-js/index.test.ts @@ -16,12 +16,49 @@ import Database, { let lastCmd = '', lastArgs: Record = {}; +const DATABASE_KEYS = { + T: 'T', + TEST: 'TEST', + ARCHIVE: 'ARCHIVE', + STATS: 'STATS', + ORDERS: 'ORDERS', + MAIN: 'MAIN', +}; + +function getDatabasePath(key: string): string { + if (key === DATABASE_KEYS.ARCHIVE) { + return 'archive.db'; + } + + if (key === DATABASE_KEYS.STATS) { + return 'stats.db'; + } + + if (key === DATABASE_KEYS.ORDERS) { + return 'orders.db'; + } + + if (key === DATABASE_KEYS.T) { + return 't.db'; + } + + if (key === DATABASE_KEYS.TEST) { + return 'test.db'; + } + + if (key === DATABASE_KEYS.MAIN) { + return 'main.db'; + } + + throw new Error(`Unknown database key: ${key}`); +} + beforeEach(() => { mockIPC((cmd, args) => { lastCmd = cmd; lastArgs = args as Record; if (cmd === 'plugin:sqlite|load') { - return (args as { db: string }).db; + return getDatabasePath((args as { dbKey: string }).dbKey); } if (cmd === 'plugin:sqlite|execute') { return [ 1, 1 ]; @@ -30,13 +67,13 @@ beforeEach(() => { return []; } if (cmd === 'plugin:sqlite|begin_interruptible_transaction') { - return { dbPath: (args as { db: string }).db, transactionId: 'test-tx-id' }; + return { dbKey: (args as { dbKey: string }).dbKey, transactionId: 'test-tx-id' }; } if (cmd === 'plugin:sqlite|transaction_continue') { const action = (args as { action: { type: string } }).action; if (action.type === 'Continue') { - return { dbPath: 'test.db', transactionId: 'test-tx-id' }; + return { dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }; } return undefined; } @@ -84,32 +121,41 @@ afterEach(() => { return clearMocks(); }); describe('Database commands', () => { it('load', async () => { - await Database.load('test.db'); + await Database.load(DATABASE_KEYS.TEST); expect(lastCmd).toBe('plugin:sqlite|load'); - expect(lastArgs.db).toBe('test.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.TEST); + }); + + it('get is lazy until first operation', async () => { + const db = Database.get(DATABASE_KEYS.TEST); + + expect(db.path).toBe(''); + + await db.fetchAll('SELECT 1'); + expect(db.path).toBe(getDatabasePath(DATABASE_KEYS.TEST)); }); it('execute', async () => { - await Database.get('t.db').execute('INSERT INTO t VALUES ($1)', [ 1 ]); + await Database.get(DATABASE_KEYS.T).execute('INSERT INTO t VALUES ($1)', [ 1 ]); expect(lastCmd).toBe('plugin:sqlite|execute'); - expect(lastArgs).toMatchObject({ db: 't.db', query: 'INSERT INTO t VALUES ($1)', values: [ 1 ], attached: null }); + expect(lastArgs).toMatchObject({ dbKey: DATABASE_KEYS.T, query: 'INSERT INTO t VALUES ($1)', values: [ 1 ], attached: null }); }); it('execute with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .execute('UPDATE todos SET status = $1 WHERE id IN (SELECT todo_id FROM archive.completed)', [ 'archived' ]) .attach([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|execute'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'archive.db', + databaseKey: 'ARCHIVE', schemaName: 'archive', mode: 'readOnly', }, @@ -117,30 +163,30 @@ describe('Database commands', () => { }); it('execute_transaction', async () => { - await Database.get('t.db').executeTransaction([ [ 'DELETE FROM t' ] ]); + await Database.get(DATABASE_KEYS.T).executeTransaction([ [ 'DELETE FROM t' ] ]); expect(lastCmd).toBe('plugin:sqlite|execute_transaction'); expect(lastArgs.statements).toEqual([ { query: 'DELETE FROM t', values: [] } ]); expect(lastArgs.attached).toBe(null); }); it('execute_transaction with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .executeTransaction([ [ 'INSERT INTO orders (user_id, total) VALUES ($1, $2)', [ 1, 99.99 ] ], [ 'UPDATE stats.order_stats SET order_count = order_count + 1', [] ], ]) .attach([ { - databasePath: 'stats.db', + databaseKey: DATABASE_KEYS.STATS, schemaName: 'stats', mode: 'readWrite', }, ]); expect(lastCmd).toBe('plugin:sqlite|execute_transaction'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'stats.db', + databaseKey: DATABASE_KEYS.STATS, schemaName: 'stats', mode: 'readWrite', }, @@ -148,26 +194,26 @@ describe('Database commands', () => { }); it('fetch_all', async () => { - await Database.get('t.db').fetchAll('SELECT * FROM t'); + await Database.get(DATABASE_KEYS.T).fetchAll('SELECT * FROM t'); expect(lastCmd).toBe('plugin:sqlite|fetch_all'); - expect(lastArgs).toMatchObject({ db: 't.db', query: 'SELECT * FROM t', attached: null }); + expect(lastArgs).toMatchObject({ dbKey: DATABASE_KEYS.T, query: 'SELECT * FROM t', attached: null }); }); it('fetch_all with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .fetchAll('SELECT u.name, o.total FROM users u JOIN orders.orders o ON u.id = o.user_id', []) .attach([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|fetch_all'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, @@ -175,26 +221,26 @@ describe('Database commands', () => { }); it('fetch_one', async () => { - await Database.get('t.db').fetchOne('SELECT * FROM t WHERE id = $1', [ 1 ]); + await Database.get(DATABASE_KEYS.T).fetchOne('SELECT * FROM t WHERE id = $1', [ 1 ]); expect(lastCmd).toBe('plugin:sqlite|fetch_one'); - expect(lastArgs).toMatchObject({ db: 't.db', values: [ 1 ], attached: null }); + expect(lastArgs).toMatchObject({ dbKey: DATABASE_KEYS.T, values: [ 1 ], attached: null }); }); it('fetch_one with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .fetchOne('SELECT COUNT(*) as total FROM users u JOIN orders.orders o ON u.id = o.user_id', []) .attach([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|fetch_one'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, @@ -206,10 +252,10 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db').fetchPage('SELECT * FROM posts', [], keyset, 25); + await Database.get(DATABASE_KEYS.T).fetchPage('SELECT * FROM posts', [], keyset, 25); expect(lastCmd).toBe('plugin:sqlite|fetch_page'); expect(lastArgs).toMatchObject({ - db: 't.db', + dbKey: DATABASE_KEYS.T, query: 'SELECT * FROM posts', values: [], keyset: [ { name: 'id', direction: 'asc' } ], @@ -225,7 +271,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts', [], keyset, 25) .after([ 100 ]); @@ -239,7 +285,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts', [], keyset, 25) .before([ 50 ]); @@ -253,22 +299,22 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .fetchPage('SELECT * FROM posts', [], keyset, 25) .attach([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|fetch_page'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.after).toBeNull(); expect(lastArgs.before).toBeNull(); expect(lastArgs.attached).toEqual([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, @@ -282,7 +328,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts WHERE active = $1', [ true ], keyset, 50) .after([ 'tech', 95, 42 ]); @@ -305,7 +351,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts WHERE active = $1', [ true ], keyset, 50) .before([ 'tech', 95, 42 ]); @@ -322,9 +368,9 @@ describe('Database commands', () => { }); it('close', async () => { - await Database.get('t.db').close(); + await Database.get(DATABASE_KEYS.T).close(); expect(lastCmd).toBe('plugin:sqlite|close'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); }); it('close_all', async () => { @@ -333,15 +379,15 @@ describe('Database commands', () => { }); it('remove', async () => { - await Database.get('t.db').remove(); + await Database.get(DATABASE_KEYS.T).remove(); expect(lastCmd).toBe('plugin:sqlite|remove'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); }); it('getMigrationEvents', async () => { const mockEvents: MigrationEvent[] = [ - { dbPath: 't.db', status: 'running' }, - { dbPath: 't.db', status: 'completed', migrationCount: 5 }, + { dbKey: DATABASE_KEYS.T, dbPath: 't.db', status: 'running' }, + { dbKey: DATABASE_KEYS.T, dbPath: 't.db', status: 'completed', migrationCount: 5 }, ]; mockIPC((cmd, args) => { @@ -353,28 +399,28 @@ describe('Database commands', () => { return undefined; }); - const events = await Database.get('t.db').getMigrationEvents(); + const events = await Database.get(DATABASE_KEYS.T).getMigrationEvents(); expect(lastCmd).toBe('plugin:sqlite|get_migration_events'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(events).toEqual(mockEvents); }); it('getMigrationEvents - empty array', async () => { - const events = await Database.get('test.db').getMigrationEvents(); + const events = await Database.get(DATABASE_KEYS.TEST).getMigrationEvents(); expect(lastCmd).toBe('plugin:sqlite|get_migration_events'); - expect(lastArgs.db).toBe('test.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.TEST); expect(events).toEqual([]); }); it('beginInterruptibleTransaction', async () => { - const tx = await Database.get('t.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.T).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); expect(lastCmd).toBe('plugin:sqlite|begin_interruptible_transaction'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(lastArgs.initialStatements).toEqual([ { query: 'INSERT INTO users (name) VALUES ($1)', values: [ 'Alice' ] }, ]); @@ -383,26 +429,26 @@ describe('Database commands', () => { }); it('beginInterruptibleTransaction with attached databases', async () => { - const tx = await Database.get('main.db') + const tx = await Database.get(DATABASE_KEYS.MAIN) .beginInterruptibleTransaction([ [ 'DELETE FROM users WHERE id IN (SELECT user_id FROM archive.archived_users)' ], ]) .attach([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|begin_interruptible_transaction'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.initialStatements).toEqual([ { query: 'DELETE FROM users WHERE id IN (SELECT user_id FROM archive.archived_users)', values: [] }, ]); expect(lastArgs.attached).toEqual([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, @@ -411,7 +457,7 @@ describe('Database commands', () => { }); it('InterruptibleTransaction.continueWith()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); @@ -420,67 +466,70 @@ describe('Database commands', () => { ]); expect(lastCmd).toBe('plugin:sqlite|transaction_continue'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect((lastArgs.action as { type: string }).type).toBe('Continue'); expect(tx2).toBeInstanceOf(Object); }); it('InterruptibleTransaction.commit()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); await tx.commit(); expect(lastCmd).toBe('plugin:sqlite|transaction_continue'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect((lastArgs.action as { type: string }).type).toBe('Commit'); }); it('InterruptibleTransaction.rollback()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); await tx.rollback(); expect(lastCmd).toBe('plugin:sqlite|transaction_continue'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect((lastArgs.action as { type: string }).type).toBe('Rollback'); }); it('InterruptibleTransaction.read()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); await tx.read('SELECT * FROM users WHERE name = $1', [ 'Alice' ]); expect(lastCmd).toBe('plugin:sqlite|transaction_read'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect(lastArgs.query).toBe('SELECT * FROM users WHERE name = $1'); expect(lastArgs.values).toEqual([ 'Alice' ]); }); it('handles errors from backend', async () => { + const db = new Database(DATABASE_KEYS.T, getDatabasePath(DATABASE_KEYS.T)); + mockIPC(() => { throw new Error('Database error'); }); - await expect(Database.get('t.db').execute('SELECT 1', [])).rejects.toThrow('Database error'); + + await expect(db.execute('SELECT 1', [])).rejects.toThrow('Database error'); }); }); describe('Database.load with customConfig', () => { it('passes customConfig to backend', async () => { - await Database.load('test.db', { maxReadConnections: 10, idleTimeoutSecs: 60 }); + await Database.load(DATABASE_KEYS.TEST, { maxReadConnections: 10, idleTimeoutSecs: 60 }); expect(lastCmd).toBe('plugin:sqlite|load'); - expect(lastArgs.db).toBe('test.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.TEST); expect(lastArgs.customConfig).toEqual({ maxReadConnections: 10, idleTimeoutSecs: 60 }); }); }); describe('Observer commands', () => { it('observe', async () => { - await Database.get('t.db').observe([ 'users', 'posts' ]); + await Database.get(DATABASE_KEYS.T).observe([ 'users', 'posts' ]); expect(lastCmd).toBe('plugin:sqlite|observe'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(lastArgs.tables).toEqual([ 'users', 'posts' ]); expect(lastArgs.config).toBe(null); }); @@ -488,7 +537,7 @@ describe('Observer commands', () => { it('observe with config', async () => { const config: ObserverConfig = { channelCapacity: 512, captureValues: false }; - await Database.get('t.db').observe([ 'users' ], config); + await Database.get(DATABASE_KEYS.T).observe([ 'users' ], config); expect(lastCmd).toBe('plugin:sqlite|observe'); expect(lastArgs.tables).toEqual([ 'users' ]); expect(lastArgs.config).toEqual({ channelCapacity: 512, captureValues: false }); @@ -497,10 +546,10 @@ describe('Observer commands', () => { it('subscribe', async () => { const events: TableChangeEvent[] = []; - const sub = await Database.get('t.db').subscribe([ 'users' ], (e) => { events.push(e); }); + const sub = await Database.get(DATABASE_KEYS.T).subscribe([ 'users' ], (e) => { events.push(e); }); expect(lastCmd).toBe('plugin:sqlite|subscribe'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(lastArgs.tables).toEqual([ 'users' ]); expect(lastArgs.onEvent).toBeDefined(); expect(sub).toBeInstanceOf(Subscription); @@ -520,15 +569,16 @@ describe('Observer commands', () => { }); it('unobserve', async () => { - await Database.get('t.db').unobserve(); + await Database.get(DATABASE_KEYS.T).unobserve(); expect(lastCmd).toBe('plugin:sqlite|unobserve'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); }); }); describe('MigrationEvent type', () => { it('accepts running status', () => { const event: MigrationEvent = { + dbKey: 'TEST', dbPath: 'test.db', status: 'running', }; @@ -540,6 +590,7 @@ describe('MigrationEvent type', () => { it('accepts completed status with migrationCount', () => { const event: MigrationEvent = { + dbKey: 'TEST', dbPath: 'test.db', status: 'completed', migrationCount: 3, @@ -551,6 +602,7 @@ describe('MigrationEvent type', () => { it('accepts failed status with error', () => { const event: MigrationEvent = { + dbKey: 'TEST', dbPath: 'test.db', status: 'failed', error: 'Migration failed: syntax error', diff --git a/guest-js/index.ts b/guest-js/index.ts index 3537fae..2f3cb2e 100644 --- a/guest-js/index.ts +++ b/guest-js/index.ts @@ -1,5 +1,25 @@ import { Channel, invoke } from '@tauri-apps/api/core'; +const databaseLoadPromises = new WeakMap>(); + +async function ensureDatabaseLoaded(db: Database): Promise { + if (db.path !== '') { + return; + } + + let loadPromise = databaseLoadPromises.get(db); + + if (!loadPromise) { + loadPromise = invoke('plugin:sqlite|load', { dbKey: db.key }) + .then((resolvedPath) => { + db.path = resolvedPath; + }); + databaseLoadPromises.set(db, loadPromise); + } + + await loadPromise; +} + /** * Valid SQLite parameter binding value types. * @@ -23,9 +43,9 @@ export type AttachedDatabaseMode = 'readOnly' | 'readWrite'; export interface AttachedDatabaseSpec { /** - * Path to the database to attach (must be loaded via `Database.load()` first) + * Database key for identifying the database to attach */ - databasePath: string; + databaseKey: string; /** * Schema name to use for the attached database in queries @@ -77,11 +97,11 @@ export interface SqliteError { * Provides methods to read uncommitted data and execute additional statements. */ export class InterruptibleTransaction { - private readonly _dbPath: string; + private readonly _dbKey: string; private readonly _transactionId: string; - public constructor(dbPath: string, transactionId: string) { - this._dbPath = dbPath; + public constructor(dbKey: string, transactionId: string) { + this._dbKey = dbKey; this._transactionId = transactionId; } @@ -112,7 +132,7 @@ export class InterruptibleTransaction { */ public async read(query: string, bindValues?: SqlValue[]): Promise { return await invoke('plugin:sqlite|transaction_read', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, query, values: bindValues ?? [], }); @@ -137,10 +157,10 @@ export class InterruptibleTransaction { * ``` */ public async continueWith(statements: Array<[string, SqlValue[]?]>): Promise { - const token = await invoke<{ dbPath: string; transactionId: string }>( + const token = await invoke<{ dbKey: string; transactionId: string }>( 'plugin:sqlite|transaction_continue', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, action: { type: 'Continue', statements: statements.map(([ query, values ]) => { @@ -153,7 +173,7 @@ export class InterruptibleTransaction { } ); - return new InterruptibleTransaction(token.dbPath, token.transactionId); + return new InterruptibleTransaction(token.dbKey, token.transactionId); } /** @@ -170,7 +190,7 @@ export class InterruptibleTransaction { */ public async commit(): Promise { await invoke('plugin:sqlite|transaction_continue', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, action: { type: 'Commit' }, }); } @@ -189,7 +209,7 @@ export class InterruptibleTransaction { */ public async rollback(): Promise { await invoke('plugin:sqlite|transaction_continue', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, action: { type: 'Rollback' }, }); } @@ -220,7 +240,7 @@ export interface CustomConfig { * // Get all migration events (including ones emitted before registering listener) * const events = await db.getMigrationEvents() * for (const event of events) { - * console.log(`${event.status}: ${event.dbPath}`) + * console.log(`${event.status}: ${event.dbKey}`) * if (event.status === 'failed') { * console.error(`Migration error: ${event.error}`) * } @@ -230,17 +250,17 @@ export interface CustomConfig { * * // Listen for real-time events (may miss early events) * await listen('sqlite:migration', (event) => { - * const { dbPath, status, migrationCount, error } = event.payload + * const { dbKey, dbPath, status, migrationCount, error } = event.payload * * switch (status) { * case 'running': - * console.log(`Running migrations for ${dbPath}`) + * console.log(`Running migrations for ${dbKey}`) * break * case 'completed': - * console.log(`Completed ${migrationCount} migrations for ${dbPath}`) + * console.log(`Completed ${migrationCount} migrations for ${dbKey}`) * break * case 'failed': - * console.error(`Migration failed for ${dbPath}: ${error}`) + * console.error(`Migration failed for ${dbKey}: ${error}`) * break * } * }) @@ -248,7 +268,12 @@ export interface CustomConfig { */ export interface MigrationEvent { - /** Database path (relative, as registered with the plugin) */ + /** Database registration key (such as `MAIN`) */ + dbKey: string; + + /** Database path, the absolute path to the database file, such as + * `/var/lib/myapp/main.db`. + */ dbPath: string; /** Status: "running", "completed", "failed" */ @@ -448,8 +473,10 @@ class FetchAllBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + return await invoke('plugin:sqlite|fetch_all', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, attached: this._attached.length > 0 ? this._attached : null, @@ -497,8 +524,10 @@ class FetchOneBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + return await invoke('plugin:sqlite|fetch_one', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, attached: this._attached.length > 0 ? this._attached : null, @@ -577,8 +606,10 @@ class FetchPageBuilder implements PromiseLike> { } private async _execute(): Promise> { + await ensureDatabaseLoaded(this._db); + return await invoke>('plugin:sqlite|fetch_page', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, keyset: this._keyset, @@ -630,10 +661,12 @@ class ExecuteBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + const [ rowsAffected, lastInsertId ] = await invoke<[number, number]>( 'plugin:sqlite|execute', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, attached: this._attached.length > 0 ? this._attached : null, @@ -684,10 +717,12 @@ class InterruptibleTransactionBuilder implements PromiseLike { - const token = await invoke<{ dbPath: string; transactionId: string }>( + await ensureDatabaseLoaded(this._db); + + const token = await invoke<{ dbKey: string; transactionId: string }>( 'plugin:sqlite|begin_interruptible_transaction', { - db: this._db.path, + dbKey: this._db.key, initialStatements: this._initialStatements.map(([ query, values ]) => { return { query, @@ -698,7 +733,7 @@ class InterruptibleTransactionBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + return await invoke('plugin:sqlite|execute_transaction', { - db: this._db.path, + dbKey: this._db.key, statements: this._statements.map(([ query, values ]) => { return { query, @@ -765,9 +802,17 @@ class TransactionBuilder implements PromiseLike { * visible to reads in another, and closing a database affects all windows. */ export default class Database { + + /** Database key for identifying the database */ + public key: string; + + /** Database path, the absolute path to the database file, such as + * `/var/lib/myapp/main.db`. + */ public path: string; - public constructor(path: string) { + public constructor(key: string, path: string) { + this.key = key; this.path = path; } @@ -777,51 +822,58 @@ export default class Database { * A static initializer which connects to the underlying SQLite database and * returns a `Database` instance once a connection is established. * - * The path is relative to `tauri::path::BaseDirectory::AppConfig`. + * The key must be a key that the Rust side has registered with the plugin + * (via `Builder::register_database` / `SetupRegistrar::register_database`), + * or an in-memory database such as `:memory:`. * - * @param path - Database file path (relative to AppConfig directory) + * @param key - The key of the database to load * @param customConfig - Optional custom configuration for connection pools * * @example * ```ts + * const dbKey = 'MAIN'; + * * // Use default configuration - * const db = await Database.load("test.db"); + * const db = await Database.load(dbKey); * * // Use custom configuration - * const db = await Database.load("test.db", { + * const db2 = await Database.load(dbKey, { * maxReadConnections: 10, * idleTimeoutSecs: 60 * }); * ``` */ public static async load( - path: string, + dbKey: string, customConfig?: CustomConfig ): Promise { const resolvedPath = await invoke('plugin:sqlite|load', { - db: path, + dbKey, customConfig, }); - return new Database(resolvedPath); + return new Database(dbKey, resolvedPath); } /** * **get** * - * A static initializer which synchronously returns an instance of - * the Database class while deferring the actual database connection - * until the first invocation or selection on the database. + * Synchronously returns a `Database` handle for a registered key while deferring + * the actual connection until the first query or other operation that requires a + * loaded connection. Use {@link Database.load} when you need to pass + * `customConfig` or connect eagerly. * - * The path is relative to `tauri::path::BaseDirectory::AppConfig`. + * The key must be registered with the plugin on the Rust side + * (via `Builder::register_database` / `SetupRegistrar::register_database`). * * @example * ```ts - * const db = Database.get("test.db"); + * const db = Database.get('MAIN'); + * await db.fetchAll('SELECT 1'); * ``` */ - public static get(path: string): Database { - return new Database(path); + public static get(dbKey: string): Database { + return new Database(dbKey, ''); } /** @@ -866,7 +918,7 @@ export default class Database { * "(SELECT todo_id FROM archive.completed)", * [ "archived" ] * ).attach([{ - * databasePath: "archive.db", + * databaseKey: "ARCHIVE", * schemaName: "archive", * mode: "readOnly" * }]); @@ -918,7 +970,7 @@ export default class Database { * ['INSERT INTO main.orders (user_id, total) VALUES ($1, $2)', [userId, total]], * ['UPDATE archive.stats SET order_count = order_count + 1', []] * ]).attach([{ - * databasePath: "archive.db", + * databaseKey: "ARCHIVE", * schemaName: "archive", * mode: "readWrite" * }]); @@ -957,7 +1009,7 @@ export default class Database { * "SELECT u.name, o.total FROM users u JOIN orders.orders o ON u.id = o.user_id", * [] * ).attach([{ - * databasePath: "orders.db", + * databaseKey: "ORDERS", * schemaName: "orders", * mode: "readOnly" * }]); @@ -992,7 +1044,7 @@ export default class Database { * "SELECT COUNT(*) as total FROM users u JOIN orders.orders o ON u.id = o.user_id", * [] * ).attach([{ - * databasePath: "orders.db", + * databaseKey: "ORDERS", * schemaName: "orders", * mode: "readOnly" * }]); @@ -1061,7 +1113,7 @@ export default class Database { * keyset, * 25, * ).attach([{ - * databasePath: 'archive.db', + * databaseKey: 'ARCHIVE', * schemaName: 'archive', * mode: 'readOnly', * }]); @@ -1106,8 +1158,10 @@ export default class Database { * ``` */ public async observe(tables: string[], config?: ObserverConfig): Promise { + await ensureDatabaseLoaded(this); + await invoke('plugin:sqlite|observe', { - db: this.path, + dbKey: this.key, tables, config: config ?? null, }); @@ -1148,12 +1202,14 @@ export default class Database { tables: string[], onEvent: (event: TableChangeEvent) => void ): Promise { + await ensureDatabaseLoaded(this); + const channel = new Channel(); channel.onmessage = onEvent; const subscriptionId = await invoke('plugin:sqlite|subscribe', { - db: this.path, + dbKey: this.key, tables, onEvent: channel, }); @@ -1174,8 +1230,10 @@ export default class Database { * ``` */ public async unobserve(): Promise { + await ensureDatabaseLoaded(this); + await invoke('plugin:sqlite|unobserve', { - db: this.path, + dbKey: this.key, }); } @@ -1199,7 +1257,7 @@ export default class Database { */ public async close(): Promise { const success = await invoke('plugin:sqlite|close', { - db: this.path, + dbKey: this.key, }); return success; @@ -1229,7 +1287,7 @@ export default class Database { */ public async remove(): Promise { const success = await invoke('plugin:sqlite|remove', { - db: this.path, + dbKey: this.key, }); return success; @@ -1292,7 +1350,7 @@ export default class Database { * let tx = await db.beginInterruptibleTransaction([ * ['DELETE FROM users WHERE archived = 1'] * ]).attach([{ - * databasePath: 'archive.db', + * databaseKey: 'ARCHIVE', * schemaName: 'archive', * mode: 'readWrite' * }]); @@ -1328,7 +1386,7 @@ export default class Database { * // Get all migration events (including ones that happened before we could listen) * const events = await db.getMigrationEvents() * for (const event of events) { - * console.log(`${event.status}: ${event.dbPath}`) + * console.log(`${event.status}: ${event.dbKey}`) * if (event.status === 'failed') { * console.error(`Migration error: ${event.error}`) * } @@ -1337,7 +1395,7 @@ export default class Database { */ public async getMigrationEvents(): Promise { return await invoke('plugin:sqlite|get_migration_events', { - db: this.path, + dbKey: this.key, }); } } diff --git a/package.json b/package.json index fdfe120..e19c7c2 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,7 @@ ], "scripts": { "build": "rollup -c", + "check:iife": "rollup -c && git diff --exit-code api-iife.js", "check-node-version": "check-node-version --npm 10.5.0", "commitlint": "commitlint --from ${COMMITLINT_FROM:-002bcc8} --to ${COMMITLINT_TO:-HEAD}", "eslint": "eslint .", @@ -27,7 +28,7 @@ "type-check": "run-p type-check:ts", "markdownlint": "markdownlint-cli2", "prepare": "rollup -c", - "standards": "npm run eslint && npm run type-check && npm run markdownlint && npm run rust:lint && npm run commitlint", + "standards": "npm run eslint && npm run type-check && npm run markdownlint && npm run check:iife && npm run rust:lint && npm run commitlint", "rust:lint": "cargo lint-clippy && cargo lint-fmt", "rust:lint:fix": "cargo fix-clippy && cargo fix-fmt", "test": "vitest run && cargo test --workspace --lib --test '*'", diff --git a/src/commands.rs b/src/commands.rs index 2222bda..438ad9a 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -12,14 +12,16 @@ use sqlx_sqlite_toolkit::{ ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, DatabaseWrapper, Statement, TransactionWriter, WriteQueryResult, }; +use std::path::PathBuf; use std::sync::Arc; use tauri::ipc::Channel; use tauri::{AppHandle, Runtime, State}; use tracing::debug; use uuid::Uuid; +use crate::connect_to_database; use crate::{ - DbInstances, Error, MigrationEvent, MigrationStates, MigrationStatus, Result, + DbInstances, Error, MigrationEvent, MigrationStates, Result, subscriptions::{ ActiveSubscriptions, ObserverConfigParams, TableChangePayload, event_to_payload, }, @@ -29,7 +31,7 @@ use crate::{ #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TransactionToken { - pub db_path: String, + pub db_key: String, pub transaction_id: String, } @@ -46,8 +48,8 @@ pub enum TransactionAction { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AttachedDatabaseSpec { - /// Path to the database to attach (must be loaded via `load()` first) - pub database_path: String, + /// Key of the database to attach (must be loaded via `load()` first) + pub database_key: String, /// Schema name to use for the attached database in queries pub schema_name: String, /// Access mode: "readOnly" or "readWrite" @@ -71,8 +73,8 @@ fn resolve_attached_specs( for spec in specs { let wrapper = db_instances - .get(&spec.database_path) - .ok_or_else(|| Error::DatabaseNotLoaded(spec.database_path.clone()))?; + .get(&spec.database_key) + .ok_or_else(|| Error::DatabaseNotLoaded(spec.database_key.clone()))?; let mode = match spec.mode { AttachedDatabaseMode::ReadOnly => sqlx_sqlite_conn_mgr::AttachedMode::ReadOnly, @@ -94,6 +96,10 @@ fn resolve_attached_specs( /// If the database is already loaded, returns the existing connection. /// Otherwise, creates a new connection with optional custom configuration. /// +/// `db_key` must be a registration key from +/// [`Builder::register_database`] / [`SetupRegistrar::register_database`]. +/// Unregistered keys are rejected with `PATH_NOT_REGISTERED`. +/// /// # Migration Timing /// /// If migrations are registered for this database, this function waits for them @@ -104,92 +110,18 @@ fn resolve_attached_specs( #[tauri::command] pub async fn load( app: AppHandle, - db_instances: State<'_, DbInstances>, - migration_states: State<'_, MigrationStates>, - db: String, + db_key: String, custom_config: Option, -) -> Result { - // Wait for migrations to complete if registered for this database - await_migrations(&migration_states, &db).await?; - - let instances = db_instances.inner.read().await; - - // Return cached if db was already loaded - if instances.contains_key(&db) { - return Ok(db); - } - - drop(instances); // Release read lock before acquiring write lock - - let mut instances = db_instances.inner.write().await; - - // Check database count limit before creating a new connection. - // This check is before entry() to avoid borrow conflicts, and the write lock - // prevents races between the len() check and the insert. - if !instances.contains_key(&db) && instances.len() >= db_instances.max { - return Err(Error::TooManyDatabases(db_instances.max)); - } - - // Use entry API to atomically check and insert, avoiding race conditions - // where two callers could both create wrappers - use std::collections::hash_map::Entry; - match instances.entry(db.clone()) { - Entry::Occupied(_) => { - // Another caller won the race and inserted while we waited for write lock - Ok(db) - } - Entry::Vacant(entry) => { - // We won the race, create and insert the wrapper - let wrapper = crate::resolve::connect(&db, &app, custom_config).await?; - entry.insert(wrapper); - Ok(db) - } - } -} - -/// Wait for migrations to complete for a database, if any are registered. -/// -/// Returns Ok(()) if: -/// - No migrations are registered for this database -/// - Migrations completed successfully -/// -/// Returns Err if migrations failed. -async fn await_migrations(migration_states: &State<'_, MigrationStates>, db: &str) -> Result<()> { - loop { - // Get notify handle before checking status - let notify = { - let states = migration_states.0.read().await; - match states.get(db) { - // No migrations registered for this database - None => return Ok(()), - - Some(state) => match &state.status { - // Migrations completed successfully - MigrationStatus::Complete => return Ok(()), - - // Migrations failed - return the error - MigrationStatus::Failed(error) => { - return Err(Error::Migration(sqlx::migrate::MigrateError::Source( - error.clone().into(), - ))); - } - - // Migrations still pending or running - wait for notification - MigrationStatus::Pending | MigrationStatus::Running => state.notify.clone(), - }, - } - }; - - // Wait for migration state change - notify.notified().await; - } +) -> Result { + let response = connect_to_database(&app, &db_key, custom_config).await?; + Ok(response.path) } /// Execute a write query (INSERT, UPDATE, DELETE, etc.) #[tauri::command] pub async fn execute( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, attached: Option>, @@ -197,8 +129,8 @@ pub async fn execute( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.execute(query, values); @@ -217,15 +149,15 @@ pub async fn execute( pub async fn execute_transaction( db_instances: State<'_, DbInstances>, regular_txs: State<'_, ActiveRegularTransactions>, - db: String, + db_key: String, statements: Vec, attached: Option>, ) -> Result> { let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; // Convert Statement structs to tuples for wrapper let stmt_tuples: Vec<(String, Vec)> = statements @@ -234,7 +166,7 @@ pub async fn execute_transaction( .collect(); // Generate unique key for tracking this transaction - let tx_key = format!("{}:{}", db, Uuid::new_v4()); + let tx_key = format!("{}:{}", db_key, Uuid::new_v4()); // Resolve attached specs if provided let resolved_specs = if let Some(specs) = attached { @@ -297,7 +229,7 @@ pub async fn execute_transaction( #[tauri::command] pub async fn fetch_all( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, attached: Option>, @@ -305,8 +237,8 @@ pub async fn fetch_all( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.fetch_all(query, values); @@ -324,7 +256,7 @@ pub async fn fetch_all( #[tauri::command] pub async fn fetch_one( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, attached: Option>, @@ -332,8 +264,8 @@ pub async fn fetch_one( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.fetch_one(query, values); @@ -352,7 +284,7 @@ pub async fn fetch_one( #[tauri::command] pub async fn fetch_page( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, keyset: Vec, @@ -370,8 +302,8 @@ pub async fn fetch_page( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.fetch_page(query, values, keyset, page_size); @@ -400,13 +332,13 @@ pub async fn fetch_page( pub async fn close( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, ) -> Result { - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; - if let Some(wrapper) = instances.remove(&db) { + if let Some(wrapper) = instances.remove(&db_key) { wrapper.close().await?; Ok(true) } else { @@ -453,13 +385,13 @@ pub async fn close_all( pub async fn remove( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, ) -> Result { - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; - if let Some(wrapper) = instances.remove(&db) { + if let Some(wrapper) = instances.remove(&db_key) { wrapper.remove().await?; Ok(true) } else { @@ -476,11 +408,11 @@ pub async fn remove( #[tauri::command] pub async fn get_migration_events( migration_states: State<'_, MigrationStates>, - db: String, + db_key: String, ) -> Result> { let states = migration_states.0.read().await; - match states.get(&db) { + match states.get(&db_key) { Some(state) => Ok(state.events.clone()), None => Ok(Vec::new()), } @@ -495,15 +427,15 @@ pub async fn get_migration_events( pub async fn begin_interruptible_transaction( db_instances: State<'_, DbInstances>, active_txs: State<'_, ActiveInterruptibleTransactions>, - db: String, + db_key: String, initial_statements: Vec, attached: Option>, ) -> Result { let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; // Generate unique transaction ID let transaction_id = Uuid::new_v4().to_string(); @@ -524,15 +456,15 @@ pub async fn begin_interruptible_transaction( // Execute initial statements let mut active_tx = - ActiveInterruptibleTransaction::new(db.clone(), transaction_id.clone(), writer); + ActiveInterruptibleTransaction::new(db_key.clone(), transaction_id.clone(), writer); active_tx.continue_with(initial_statements).await?; // Store transaction state - active_txs.insert(db.clone(), active_tx).await?; + active_txs.insert(db_key.clone(), active_tx).await?; Ok(TransactionToken { - db_path: db, + db_key, transaction_id, }) } @@ -550,14 +482,14 @@ pub async fn transaction_continue( TransactionAction::Continue { statements } => { // Remove transaction to get mutable access let mut tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; // Execute statements on the transaction match tx.continue_with(statements).await { Ok(_results) => { // Re-insert transaction - if this fails, tx is dropped and auto-rolled back - match active_txs.insert(token.db_path.clone(), tx).await { + match active_txs.insert(token.db_key.clone(), tx).await { Ok(()) => Ok(Some(token)), Err(e) => { // Transaction lost but will auto-rollback via Drop @@ -576,7 +508,7 @@ pub async fn transaction_continue( TransactionAction::Commit => { // Remove transaction and commit let tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; tx.commit().await?; @@ -586,7 +518,7 @@ pub async fn transaction_continue( TransactionAction::Rollback => { // Remove transaction and rollback let tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; tx.rollback().await?; @@ -608,14 +540,14 @@ pub async fn transaction_read( ) -> Result>> { // Remove transaction to get mutable access let mut tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; // Execute read on the transaction match tx.read(query, values).await { Ok(results) => { // Re-insert transaction - if this fails, tx is dropped and auto-rolled back - match active_txs.insert(token.db_path.clone(), tx).await { + match active_txs.insert(token.db_key.clone(), tx).await { Ok(()) => Ok(results), Err(e) => { // Transaction lost but will auto-rollback via Drop @@ -643,7 +575,7 @@ pub async fn transaction_read( pub async fn observe( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, tables: Vec, config: Option, ) -> Result<()> { @@ -659,13 +591,13 @@ pub async fn observe( // Abort plugin-level subscription tasks before the crate-level // enable_observation() drops the old broker - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; let wrapper = instances - .get_mut(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get_mut(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut observer_config = sqlx_sqlite_observer::ObserverConfig::new().with_tables(tables); @@ -697,13 +629,13 @@ pub async fn observe( pub async fn subscribe( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, tables: Vec, on_event: Channel, ) -> Result { const MAX_SUBSCRIPTIONS_PER_DATABASE: usize = 100; - let sub_count = active_subs.count_for_db(&db).await; + let sub_count = active_subs.count_for_db(&db_key).await; if sub_count >= MAX_SUBSCRIPTIONS_PER_DATABASE { return Err(Error::TooManySubscriptions(MAX_SUBSCRIPTIONS_PER_DATABASE)); } @@ -711,12 +643,12 @@ pub async fn subscribe( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let observable = wrapper .observable() - .ok_or_else(|| Error::ObservationNotEnabled(db.clone()))?; + .ok_or_else(|| Error::ObservationNotEnabled(db_key.clone()))?; // Create subscription stream let mut stream = observable.subscribe_stream(tables); @@ -726,7 +658,7 @@ pub async fn subscribe( // Spawn task to forward stream events to the Tauri Channel let sub_id = subscription_id.clone(); - let db_path = db.clone(); + let db_key_clone = db_key.clone(); let handle = tokio::spawn(async move { while let Some(event) = stream.next().await { @@ -738,12 +670,12 @@ pub async fn subscribe( } } - debug!("Subscription {} for db {} ended", sub_id, db_path); + debug!("Subscription {} for db {} ended", sub_id, &db_key_clone); }); // Track subscription active_subs - .insert(subscription_id.clone(), db.clone(), handle.abort_handle()) + .insert(subscription_id.clone(), db_key, handle.abort_handle()) .await; Ok(subscription_id) @@ -767,16 +699,16 @@ pub async fn unsubscribe( pub async fn unobserve( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, ) -> Result<()> { // Abort all subscriptions for this database first - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; let wrapper = instances - .get_mut(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get_mut(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; wrapper.disable_observation(); Ok(()) diff --git a/src/error.rs b/src/error.rs index 4fc5e1e..4b1fd77 100644 --- a/src/error.rs +++ b/src/error.rs @@ -31,6 +31,10 @@ pub enum Error { #[error("path traversal not allowed: {0}")] PathTraversal(String), + /// Database key is not registered with the plugin. + #[error("database key not registered: {0}")] + PathNotRegistered(String), + /// Attempted to access a database that hasn't been loaded. #[error("database {0} not loaded")] DatabaseNotLoaded(String), @@ -84,6 +88,7 @@ impl Error { Error::Migration(_) => "MIGRATION_ERROR".to_string(), Error::InvalidPath(_) => "INVALID_PATH".to_string(), Error::PathTraversal(_) => "PATH_TRAVERSAL".to_string(), + Error::PathNotRegistered(_) => "PATH_NOT_REGISTERED".to_string(), Error::DatabaseNotLoaded(_) => "DATABASE_NOT_LOADED".to_string(), Error::ObservationNotEnabled(_) => "OBSERVATION_NOT_ENABLED".to_string(), Error::TooManyDatabases(_) => "TOO_MANY_DATABASES".to_string(), @@ -123,6 +128,24 @@ mod tests { assert_eq!(err.error_code(), "INVALID_PATH"); } + #[test] + fn test_error_code_path_not_registered() { + let err = Error::PathNotRegistered("MAIN".into()); + assert_eq!(err.error_code(), "PATH_NOT_REGISTERED"); + } + + #[test] + fn test_error_serialization_path_not_registered() { + let err = Error::PathNotRegistered("MAIN".into()); + let json = serde_json::to_value(&err).unwrap(); + + assert_eq!(json["code"], "PATH_NOT_REGISTERED"); + assert_eq!( + json["message"].as_str().unwrap(), + "database key not registered: MAIN" + ); + } + #[test] fn test_error_code_unsupported_datatype() { let err = Error::Toolkit(sqlx_sqlite_toolkit::Error::UnsupportedDatatype( diff --git a/src/lib.rs b/src/lib.rs index f1ff209..f77b675 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,18 @@ use std::collections::HashMap; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::{AtomicU8, Ordering}; use serde::Serialize; use sqlx_sqlite_conn_mgr::Migrator; -use tauri::{Emitter, Manager, RunEvent, Runtime, plugin::Builder as PluginBuilder}; +use tauri::{AppHandle, Emitter, Manager, RunEvent, Runtime, plugin::Builder as PluginBuilder}; use tokio::sync::{Notify, RwLock}; use tracing::{debug, error, info, trace, warn}; mod commands; mod error; -mod resolve; mod subscriptions; +mod validate; pub use error::{Error, Result}; pub use sqlx_sqlite_conn_mgr::{ @@ -54,6 +55,8 @@ impl Drop for ExitGuard { /// This struct maintains a thread-safe map of database paths to their corresponding /// connection wrappers, with a configurable upper limit on how many databases can be /// loaded simultaneously. +/// +/// The string key is the registered database key. #[derive(Clone)] pub struct DbInstances { pub(crate) inner: Arc>>, @@ -79,6 +82,64 @@ impl DbInstances { } } +/// Tracks the paths of all registered databases. +/// The String value of the key is the database identifier, not the path. +/// For example, the value of the key `MAIN` would be something like +/// `/var/lib/myapp/main.db`. +/// +/// This key value is what will be used by the caller to interact with the database. +/// For example, when calling `load()` or `execute()`, the caller will pass the key value +/// to identify the database to which they want to connect. +#[derive(Clone, Default)] +pub struct RegisteredDatabases { + pub(crate) database_path_by_key: Arc>, +} + +/// Contains the information required for registering a database. +/// +/// When initializing or setting up the plugin, the caller will pass the path to the database +/// file and the migrator to use for the database. +/// +/// This information is then stored in the `RegisteredDatabases` struct, which is used to +/// track the paths of all registered databases. +/// +/// The `migrator` is not held by the app state, but rather is only used after +/// initialization to run the migrations for the database. +#[derive(Debug, Clone)] +struct DatabaseInfo { + path: PathBuf, + migrator: Option>, +} + +fn validated_database_info( + path: impl Into, + migrator: Option, +) -> Result { + let path = path.into(); + Ok(DatabaseInfo { + path: validate::validate_database_path(&path)?, + migrator: migrator.map(Arc::new), + }) +} + +/// Ensure each registration key maps to a distinct database path. +fn ensure_distinct_database_paths( + database_info_by_key: &HashMap, +) -> Result<()> { + let mut path_to_key = HashMap::new(); + + for (key, info) in database_info_by_key { + if let Some(existing_key) = path_to_key.insert(info.path.clone(), key.as_str()) { + return Err(Error::InvalidConfig(format!( + "database keys {existing_key} and {key} both register the same path: {}", + info.path.display() + ))); + } + } + + Ok(()) +} + /// Migration status for a database. #[derive(Debug, Clone)] pub enum MigrationStatus { @@ -119,6 +180,7 @@ impl MigrationState { } /// Tracks migration state for all databases. +/// The String value of the key is the database identifier, not the path. #[derive(Default)] pub struct MigrationStates(pub RwLock>); @@ -126,8 +188,12 @@ pub struct MigrationStates(pub RwLock>); #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct MigrationEvent { - /// Database path (relative, as registered) - pub db_path: String, + /// Database key, meant to be human readable (such as `MAIN`). + /// This is what is to be used by the client to interact with the database. + pub db_key: String, + /// Database path, the absolute path to the database file, such as + /// `/var/lib/myapp/main.db`. + pub db_path: PathBuf, /// Status: "running", "completed", "failed" pub status: String, /// Total number of migrations defined in the migrator (on "completed"), not just newly applied @@ -142,6 +208,17 @@ pub struct MigrationEvent { /// /// Use this to configure the plugin and build the plugin instance. /// +/// # Database registration +/// +/// Every database must be **registered** with a stable key and filesystem path (or +/// in-memory URI) before it can be opened. The frontend and Rust callers open databases +/// by **key** via `load()` / [`Connection::connect`]. Paths are validated and +/// canonicalized at registration time. +/// +/// Because legitimate paths usually depend on runtime values (for example +/// `app.path().app_data_dir()`), registration normally happens in the [`Builder::on_setup`] +/// hook. Static paths can be registered up front with [`Builder::register_database`]. +/// /// # Example /// /// ```ignore @@ -151,9 +228,9 @@ pub struct MigrationEvent { /// use tauri_plugin_sqlite::Builder; /// /// # fn main() { -/// // Basic setup (no migrations): +/// // Basic setup (no databases registered yet — register them in `on_setup`): /// tauri::Builder::default() -/// .plugin(Builder::new().build()) +/// .plugin(Builder::new().build().expect("failed to build sqlite plugin")) /// .run(tauri::generate_context!()) /// .expect("error while running tauri application"); /// # } @@ -166,64 +243,169 @@ pub struct MigrationEvent { /// // tauri::generate_context!() requires tauri.conf.json at compile time, /// // which cannot be provided in doc test environments. /// use tauri_plugin_sqlite::Builder; +/// use tauri::Manager; /// /// # fn main() { -/// // Setup with migrations: +/// // Resolve the database path from the app instance and register it with migrations. +/// // The frontend then calls `Database.load("MAIN")` with the registration key. /// tauri::Builder::default() /// .plugin( /// Builder::new() -/// .add_migrations("main.db", sqlx::migrate!("./migrations/main")) -/// .add_migrations("cache.db", sqlx::migrate!("./migrations/cache")) +/// .on_setup(|app, reg| { +/// let db = app.path().app_data_dir()?.join("main.db"); +/// reg.register_database( +/// "MAIN", +/// db, +/// Some(sqlx::migrate!("./migrations/main")), +/// )?; +/// Ok(()) +/// }) /// .build() +/// .expect("failed to build sqlite plugin") /// ) /// .run(tauri::generate_context!()) /// .expect("error while running tauri application"); /// # } /// ``` -#[derive(Debug, Default)] -pub struct Builder { - /// Migrations registered per database path - migrations: HashMap>, +/// +/// Collects database registrations from the [`Builder::on_setup`] hook. +/// +/// Passed to the `on_setup` closure during plugin setup, where the `app` instance is +/// available. Use it to register values that can only be computed at runtime (for example, +/// paths derived from `app.path().app_data_dir()`). +#[derive(Default)] +pub struct SetupRegistrar { + database_info_by_key: HashMap, +} + +impl SetupRegistrar { + /// Register a database path, optionally with migrations. See [`Builder::register_database`]. + /// + /// This invocation is to be used when the database path is known at runtime (such as + /// a path dependent on `app.path().app_data_dir()`). + /// + /// For a path that is known at compile time, use [`Builder::register_database`] + /// instead. + /// + /// The `key` is the identifier for the database. It is used to identify the database + /// when calling `load()` or `execute()`. + /// + /// The `path` is the absolute filesystem path or in-memory URI. It is validated and + /// canonicalized at registration time. + /// + /// Returns `Err` if the path fails validation (relative, traversal, or canonicalization). + /// + /// The `migrator` runs automatically at plugin initialization when provided. + /// + /// If the same key is registered more than once, the last registration will override + /// all previous ones. + /// + /// Distinct keys must map to distinct database paths. Duplicate paths are rejected + /// when the plugin initializes (see [`Builder::build`]). + pub fn register_database( + &mut self, + key: &str, + path: impl Into, + migrator: Option, + ) -> Result<()> { + self + .database_info_by_key + .insert(key.to_string(), validated_database_info(path, migrator)?); + Ok(()) + } +} + +/// Closure type for the deferred [`Builder::on_setup`] hook. +type OnSetupHook = Box, &mut SetupRegistrar) -> Result<()> + Send>; + +pub struct Builder { + /// Migrations registered per database path, keyed by the database key. + database_info_by_key: HashMap, /// Timeout for interruptible transactions. Defaults to 5 minutes. transaction_timeout: Option, /// Maximum number of concurrently loaded databases. Defaults to 50. max_databases: Option, + /// Deferred hook run during plugin setup with the app handle. Lets callers register + /// paths/migrations computed from `app`. Returning `Err` aborts app startup. + on_setup: Option>, +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for Builder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Builder") + .field("database_info_by_key", &self.database_info_by_key) + .field("transaction_timeout", &self.transaction_timeout) + .field("max_databases", &self.max_databases) + .field("on_setup", &self.on_setup.is_some()) + .finish() + } } -impl Builder { +impl Builder { /// Create a new builder instance. pub fn new() -> Self { Self { - migrations: HashMap::new(), + database_info_by_key: HashMap::new(), transaction_timeout: None, max_databases: None, + on_setup: None, } } - /// Register migrations for a database path. + /// Register a database by key and path, optionally with migrations. /// - /// Migrations will be run automatically at plugin initialization. - /// Multiple databases can have their own migrations. + /// Pass `None` for `migrator` when the database has no migrations. Migrations run + /// automatically at plugin initialization when provided. /// - /// # Arguments + /// Use this when the path is known at compile time. For paths derived from the `app` + /// instance (for example `app.path().app_data_dir()`), use [`on_setup`](Self::on_setup) + /// and [`SetupRegistrar::register_database`] instead. /// - /// * `path` - Database path (relative to app config directory) - /// * `migrator` - Migrator instance, typically from `sqlx::migrate!()` + /// The frontend must call `load()` with the registration **key**. /// /// # Example /// /// ```no_run /// use tauri_plugin_sqlite::Builder; + /// use std::path::PathBuf; + /// + /// const MAIN_DB_KEY: &str = "MAIN"; /// - /// # fn example() { - /// Builder::new() - /// .add_migrations("main.db", sqlx::migrate!("./doc-test-fixtures/migrations")) - /// .build::(); + /// # fn example() -> tauri_plugin_sqlite::Result<()> { + /// Builder::::new() + /// .register_database( + /// MAIN_DB_KEY, + /// PathBuf::from("/var/lib/myapp/main.db"), + /// Some(sqlx::migrate!("./doc-test-fixtures/migrations")), + /// )? + /// .build()?; + /// # Ok(()) /// # } /// ``` - pub fn add_migrations(mut self, path: &str, migrator: Migrator) -> Self { - self.migrations.insert(path.to_string(), Arc::new(migrator)); + /// + /// If the same key is registered more than once, the last registration will override + /// all previous ones. + /// + /// Distinct keys must map to distinct database paths. If two distinct keys register + /// the same path, plugin initialization returns [`Error::InvalidConfig`]. Registrations + /// from [`on_setup`](Self::on_setup) are validated when the merged map is initialized. + pub fn register_database( + mut self, + key: &str, + path: impl Into, + migrator: Option, + ) -> Result { self + .database_info_by_key + .insert(key.to_string(), validated_database_info(path, migrator)?); + + Ok(self) } /// Set the timeout for interruptible transactions. @@ -258,13 +440,62 @@ impl Builder { Ok(self) } + /// Register a hook that runs during plugin setup, once the `app` instance exists. + /// + /// This is the primary way to register database paths, because the legitimate absolute + /// paths usually depend on runtime values — for example paths derived from + /// `app.path().app_data_dir()`. The closure receives the app handle and a + /// [`SetupRegistrar`] on which you call [`register_database`](SetupRegistrar::register_database). + /// + /// Entries registered here are merged with those registered statically via + /// [`register_database`](Self::register_database); a later registration for the same + /// key overrides an earlier one. + /// + /// Returning `Err` from the hook aborts app startup (fail-fast). + /// + /// # Example + /// + /// ```no_run + /// use tauri_plugin_sqlite::Builder; + /// use tauri::Manager; + /// + /// const MAIN_DB_KEY: &str = "MAIN"; + /// + /// # fn example() -> tauri_plugin_sqlite::Result<()> { + /// Builder::::new() + /// .on_setup(|app, reg| { + /// let db = app.path().app_data_dir()?.join("main.db"); + /// reg.register_database( + /// MAIN_DB_KEY, + /// db, + /// Some(sqlx::migrate!("./doc-test-fixtures/migrations")) + /// )?; + /// Ok(()) + /// }) + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + pub fn on_setup( + mut self, + f: impl FnOnce(&AppHandle, &mut SetupRegistrar) -> Result<()> + Send + 'static, + ) -> Self { + self.on_setup = Some(Box::new(f)); + self + } + /// Build the plugin with command registration and state management. - pub fn build(self) -> tauri::plugin::TauriPlugin { - let migrations = Arc::new(self.migrations); + /// + /// Duplicate paths across distinct registration keys are rejected during plugin + /// initialization (the setup hook), after [`on_setup`](Self::on_setup) registrations + /// are merged with static ones. + pub fn build(self) -> Result> { + let database_info_by_key = self.database_info_by_key; let transaction_timeout = self.transaction_timeout; let max_databases = self.max_databases; + let on_setup = self.on_setup; - PluginBuilder::::new("sqlite") + Ok(PluginBuilder::::new("sqlite") .invoke_handler(tauri::generate_handler![ commands::load, commands::execute, @@ -297,26 +528,40 @@ impl Builder { app.manage(ActiveRegularTransactions::default()); app.manage(subscriptions::ActiveSubscriptions::default()); - // Initialize migration states as Pending for all registered databases + // Run the deferred setup hook (if any), merge with static registrations. + // Paths are validated and canonicalized at registration time. Hook errors + // abort startup (fail-fast). + let mut database_info_by_key = database_info_by_key; + if let Some(on_setup_action) = on_setup { + let mut registrar = SetupRegistrar::default(); + on_setup_action(app, &mut registrar)?; + database_info_by_key.extend(registrar.database_info_by_key); + } + + ensure_distinct_database_paths(&database_info_by_key)?; + + app.manage(RegisteredDatabases { + database_path_by_key: Arc::new(database_info_by_key.iter().map(|(key, info)| (key.clone(), info.path.clone())).collect()), + }); + let migration_states = app.state::(); { let mut states = migration_states.0.blocking_write(); - for path in migrations.keys() { - states.insert(path.clone(), MigrationState::new()); + for key in database_info_by_key.keys() { + states.insert(key.clone(), MigrationState::new()); } } - // Spawn parallel migration tasks for each registered database - if !migrations.is_empty() { - info!("Starting migrations for {} database(s)", migrations.len()); + for (key, info) in &database_info_by_key { + if let Some(migrator) = &info.migrator { + info!("Starting migrations for database {}", key); - for (path, migrator) in migrations.iter() { + let key = key.clone(); + let migrator = migrator.clone(); + let path = info.path.clone(); let app_handle = app.clone(); - let path = path.clone(); - let migrator = Arc::clone(migrator); - tauri::async_runtime::spawn(async move { - run_migrations_for_database(app_handle, path, migrator).await; + run_migrations_for_database(app_handle, &key, &path, &migrator).await; }); } } @@ -446,13 +691,15 @@ impl Builder { } } }) - .build() + .build()) } } /// Initializes the plugin with default configuration. pub fn init() -> tauri::plugin::TauriPlugin { - Builder::new().build() + Builder::::new() + .build() + .expect("failed to build sqlite plugin") } /// Run migrations for a single database and emit events. @@ -472,38 +719,39 @@ pub fn init() -> tauri::plugin::TauriPlugin { /// global registry and is reused when `load` creates its own wrapper. async fn run_migrations_for_database( app: tauri::AppHandle, - path: String, - migrator: Arc, + key: &str, + path: &Path, + migrator: &Arc, ) { let migration_states = app.state::(); // Update state to Running { let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Running); } } // Emit running event - emit_migration_event(&app, &path, "running", None, None); + emit_migration_event(&app, key, path, "running", None, None); // Resolve absolute path and connect - let abs_path = match resolve_migration_path(&path, &app) { + let abs_path = match resolve_database_path(key, &app) { Ok(p) => p, Err(e) => { let error_msg = e.to_string(); error!( "Failed to resolve migration path for {}: {}", - path, error_msg + key, error_msg ); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Failed(error_msg.clone())); } - emit_migration_event(&app, &path, "failed", None, Some(error_msg)); + emit_migration_event(&app, key, path, "failed", None, Some(error_msg)); return; } }; @@ -513,14 +761,14 @@ async fn run_migrations_for_database( Ok(wrapper) => wrapper, Err(e) => { let error_msg = e.to_string(); - error!("Failed to connect for migrations {}: {}", path, error_msg); + error!("Failed to connect for migrations {}: {}", key, error_msg); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Failed(error_msg.clone())); } - emit_migration_event(&app, &path, "failed", None, Some(error_msg)); + emit_migration_event(&app, key, path, "failed", None, Some(error_msg)); return; } }; @@ -529,30 +777,30 @@ async fn run_migrations_for_database( // Note: SQLx's migrator.run() doesn't provide per-migration callbacks, // so we can only report start and finish. For detailed per-migration events, // we would need to iterate migrations manually. - trace!("Running migrations for {}", path); + trace!("Running migrations for {}", key); - match db.run_migrations(&migrator).await { + match db.run_migrations(migrator).await { Ok(()) => { - info!("Migrations completed successfully for {}", path); + info!("Migrations completed successfully for {}", key); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Complete); } let migration_count = migrator.iter().count(); - emit_migration_event(&app, &path, "completed", Some(migration_count), None); + emit_migration_event(&app, key, path, "completed", Some(migration_count), None); } Err(e) => { let error_msg = e.to_string(); - error!("Migration failed for {}: {}", path, error_msg); + error!("Migration failed for {}: {}", key, error_msg); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Failed(error_msg.clone())); } - emit_migration_event(&app, &path, "failed", None, Some(error_msg)); + emit_migration_event(&app, key, path, "failed", None, Some(error_msg)); } } } @@ -560,13 +808,15 @@ async fn run_migrations_for_database( /// Emit a migration event to the frontend and cache it. fn emit_migration_event( app: &tauri::AppHandle, - db_path: &str, + db_key: &str, + db_path: &Path, status: &str, migration_count: Option, error: Option, ) { let event = MigrationEvent { - db_path: db_path.to_string(), + db_key: db_key.to_string(), + db_path: db_path.to_path_buf(), status: status.to_string(), migration_count, error, @@ -575,7 +825,7 @@ fn emit_migration_event( // Cache event in migration state let migration_states = app.state::(); if let Ok(mut states) = migration_states.0.try_write() - && let Some(state) = states.get_mut(db_path) + && let Some(state) = states.get_mut(db_key) { state.cache_event(event.clone()); } @@ -585,36 +835,316 @@ fn emit_migration_event( } } -/// Resolve database path for migrations. +/// Connect to a registered database by its registration key. /// -/// Delegates to `resolve::resolve_database_path` to ensure consistent path validation -/// across all entry points. -fn resolve_migration_path( - path: &str, - app: &tauri::AppHandle, -) -> Result { - crate::resolve::resolve_database_path(path, app) +/// Opens the database through the same path as the frontend `load` IPC command +/// ([`connect_to_database`]): awaits migrations, enforces max-database limits, and +/// stores the wrapper in [`DbInstances`]. Returns a [`DatabaseWrapper`] for direct +/// toolkit use. +/// +/// The `database_key` must match a key registered via +/// [`Builder::register_database`] or [`SetupRegistrar::register_database`]. +/// +/// # Why use a key? +/// +/// Database paths are usually resolved once during plugin setup — for example +/// `app.path().app_data_dir()?.join("main.db")` in [`Builder::on_setup`]. Without +/// registration keys, every call site would repeat that path discovery or keep its own +/// `PathBuf`. Registration stores the key-to-path mapping once; `connect` reuses the key +/// so callers do not supply a filesystem path on every open. +/// +/// On mobile, path discovery is not a cheap string join. Resolvers such as +/// [tauri-plugin-fs-resolver](https://github.com/silvermine/tauri-plugin-fs-resolver) +/// call platform-native APIs so paths match OS sandbox rules. On Android that means a +/// JNI call into Kotlin `Context` (e.g. `getFilesDir()`) on each resolve — noticeably +/// more expensive than a local HashMap lookup, and a different kind of boundary than +/// TypeScript-to-Rust IPC (in-process JNI vs webview bridge). Register the resolved +/// `PathBuf` once in `on_setup`; every later `connect(database_key)` only looks up that +/// key in [`RegisteredDatabases`] — no repeat native or JNI work. +/// +/// For webview/frontend access, use `Database.load(dbKey)` instead. +/// +/// # Example +/// +/// ```ignore +/// use tauri::{Manager, Runtime}; +/// use tauri_plugin_sqlite::Connection; +/// +/// // During setup (on_setup): +/// // reg.register_database("MAIN", app.path().app_data_dir()?.join("main.db"), None); +/// +/// async fn read_users(app: tauri::AppHandle) -> tauri_plugin_sqlite::Result<()> { +/// let db = app.connect("MAIN").await?; +/// let rows = db.fetch_all("SELECT * FROM users".into(), vec![]).execute().await?; +/// Ok(()) +/// } +/// ``` +pub trait Connection { + /// Connect with default pool configuration. + fn connect(&self, database_key: &str) -> impl Future> + Send; + + /// Connect with custom [`SqliteDatabaseConfig`] (pool sizes, idle timeout). + fn connect_with_config( + &self, + database_key: &str, + config: SqliteDatabaseConfig, + ) -> impl Future> + Send; +} + +/// Delegates to [`connect_to_database`]: same open path as the `load` IPC command. +impl Connection for AppHandle { + async fn connect(&self, database_key: &str) -> Result { + let response = connect_to_database(self, database_key, None).await?; + Ok(response.wrapper) + } + + async fn connect_with_config( + &self, + database_key: &str, + config: SqliteDatabaseConfig, + ) -> Result { + let response = connect_to_database(self, database_key, Some(config)).await?; + Ok(response.wrapper) + } +} + +struct ConnectionResponse { + path: PathBuf, + wrapper: DatabaseWrapper, +} + +async fn connect_to_database( + app: &AppHandle, + db_key: &str, + custom_config: Option, +) -> Result { + let migration_states = app.state::(); + let db_instances = app.state::(); + + // Wait for migrations to complete if registered for this database + await_migrations(&migration_states, db_key).await?; + + let path = resolve_database_path(db_key, app)?; + + let instances = db_instances.inner.read().await; + + // Return cached if db was already loaded + if let Some(wrapper) = instances.get(db_key) { + return Ok(ConnectionResponse { + path, + wrapper: wrapper.clone(), + }); + } + + drop(instances); // Release read lock before acquiring write lock + + let mut instances = db_instances.inner.write().await; + + // Check database count limit before creating a new connection. + // This check is before entry() to avoid borrow conflicts, and the write lock + // prevents races between the len() check and the insert. + if !instances.contains_key(db_key) && instances.len() >= db_instances.max { + return Err(Error::TooManyDatabases(db_instances.max)); + } + + // Use entry API to atomically check and insert, avoiding race conditions + // where two callers could both create wrappers + use std::collections::hash_map::Entry; + match instances.entry(db_key.to_string()) { + Entry::Occupied(entry) => { + // Another caller won the race and inserted while we waited for write lock + Ok(ConnectionResponse { + path, + wrapper: entry.get().clone(), + }) + } + Entry::Vacant(entry) => { + // We won the race, create and insert the wrapper + let wrapper = DatabaseWrapper::connect(&path, custom_config).await?; + entry.insert(wrapper.clone()); + Ok(ConnectionResponse { path, wrapper }) + } + } +} + +/// Wait for migrations to complete for a database, if any are registered. +/// +/// Returns Ok(()) if: +/// - No migrations are registered for this database +/// - Migrations completed successfully +/// +/// Returns Err if migrations failed. +async fn await_migrations(migration_states: &MigrationStates, db_key: &str) -> Result<()> { + loop { + // Get notify handle before checking status + let notify = { + match migration_states.0.read().await.get(db_key) { + // No migrations registered for this database + None => return Ok(()), + + Some(state) => match &state.status { + // Migrations completed successfully + MigrationStatus::Complete => return Ok(()), + + // Migrations failed - return the error + MigrationStatus::Failed(error) => { + return Err(Error::Migration(sqlx::migrate::MigrateError::Source( + error.clone().into(), + ))); + } + + // Migrations still pending or running - wait for notification + MigrationStatus::Pending | MigrationStatus::Running => state.notify.clone(), + }, + } + }; + + // Wait for migration state change + notify.notified().await; + } +} + +/// Resolve a registered database path by key. +/// +/// The `db_key` must match a key registered via +/// [`crate::Builder::register_database`] / [`crate::SetupRegistrar::register_database`]. +/// +/// Returns `Err(Error::PathNotRegistered)` if the key is not registered. +fn resolve_database_path(db_key: &str, app: &AppHandle) -> Result { + let registered_databases = app.state::(); + + if let Some(path) = registered_databases.database_path_by_key.get(db_key) { + return Ok(path.clone()); + } + + Err(Error::PathNotRegistered(db_key.to_string())) } #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; + use tauri::plugin::Plugin; + use tauri::test::{MockRuntime, mock_app, mock_builder, mock_context, noop_assets}; + + fn builder_with_duplicate_paths(temp_dir: &tempfile::TempDir) -> Builder { + let path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + + Builder::::new() + .register_database("MAIN", &path, None) + .unwrap() + .register_database("BACKUP", &path, None) + .unwrap() + } + + fn mock_app_with_registrations( + database_path_by_key: HashMap, + ) -> tauri::App { + let app = tauri::test::mock_app(); + app.manage(RegisteredDatabases { + database_path_by_key: Arc::new(database_path_by_key), + }); + app.manage(DbInstances::default()); + app.manage(MigrationStates::default()); + app + } + + #[tokio::test] + async fn test_connect_to_database_registered_key() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let mut registrations = HashMap::new(); + registrations.insert("MAIN".to_string(), db_path.clone()); + + let app = mock_app_with_registrations(registrations); + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .unwrap(); + + assert_eq!(response.path, db_path); + } + + #[tokio::test] + async fn test_connect_unregistered_key_returns_path_not_registered() { + let app = mock_app_with_registrations(HashMap::new()); + let err = match connect_to_database(app.handle(), "MAIN", None).await { + Err(error) => error, + Ok(_) => panic!("expected unregistered key to fail"), + }; + + assert!(matches!(err, Error::PathNotRegistered(_))); + assert_eq!(err.to_string(), "database key not registered: MAIN"); + } + + #[test] + fn test_register_database_last_registration_wins() { + let temp_dir = tempfile::tempdir().unwrap(); + let first_path = validate::validate_database_path(temp_dir.path().join("first.db")).unwrap(); + let second_path = + validate::validate_database_path(temp_dir.path().join("second.db")).unwrap(); + + let mut database_path_by_key = HashMap::new(); + database_path_by_key.insert("MAIN".to_string(), first_path); + database_path_by_key.insert("MAIN".to_string(), second_path.clone()); + + let app = mock_app_with_registrations(database_path_by_key); + let path = resolve_database_path("MAIN", app.handle()).unwrap(); + + assert_eq!(path, second_path); + } + + #[test] + fn test_setup_rejects_duplicate_paths_for_distinct_keys() { + let temp_dir = tempfile::tempdir().unwrap(); + let mut plugin = builder_with_duplicate_paths(&temp_dir).build().unwrap(); + let app = mock_app(); + + let err = plugin + .initialize(app.handle(), serde_json::Value::default()) + .unwrap_err(); + + let err_msg = err.to_string(); + assert!(err_msg.contains("MAIN")); + assert!(err_msg.contains("BACKUP")); + assert!(err_msg.contains("invalid configuration")); + } + + #[test] + fn test_app_build_rejects_duplicate_paths_for_distinct_keys() { + let temp_dir = tempfile::tempdir().unwrap(); + let plugin = builder_with_duplicate_paths(&temp_dir).build().unwrap(); + + let err = mock_builder() + .plugin(plugin) + .build(mock_context(noop_assets())) + .unwrap_err(); + + match err { + tauri::Error::PluginInitialization(name, message) => { + assert_eq!(name, "sqlite"); + assert!(message.contains("MAIN")); + assert!(message.contains("BACKUP")); + assert!(message.contains("invalid configuration")); + } + other => panic!("expected PluginInitialization, got {other:?}"), + } + } #[test] fn test_max_databases_rejects_zero() { - let err = Builder::new().max_databases(0).unwrap_err(); + let err = Builder::::new().max_databases(0).unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_))); } #[test] fn test_max_databases_accepts_positive() { - let builder = Builder::new().max_databases(1).unwrap(); + let builder = Builder::::new().max_databases(1).unwrap(); assert_eq!(builder.max_databases, Some(1)); } #[test] fn test_transaction_timeout_rejects_zero() { - let err = Builder::new() + let err = Builder::::new() .transaction_timeout(std::time::Duration::ZERO) .unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_))); @@ -622,7 +1152,7 @@ mod tests { #[test] fn test_transaction_timeout_accepts_positive() { - let builder = Builder::new() + let builder = Builder::::new() .transaction_timeout(std::time::Duration::from_secs(1)) .unwrap(); assert_eq!( diff --git a/src/resolve.rs b/src/resolve.rs deleted file mode 100644 index 08d3b25..0000000 --- a/src/resolve.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::fs::create_dir_all; -use std::path::{Component, Path, PathBuf}; - -use sqlx_sqlite_conn_mgr::SqliteDatabaseConfig; -use sqlx_sqlite_toolkit::DatabaseWrapper; -use tauri::{AppHandle, Manager, Runtime}; - -use crate::Error; - -/// Connect to a SQLite database via the connection manager, resolving -/// the path relative to the app config directory. -/// -/// This is the Tauri-specific connection method that resolves relative paths -/// before delegating to the toolkit's `DatabaseWrapper::connect()`. -pub async fn connect( - path: &str, - app: &AppHandle, - custom_config: Option, -) -> Result { - let abs_path = resolve_database_path(path, app)?; - Ok(DatabaseWrapper::connect(&abs_path, custom_config).await?) -} - -/// Resolve database file path relative to app config directory. -/// -/// Paths are joined to `app_config_dir()` (e.g., `Library/Application Support/${bundleIdentifier}` -/// on iOS). Special paths like `:memory:` are passed through unchanged. -/// -/// Returns `Err(Error::PathTraversal)` if the path attempts to escape the app config directory -/// via absolute paths, `..` segments, or null bytes. -pub fn resolve_database_path(path: &str, app: &AppHandle) -> Result { - let app_path = app - .path() - .app_config_dir() - .map_err(|_| Error::InvalidPath("No app config path found".to_string()))?; - - create_dir_all(&app_path)?; - - validate_and_resolve(path, &app_path) -} - -/// Validate a user-supplied path and resolve it against a base directory. -/// -/// In-memory database paths are passed through unchanged. All other paths are validated -/// to ensure they cannot escape the base directory. -fn validate_and_resolve(path: &str, base: &Path) -> Result { - // Pass through in-memory database paths unchanged — they don't touch the filesystem. - // Matches the same patterns as `is_memory_database` in sqlx-sqlite-conn-mgr. - if is_memory_path(path) { - return Ok(PathBuf::from(path)); - } - - // Reject null bytes — these can truncate paths in C-level filesystem calls - if path.contains('\0') { - return Err(Error::PathTraversal("path contains null byte".to_string())); - } - - let rel = Path::new(path); - - // Reject absolute paths — PathBuf::join replaces the base when given an absolute path - if rel.is_absolute() { - return Err(Error::PathTraversal( - "absolute paths are not allowed".to_string(), - )); - } - - // Reject parent directory components — prevents escaping the base via `../` - for component in rel.components() { - if matches!(component, Component::ParentDir) { - return Err(Error::PathTraversal( - "parent directory references are not allowed".to_string(), - )); - } - } - - // Join and canonicalize to verify containment. The parent directory is canonicalized - // because the file may not exist yet. - let joined = base.join(rel); - let canonical_base = base - .canonicalize() - .map_err(|e| Error::InvalidPath(format!("cannot canonicalize base path: {e}")))?; - - let canonical_resolved = if joined.exists() { - joined.canonicalize() - } else { - // Ensure intermediate directories exist so that canonicalize can resolve the - // parent. This matches the caller's expectation that nested relative paths like - // "subdir/mydb.db" work without pre-creating "subdir/". - let parent = joined - .parent() - .ok_or_else(|| Error::InvalidPath("path has no parent".to_string()))?; - - create_dir_all(parent)?; - - parent - .canonicalize() - .map(|p| p.join(joined.file_name().unwrap_or_default())) - } - .map_err(|e| Error::InvalidPath(format!("cannot canonicalize path: {e}")))?; - - if !canonical_resolved.starts_with(&canonical_base) { - return Err(Error::PathTraversal( - "resolved path escapes the base directory".to_string(), - )); - } - - // Return the original (non-canonicalized) joined path for consistency with how the - // rest of the codebase references database paths. - Ok(joined) -} - -/// Check if a path string represents an in-memory SQLite database. -/// -/// Matches the same patterns as `is_memory_database` in `sqlx-sqlite-conn-mgr`: -/// `:memory:`, `file::memory:*` URIs, and `mode=memory` query parameters. -fn is_memory_path(path: &str) -> bool { - path == ":memory:" - || path.starts_with("file::memory:") - || (path.starts_with("file:") && path.contains("mode=memory")) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::fs; - - /// Helper that creates a temporary base directory for testing. - fn make_temp_base() -> PathBuf { - let dir = std::env::temp_dir().join(format!("tauri_sqlite_test_{}", std::process::id())); - fs::create_dir_all(&dir).unwrap(); - dir - } - - #[test] - fn test_simple_filename() { - let base = make_temp_base(); - let result = validate_and_resolve("mydb.db", &base).unwrap(); - assert_eq!(result, base.join("mydb.db")); - } - - #[test] - fn test_subdirectory_path() { - let base = make_temp_base(); - // Intermediate directories are auto-created — no manual setup needed - let result = validate_and_resolve("subdir/mydb.db", &base).unwrap(); - assert_eq!(result, base.join("subdir/mydb.db")); - assert!(base.join("subdir").is_dir()); - } - - #[test] - fn test_nested_subdirectory_path() { - let base = make_temp_base(); - let result = validate_and_resolve("a/b/c/mydb.db", &base).unwrap(); - assert_eq!(result, base.join("a/b/c/mydb.db")); - assert!(base.join("a/b/c").is_dir()); - } - - #[test] - fn test_memory_passthrough() { - let base = make_temp_base(); - assert_eq!( - validate_and_resolve(":memory:", &base).unwrap(), - PathBuf::from(":memory:"), - ); - } - - #[test] - fn test_file_memory_uri_passthrough() { - let base = make_temp_base(); - assert_eq!( - validate_and_resolve("file::memory:?cache=shared", &base).unwrap(), - PathBuf::from("file::memory:?cache=shared"), - ); - } - - #[test] - fn test_mode_memory_passthrough() { - let base = make_temp_base(); - assert_eq!( - validate_and_resolve("file:test?mode=memory", &base).unwrap(), - PathBuf::from("file:test?mode=memory"), - ); - } - - #[test] - fn test_rejects_parent_traversal() { - let base = make_temp_base(); - let err = validate_and_resolve("../../../etc/passwd", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_absolute_path() { - let base = make_temp_base(); - let err = validate_and_resolve("/etc/passwd", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_embedded_traversal() { - let base = make_temp_base(); - let err = validate_and_resolve("foo/../../bar", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_null_byte() { - let base = make_temp_base(); - let err = validate_and_resolve("path\0evil", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_non_uri_mode_memory() { - let base = make_temp_base(); - // A bare filename containing "mode=memory" is not a valid SQLite URI — - // it should go through normal path validation, not be passed through. - let result = validate_and_resolve("evil.db?mode=memory", &base).unwrap(); - assert_eq!(result, base.join("evil.db?mode=memory")); - } -} diff --git a/src/subscriptions.rs b/src/subscriptions.rs index 2805d6c..21207a2 100644 --- a/src/subscriptions.rs +++ b/src/subscriptions.rs @@ -113,8 +113,8 @@ pub struct ObserverConfigParams { struct ActiveSubscription { /// Abort handle for the subscription forwarding task. abort_handle: tokio::task::AbortHandle, - /// Database path this subscription is for. - db_path: String, + /// Database key this subscription is for. + db_key: String, } /// Global state tracking all active observer subscriptions. @@ -123,13 +123,13 @@ pub struct ActiveSubscriptions(Arc>>) impl ActiveSubscriptions { /// Insert a new subscription. - pub async fn insert(&self, id: String, db_path: String, abort_handle: tokio::task::AbortHandle) { + pub async fn insert(&self, id: String, db_key: String, abort_handle: tokio::task::AbortHandle) { let mut subs = self.0.write().await; subs.insert( id, ActiveSubscription { abort_handle, - db_path, + db_key, }, ); } @@ -146,11 +146,11 @@ impl ActiveSubscriptions { } /// Remove and abort all subscriptions for a specific database. - pub async fn remove_for_db(&self, db_path: &str) { + pub async fn remove_for_db(&self, db_key: &str) { let mut subs = self.0.write().await; let keys_to_remove: Vec = subs .iter() - .filter(|(_, sub)| sub.db_path == db_path) + .filter(|(_, sub)| sub.db_key == db_key) .map(|(k, _)| k.clone()) .collect(); @@ -162,9 +162,9 @@ impl ActiveSubscriptions { } /// Count active subscriptions for a specific database. - pub async fn count_for_db(&self, db_path: &str) -> usize { + pub async fn count_for_db(&self, db_key: &str) -> usize { let subs = self.0.read().await; - subs.values().filter(|sub| sub.db_path == db_path).count() + subs.values().filter(|sub| sub.db_key == db_key).count() } /// Abort all subscriptions (for cleanup on app exit). diff --git a/src/validate.rs b/src/validate.rs new file mode 100644 index 0000000..ee24103 --- /dev/null +++ b/src/validate.rs @@ -0,0 +1,171 @@ +use std::fs; +use std::path::{Component, Path, PathBuf}; + +use sqlx_sqlite_conn_mgr::{canonicalize_database_path, is_memory_database}; + +use crate::{Error, Result}; + +/// Validate and normalize a database path at registration time. +/// +/// In-memory databases (`:memory:`, `file::memory:*`, and `file:` URIs with an exact +/// `mode=memory` query parameter) are returned unchanged. File paths must not contain +/// null bytes or `..` components, must be absolute, and are canonicalized for consistent +/// lookups (symlink-safe when the path or its parent exists). +pub fn validate_database_path(path: impl AsRef) -> Result { + let path = path.as_ref(); + + if is_memory_database(path) { + return Ok(path.to_path_buf()); + } + + let path_str = path.to_str().ok_or_else(|| { + Error::InvalidPath(format!( + "database path is not valid UTF-8: {}", + path.display() + )) + })?; + + if path_str.contains('\0') { + return Err(Error::PathTraversal(format!( + "database path contains null byte: {path_str}" + ))); + } + + if path + .components() + .any(|component| matches!(component, Component::ParentDir)) + { + return Err(Error::PathTraversal(format!( + "database path contains parent traversal: {path_str}" + ))); + } + + if !path.is_absolute() { + return Err(Error::InvalidPath(format!( + "database path must be absolute: {path_str}" + ))); + } + + if let Some(parent) = path.parent() + && !parent.as_os_str().is_empty() + { + fs::create_dir_all(parent).map_err(|error| { + Error::InvalidPath(format!( + "failed to create parent directory for database path {path_str}: {error}" + )) + })?; + } + + canonicalize_database_path(path).map_err(|error| { + Error::InvalidPath(format!( + "failed to canonicalize database path {path_str}: {error}" + )) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::sync::atomic::{AtomicU64, Ordering}; + + fn make_temp_dir() -> PathBuf { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let n = COUNTER.fetch_add(1, Ordering::Relaxed); + let dir = + std::env::temp_dir().join(format!("tauri_sqlite_test_{}_{}", std::process::id(), n)); + fs::create_dir_all(&dir).unwrap(); + dir + } + + #[test] + fn test_memory_passthrough() { + assert_eq!( + validate_database_path(":memory:").unwrap(), + PathBuf::from(":memory:"), + ); + } + + #[test] + fn test_file_memory_uri_passthrough() { + assert_eq!( + validate_database_path("file::memory:?cache=shared").unwrap(), + PathBuf::from("file::memory:?cache=shared"), + ); + } + + #[test] + fn test_mode_memory_passthrough() { + assert_eq!( + validate_database_path("file:test?mode=memory").unwrap(), + PathBuf::from("file:test?mode=memory"), + ); + } + + #[test] + fn test_mode_memory_substring_in_value_is_not_treated_as_memory() { + let err = validate_database_path("file:/home/user/real.db?x=mode=memory").unwrap_err(); + assert!(matches!(err, Error::InvalidPath(_))); + } + + #[test] + fn test_creates_missing_parent_directory() { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let n = COUNTER.fetch_add(1, Ordering::Relaxed); + let base = std::env::temp_dir().join(format!( + "tauri_sqlite_test_missing_parent_{}_{}", + std::process::id(), + n + )); + let db_path = base.join("nested").join("main.db"); + + let result = validate_database_path(&db_path).unwrap(); + + assert!(base.join("nested").is_dir()); + assert_eq!( + result, + base.canonicalize().unwrap().join("nested").join("main.db") + ); + + fs::remove_dir_all(&base).unwrap(); + } + + #[test] + fn test_accepts_absolute_path() { + let dir = make_temp_dir(); + let abs = dir.join("exact.db"); + + let result = validate_database_path(&abs).unwrap(); + assert_eq!(result, dir.canonicalize().unwrap().join("exact.db")); + } + + #[test] + fn test_rejects_relative_path() { + let err = validate_database_path("relative.db").unwrap_err(); + assert!(matches!(err, Error::InvalidPath(_))); + } + + #[test] + fn test_rejects_absolute_path_with_parent_traversal() { + let dir = make_temp_dir(); + let abs_str = format!("{}/../escape.db", dir.to_str().unwrap()); + + let err = validate_database_path(&abs_str).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_absolute_path_with_embedded_traversal() { + let dir = make_temp_dir(); + let abs_str = format!("{}/sub/../../escape.db", dir.to_str().unwrap()); + + let err = validate_database_path(&abs_str).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_null_byte() { + let err = validate_database_path("path\0evil").unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } +} From 5aabb86ad306b72cd50857f8d37d5b3cc399f0d3 Mon Sep 17 00:00:00 2001 From: Andrew de Waal Date: Mon, 15 Jun 2026 08:48:47 -0700 Subject: [PATCH 2/2] refactor: support closing all connections to a db in Rust Previously, the only way to close the connection to a database was through a command invocation, which is only available through IPC. We need the ability to close any open connections to a db in Rust. For example if we are performing file operations (such as deleting and recreating the db) from the Rust layer, we need to first ensure all connections are closed before taking any further action. This simple refactor breaks the logic to close any open connections into a public function callable from Rust. --- crates/sqlx-sqlite-toolkit/src/lib.rs | 2 +- .../sqlx-sqlite-toolkit/src/transactions.rs | 48 +++++++++++++++ .../tests/transaction_state_tests.rs | 54 +++++++++++++++++ src/commands.rs | 27 +++++---- src/lib.rs | 59 +++++++++++++++++++ 5 files changed, 176 insertions(+), 14 deletions(-) diff --git a/crates/sqlx-sqlite-toolkit/src/lib.rs b/crates/sqlx-sqlite-toolkit/src/lib.rs index 201ba2d..627801c 100644 --- a/crates/sqlx-sqlite-toolkit/src/lib.rs +++ b/crates/sqlx-sqlite-toolkit/src/lib.rs @@ -46,7 +46,7 @@ pub use error::{Error, Result}; pub use pagination::{KeysetColumn, KeysetPage, SortDirection}; pub use transactions::{ ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, - Statement, TransactionWriter, cleanup_all_transactions, + Statement, TransactionWriter, cleanup_all_transactions, cleanup_transactions_for_db, }; pub use wrapper::{ DatabaseWrapper, InterruptibleTransaction, InterruptibleTransactionBuilder, diff --git a/crates/sqlx-sqlite-toolkit/src/transactions.rs b/crates/sqlx-sqlite-toolkit/src/transactions.rs index 559d9eb..de91965 100644 --- a/crates/sqlx-sqlite-toolkit/src/transactions.rs +++ b/crates/sqlx-sqlite-toolkit/src/transactions.rs @@ -389,6 +389,24 @@ impl ActiveInterruptibleTransactions { } } + /// Roll back and remove the interruptible transaction for a single database, if any. + pub async fn abort_for_db(&self, db_key: &str) { + let maybe_tx = { + let mut txs = self.inner.lock().await; + txs.remove(db_key) + }; + + if let Some(tx) = maybe_tx { + debug!( + "Rolling back interruptible transaction for database: {}", + db_key + ); + if let Err(err) = tx.rollback().await { + warn!("rollback during abort_for_db failed (db: {db_key}): {err}"); + } + } + } + /// Remove and return transaction for commit/rollback. /// /// Returns `Err(Error::TransactionTimedOut)` if the transaction has exceeded the @@ -461,6 +479,26 @@ impl ActiveRegularTransactions { txs.clear(); } + + /// Abort in-flight regular transactions for a single database. + /// + /// Tracking entries are left in place until the cancelled task finishes and + /// removes itself. + pub async fn abort_for_db(&self, db_key: &str) { + let prefix = format!("{db_key}:"); + let handles: Vec<(String, AbortHandle)> = { + let txs = self.0.read().await; + txs.iter() + .filter(|(key, _)| key.starts_with(&prefix)) + .map(|(key, handle)| (key.clone(), handle.clone())) + .collect() + }; + + for (key, abort_handle) in handles { + debug!("Aborting regular transaction: {}", key); + abort_handle.abort(); + } + } } /// Cleanup all transactions on app exit. @@ -475,3 +513,13 @@ pub async fn cleanup_all_transactions( debug!("Transaction cleanup initiated"); } + +pub async fn cleanup_transactions_for_db( + db_key: &str, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result<()> { + interruptible_txs.abort_for_db(db_key).await; + regular_txs.abort_for_db(db_key).await; + Ok(()) +} diff --git a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs index 7ac675d..72fdc69 100644 --- a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs +++ b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs @@ -141,6 +141,34 @@ async fn test_abort_all_clears_transactions() { assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); } +#[tokio::test] +async fn test_abort_for_db_clears_only_matching_interruptible() { + let (db1, _temp1) = create_test_db("main.db").await; + let (db2, _temp2) = create_test_db("other.db").await; + + for db in [&db1, &db2] { + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + } + + let state = ActiveInterruptibleTransactions::default(); + let main_tx = begin_transaction(&db1, "main").await; + let main_tx_id = main_tx.transaction_id().to_string(); + let other_tx = begin_transaction(&db2, "other").await; + let other_tx_id = other_tx.transaction_id().to_string(); + + state.insert("main".into(), main_tx).await.unwrap(); + state.insert("other".into(), other_tx).await.unwrap(); + + state.abort_for_db("main").await; + + let err = expect_err(state.remove("main", &main_tx_id).await); + assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); + + assert!(state.remove("other", &other_tx_id).await.is_ok()); +} + #[tokio::test] async fn test_abort_all_auto_rollbacks_uncommitted_writes() { let (db, _temp) = create_test_db("rollback.db").await; @@ -329,6 +357,32 @@ async fn test_regular_abort_all_clears_state() { state.insert("a".into(), h3.abort_handle()).await; } +#[tokio::test] +async fn test_regular_abort_for_db_only_matching_prefix() { + let state = ActiveRegularTransactions::default(); + + let main_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + let other_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + state + .insert("main:one".into(), main_handle.abort_handle()) + .await; + state + .insert("other:two".into(), other_handle.abort_handle()) + .await; + + state.abort_for_db("main").await; + + assert!(main_handle.await.unwrap_err().is_cancelled()); + assert!(!other_handle.is_finished()); + + state.remove("main:one").await; +} + // ============================================================================ // cleanup_all_transactions tests // ============================================================================ diff --git a/src/commands.rs b/src/commands.rs index 438ad9a..b315618 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -19,13 +19,13 @@ use tauri::{AppHandle, Runtime, State}; use tracing::debug; use uuid::Uuid; -use crate::connect_to_database; use crate::{ DbInstances, Error, MigrationEvent, MigrationStates, Result, subscriptions::{ ActiveSubscriptions, ObserverConfigParams, TableChangePayload, event_to_payload, }, }; +use crate::{close_database, connect_to_database}; /// Token representing an active interruptible transaction #[derive(Debug, Clone, Serialize, Deserialize)] @@ -323,27 +323,28 @@ pub async fn fetch_page( Ok(result) } -/// Close a specific database connection +/// Close the loaded instance for a registered database key. /// /// Returns `true` if the database was loaded and successfully closed. /// Returns `false` if the database was not loaded (nothing to close). -/// Any active subscriptions for this database are aborted before closing. +/// Active subscriptions and in-flight transactions for this key are aborted +/// before closing (10 second deadline for transactions). #[tauri::command] pub async fn close( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, + interruptible_txs: State<'_, ActiveInterruptibleTransactions>, + regular_txs: State<'_, ActiveRegularTransactions>, db_key: String, ) -> Result { - active_subs.remove_for_db(&db_key).await; - - let mut instances = db_instances.inner.write().await; - - if let Some(wrapper) = instances.remove(&db_key) { - wrapper.close().await?; - Ok(true) - } else { - Ok(false) // Database wasn't loaded - } + close_database( + &db_key, + &db_instances, + &active_subs, + &interruptible_txs, + ®ular_txs, + ) + .await } /// Close all database connections diff --git a/src/lib.rs b/src/lib.rs index f77b675..c72d209 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,8 @@ pub use sqlx_sqlite_toolkit::{ TransactionExecutionBuilder, WriteQueryResult, }; +use crate::subscriptions::ActiveSubscriptions; + /// Default maximum number of concurrently loaded databases. const DEFAULT_MAX_DATABASES: usize = 50; @@ -889,6 +891,14 @@ pub trait Connection { database_key: &str, config: SqliteDatabaseConfig, ) -> impl Future> + Send; + + /// Close the loaded instance for a registered database key. + /// + /// Returns `true` if the database was loaded and successfully closed. + /// Returns `false` if the database was not loaded (nothing to close). + /// Active subscriptions and in-flight transactions for this key are aborted + /// before closing (10 second deadline for transactions). + fn close(&self, database_key: &str) -> impl Future> + Send; } /// Delegates to [`connect_to_database`]: same open path as the `load` IPC command. @@ -906,6 +916,34 @@ impl Connection for AppHandle { let response = connect_to_database(self, database_key, Some(config)).await?; Ok(response.wrapper) } + + async fn close(&self, database_key: &str) -> Result { + let instances = self + .try_state::() + .ok_or(Error::Other("DbInstances state not found".to_string()))?; + let subs = self.try_state::().ok_or(Error::Other( + "ActiveSubscriptions state not found".to_string(), + ))?; + let interruptible_txs = + self + .try_state::() + .ok_or(Error::Other( + "ActiveInterruptibleTransactions state not found".to_string(), + ))?; + let regular_txs = self + .try_state::() + .ok_or(Error::Other( + "ActiveRegularTransactions state not found".to_string(), + ))?; + close_database( + database_key, + &instances, + &subs, + &interruptible_txs, + ®ular_txs, + ) + .await + } } struct ConnectionResponse { @@ -1004,6 +1042,27 @@ async fn await_migrations(migration_states: &MigrationStates, db_key: &str) -> R } } +pub(crate) async fn close_database( + db_key: &str, + db_instances: &DbInstances, + active_subs: &ActiveSubscriptions, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result { + active_subs.remove_for_db(db_key).await; + + sqlx_sqlite_toolkit::cleanup_transactions_for_db(db_key, interruptible_txs, regular_txs).await?; + + let mut instances = db_instances.inner.write().await; + + if let Some(wrapper) = instances.remove(db_key) { + wrapper.close().await?; + Ok(true) + } else { + Ok(false) // Database wasn't loaded + } +} + /// Resolve a registered database path by key. /// /// The `db_key` must match a key registered via