diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..330fe9c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,40 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +## [0.2.0] - Unreleased + +### Breaking Changes + +#### Database registration by key + +Databases must be registered on the Rust side with a stable key before they can be opened. The frontend and Rust callers open databases by **key**, not filesystem path. + +- **`Builder::add_migrations(path, migrator)`** replaced by **`Builder::register_database(key, path, migrator?)`**, which returns `Result`. +- Added **`Builder::on_setup`** and **`SetupRegistrar`** for runtime path registration (e.g. from `app.path().app_data_dir()`). +- **`Builder`** is now generic over Tauri `Runtime`; use `Builder::::new()` where a turbofish is required. + +#### Frontend API + +- First argument to **`Database.load()`** and **`Database.get()`** is a registration **key**, not a path. Passing a path string type-checks but fails at runtime with `PATH_NOT_REGISTERED`. +- **`Database.get(dbKey)`** is synchronous again; it defers connection until the first operation that requires a loaded database. Use **`Database.load(dbKey, customConfig?)`** to connect eagerly or pass pool configuration. +- IPC command arguments: **`db`** → **`dbKey`**; attached **`databasePath`** → **`databaseKey`**. +- **`MigrationEvent`**: adds **`dbKey`**; **`dbPath`** is now an absolute path. +- **`TransactionToken`**: **`dbPath`** → **`dbKey`**. + +#### Rust API + +- Added **`Connection`** trait on **`AppHandle`** for Rust-side opens by registration key. +- **`Builder::build()`** returns **`Result`**; duplicate paths across distinct registration keys fail with **`INVALID_CONFIG`**. +- File paths must be **absolute** at registration; relative path resolution at load time was removed. +- Invalid registration paths fail at startup (`INVALID_PATH`, `PATH_TRAVERSAL`); unregistered keys fail at open time (`PATH_NOT_REGISTERED`). + +### Added + +- Path validation module (`validate.rs`) with canonicalization at registration time. +- Parent directory auto-creation during registration validation for file paths. +- CI check that committed `api-iife.js` matches a fresh Rollup build. + +### Fixed + +- Regenerated `api-iife.js` so all IPC calls use `dbKey` consistently. diff --git a/Cargo.toml b/Cargo.toml index 7271450..a5864fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,5 +47,6 @@ futures = "0.3.31" tauri-plugin = { version = "2.5.1", features = ["build"] } [dev-dependencies] +tauri = { version = "2.9.3", features = ["test"] } tempfile = "3.23.0" tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] } diff --git a/README.md b/README.md index dc3aa6b..5c013bc 100644 --- a/README.md +++ b/README.md @@ -119,12 +119,97 @@ Register the plugin in your Tauri application: ```rust fn main() { tauri::Builder::default() - .plugin(tauri_plugin_sqlite::Builder::new().build()) + .plugin(tauri_plugin_sqlite::Builder::new().build().expect("failed to build sqlite plugin")) .run(tauri::generate_context!()) .expect("error while running tauri application"); } ``` +### Registering Databases + +Every database must be **registered** on the Rust side before it can be opened. +Registration assigns a stable **key** (for example `"MAIN"`) to a filesystem path or +in-memory URI. The frontend and Rust callers open databases by **key**, not by path. + +**Registration rules** (enforced when you call `register_database`): + + * File paths must be **absolute**, with no `..` components or null bytes + * File paths are **canonicalized** once at registration (symlink-safe when the path or + parent exists) + * In-memory URIs (`:memory:`, `file::memory:*`, and `file:` URIs with an exact + `mode=memory` query parameter) are accepted as-is + * Each registration **key** must map to a **distinct** database path; two keys for the + same path fail with `INVALID_CONFIG` when calling `build()` or during plugin setup + after `on_setup` merges registrations + +Invalid registration paths fail at startup with `INVALID_PATH` or `PATH_TRAVERSAL`. + +**Why keys:** + +1) `Database.load()` is callable from the frontend over IPC. The frontend sends + only a registration key; unregistered keys are rejected with `PATH_NOT_REGISTERED`. + This prevents untrusted frontend code from opening arbitrary files on disk. + +2) Keys avoid cross-language path string mismatches. With path-at-load, TS and Rust would + need identical canonical path strings on every open (slashes, symlinks, etc.). Keys + resolve the path once at registration; all later opens use the plain string key. + +3) Without registration keys, every call site would repeat that path discovery or keep its + own `PathBuf`. Registration stores the key-to-path mapping once; `connect` reuses the + key so callers do not supply a filesystem path on every open. + On mobile, path discovery is not a cheap string join. Resolvers such as + [tauri-plugin-fs-resolver](https://github.com/silvermine/tauri-plugin-fs-resolver) + call platform-native APIs so paths match OS sandbox rules. On Android that means a + JNI call into Kotlin `Context` (e.g. `getFilesDir()`) on each resolve — noticeably + more expensive than a local HashMap lookup, and a different kind of boundary than + TypeScript-to-Rust IPC (in-process JNI vs webview bridge). Register the resolved + `PathBuf` once in `on_setup`; every later `connect(database_key)` only looks up that + key in [`RegisteredDatabases`] — no repeat native or JNI work. + +Because legitimate paths usually depend on runtime values (app data directory, platform +path resolvers, etc.), registration normally happens in the `on_setup` hook: + +```rust +use tauri_plugin_sqlite::Builder; +use tauri::Manager; + +const MAIN_DB_KEY: &str = "MAIN"; + +fn main() { + tauri::Builder::default() + .plugin( + Builder::new() + .on_setup(|app, reg| { + let db = app.path().app_data_dir()?.join("main.db"); + reg.register_database(MAIN_DB_KEY, db, None)?; + Ok(()) + }) + .build() + .expect("failed to build sqlite plugin") + ) + .run(tauri::generate_context!()) + .expect("error while running tauri application"); +} +``` + +A static registration (compile-time path) can be registered on the builder directly: + +```rust +use tauri_plugin_sqlite::Builder; +use std::path::PathBuf; + +const MAIN_DB_KEY: &str = "MAIN"; + +# fn main() -> tauri_plugin_sqlite::Result<()> { +let _plugin = Builder::new() + .register_database(MAIN_DB_KEY, PathBuf::from("/var/lib/myapp/main.db"), None)? + .build()?; +# Ok(()) +# } +``` + +The frontend then calls `Database.load(MAIN_DB_KEY)` (see [Connecting](#connecting)). + ### Migrations This plugin uses [SQLx's migration system][sqlx-migrate]. Create numbered `.sql` @@ -139,23 +224,40 @@ src-tauri/migrations/ └── 0003_create_posts.sql ``` -Register migrations using SQLx's `migrate!()` macro, which embeds them at compile time: +Register migrations using SQLx's `migrate!()` macro, which embeds them at compile time. +Pass the migrator as the third argument to `register_database`. The `on_setup` hook is the +usual place to register app-derived paths: ```rust use tauri_plugin_sqlite::Builder; +use tauri::Manager; + +const MAIN_DB_KEY: &str = "MAIN"; fn main() { tauri::Builder::default() .plugin( Builder::new() - .add_migrations("main.db", sqlx::migrate!("./migrations")) + .on_setup(|app, reg| { + let db = app.path().app_data_dir()?.join("main.db"); + reg.register_database( + MAIN_DB_KEY, + db, + Some(sqlx::migrate!("./migrations")), + )?; + Ok(()) + }) .build() + .expect("failed to build sqlite plugin") ) .run(tauri::generate_context!()) .expect("error while running tauri application"); } ``` +The frontend must call `Database.load()` with the same **registration key** so migrations +are awaited correctly. + **Timing:** Migrations start automatically at plugin setup (non-blocking). When TypeScript calls `Database.load()`, it waits for migrations to complete before returning. If migrations fail, `load()` returns an error. Applied migrations are @@ -168,12 +270,15 @@ Use `getMigrationEvents()` to retrieve cached events: ```typescript import Database from '@silvermine/tauri-plugin-sqlite'; -const db = await Database.load('mydb.db'); +const MAIN_DB_KEY = 'MAIN'; + +// Same registration key used in Rust `register_database` +const db = await Database.load(MAIN_DB_KEY); // Get all migration events (including ones emitted before listener could be registered) const events = await db.getMigrationEvents(); for (const event of events) { - console.info(`${event.status}: ${event.dbPath}`); + console.info(`${event.status}: ${event.dbKey} (${event.dbPath})`); if (event.status === 'failed') { console.error(`Migration error: ${event.error}`); } @@ -188,29 +293,41 @@ import { listen } from '@tauri-apps/api/event'; import type { MigrationEvent } from '@silvermine/tauri-plugin-sqlite'; await listen('sqlite:migration', (event) => { - const { dbPath, status, migrationCount, error } = event.payload; - console.info(`Migration ${status} for ${dbPath}: ${migrationCount} migrations`, error); + const { dbKey, dbPath, status, migrationCount, error } = event.payload; + console.info(`Migration ${status} for ${dbKey} (${dbPath}): ${migrationCount} migrations`, error); }); ``` ### Connecting +Pass the **registration key** from Rust `register_database` (see +[Registering Databases](#registering-databases)). + ```typescript import Database from '@silvermine/tauri-plugin-sqlite'; -// Path is relative to app config directory (no sqlite: prefix needed) -let db = await Database.load('mydb.db'); +const MAIN_DB_KEY = 'MAIN'; + +// Connect (no sqlite: prefix needed) +let db = await Database.load(MAIN_DB_KEY); // With custom configuration -db = await Database.load('mydb.db', { +db = await Database.load(MAIN_DB_KEY, { maxReadConnections: 10, // default: 6 idleTimeoutSecs: 60 // default: 30 }); -// Lazy initialization (connects on first query) -db = Database.get('mydb.db'); +// Lazy initialization (connects on first query; sync — no await) +db = Database.get(MAIN_DB_KEY); + +// In-memory: register first, then load by key +// reg.register_database('MEM', ':memory:', None)? in on_setup +const mem = await Database.load('MEM'); ``` +An unregistered key throws `PATH_NOT_REGISTERED`. Invalid paths are rejected at +registration time on the Rust side (`INVALID_PATH`, `PATH_TRAVERSAL`). + ### Parameter Binding All query methods use `$1`, `$2`, etc. syntax with `SqlValue` types: @@ -400,16 +517,20 @@ tables. **Builder Pattern:** All query methods (`execute`, `executeTransaction`, `fetchAll`, `fetchOne`, `fetchPage`) return builders that support `.attach()` -for cross-database operations. +for cross-database operations. Each attached database must already be loaded and is +identified by its **registration key** (see +[Registering Databases](#registering-databases)). ```typescript +const ORDERS_DB_KEY = 'ORDERS'; + // Join data from multiple databases const results = await db.fetchAll( 'SELECT u.name, o.total FROM users u JOIN orders.orders o ON u.id = o.user_id', [] ).attach([ { - databasePath: 'orders.db', + databaseKey: ORDERS_DB_KEY, schemaName: 'orders', mode: 'readOnly' } @@ -422,14 +543,13 @@ await db.execute( ['archived'] ).attach([ { - databasePath: 'archive.db', + databaseKey: 'ARCHIVE', schemaName: 'archive', mode: 'readOnly' } ]); // Atomic writes across multiple databases -// Assuming userId and total are defined in your application context const userId = 123; const total = 99.99; @@ -438,7 +558,7 @@ await db.executeTransaction([ ['UPDATE stats.order_count SET count = count + 1', []] ]).attach([ { - databasePath: 'stats.db', + databaseKey: 'STATS', schemaName: 'stats', mode: 'readWrite' } @@ -536,7 +656,9 @@ Common error codes: * `SQLITE_CONSTRAINT` - Constraint violation (unique, foreign key, etc.) * `SQLITE_NOTFOUND` - Table or column not found * `DATABASE_NOT_LOADED` - Database hasn't been loaded yet - * `INVALID_PATH` - Invalid database path + * `INVALID_PATH` - Invalid path at registration (relative or failed canonicalization) + * `PATH_NOT_REGISTERED` - Registration key not found + * `PATH_TRAVERSAL` - Registration path contains `..` or null bytes * `IO_ERROR` - File system error * `MIGRATION_ERROR` - Migration failed * `MULTIPLE_ROWS_RETURNED` - `fetchOne()` returned multiple rows @@ -557,8 +679,8 @@ await db.remove(); // Close and DELETE database file(s) - irreversible | Method | Description | | ------ | ----------- | -| `Database.load(path, config?)` | Connect and return Database instance (or existing) | -| `Database.get(path)` | Get instance without connecting (lazy init) | +| `Database.load(dbKey, config?)` | Connect eagerly and return Database instance (or existing) | +| `Database.get(dbKey)` | Sync handle; connects on first query (no `customConfig`) | | `Database.close_all()` | Close all database connections | ### Instance Methods @@ -619,7 +741,7 @@ interface CustomConfig { } interface AttachedDatabaseSpec { - databasePath: string; // Path relative to app config directory + databaseKey: string; // Registration key of a database already loaded via load() schemaName: string; // Schema name for accessing tables (e.g., 'orders') mode: 'readOnly' | 'readWrite'; } @@ -672,25 +794,44 @@ type TableChangeEvent = ## Rust-Only API -For Rust code that needs direct database access without going through Tauri commands, -use `DatabaseWrapper`. - -### Setup (Rust) +For Rust code in a Tauri app, register databases first, then open by key using the +[`Connection`](src/lib.rs) trait on `AppHandle`. This uses the same open path as the +frontend `load` command (`connect_to_database`). -```rust -use tauri_plugin_sqlite::DatabaseWrapper; -use std::path::PathBuf; +For standalone Rust projects without the plugin, use `DatabaseWrapper::connect(path)` +from [`sqlx-sqlite-toolkit`](crates/sqlx-sqlite-toolkit/) directly (no registration). -// Load a database -let mut db = DatabaseWrapper::load(PathBuf::from("/path/to/mydb.db"), None).await?; +### Setup (Tauri plugin) -// With custom configuration -use tauri_plugin_sqlite::CustomConfig; -let config = CustomConfig { - max_read_connections: Some(10), - idle_timeout_secs: Some(60), -}; -db = DatabaseWrapper::load(PathBuf::from("/path/to/mydb.db"), Some(config)).await?; +```rust +use tauri::{Manager, Runtime}; +use tauri_plugin_sqlite::{Builder, Connection, SqliteDatabaseConfig}; + +const MAIN_DB_KEY: &str = "MAIN"; + +// In lib.rs setup — register key + path +Builder::new() + .on_setup(|app, reg| { + let db = app.path().app_data_dir()?.join("main.db"); + reg.register_database(MAIN_DB_KEY, db, None)?; + Ok(()) + }) + .build()?; + +// Anywhere with AppHandle +async fn example(app: tauri::AppHandle) -> tauri_plugin_sqlite::Result<()> { + let db = app.connect(MAIN_DB_KEY).await?; + + let db = app.connect_with_config( + MAIN_DB_KEY, + SqliteDatabaseConfig { + max_read_connections: 10, + idle_timeout_secs: 60, + }, + ).await?; + + Ok(()) +} ``` ### Basic Operations @@ -824,17 +965,17 @@ tx.commit().await?; ### Cross-Database Operations -Attach other databases for cross-database queries. For Rust API usage, you need to load -both databases first, then create `AttachedSpec` instances using their inner database -references: +Attach other databases for cross-database queries. Load each database by registration +key first (`app.connect("STATS").await?`), then create `AttachedSpec` instances using +their inner database references: ```rust -use tauri_plugin_sqlite::{DatabaseWrapper, AttachedSpec, AttachedMode}; +use tauri_plugin_sqlite::{Connection, AttachedSpec, AttachedMode}; use std::sync::Arc; -// Load both databases -let main_db = DatabaseWrapper::load("/path/to/main.db".into(), None).await?; -let stats_db = DatabaseWrapper::load("/path/to/stats.db".into(), None).await?; +// After registering and connecting both databases by key +let main_db = app.connect("MAIN").await?; +let stats_db = app.connect("STATS").await?; // Create attached spec using the inner database reference let stats_spec = AttachedSpec { @@ -853,8 +994,7 @@ let results = main_db.execute_transaction(vec![ println!("Cross-database transaction completed: {} statements", results.len()); // Interruptible transaction with attached database -// Load the inventory database -let inventory_db = DatabaseWrapper::load("/path/to/inventory.db".into(), None).await?; +let inventory_db = app.connect("INVENTORY").await?; // Create spec for inventory database let inv_spec = AttachedSpec { @@ -888,7 +1028,7 @@ db.remove().await?; // Close and DELETE database file(s) | Method | Description | | ------ | ----------- | -| `load(path, config?)` | Load database, returns `DatabaseWrapper` | +| `connect(path, config?)` | Open database by path, returns `DatabaseWrapper` | | `execute(query, values)` | Execute write query | | `execute_transaction(statements)` | Execute statements atomically (builder) | | `begin_interruptible_transaction()` | Begin interruptible transaction (builder) | @@ -936,7 +1076,7 @@ fn init_tracing() {} fn main() { init_tracing(); tauri::Builder::default() - .plugin(tauri_plugin_sqlite::Builder::new().build()) + .plugin(tauri_plugin_sqlite::Builder::new().build().expect("failed to build sqlite plugin")) .run(tauri::generate_context!()) .expect("error while running tauri application"); } @@ -985,9 +1125,10 @@ pagination to keep memory usage bounded on both the Rust and TypeScript sides. ### Path Validation -Database paths are validated to prevent directory traversal. Absolute paths, -`..` segments, and null bytes are rejected. All paths are resolved relative to -the app config directory. +Registration validates filesystem paths once (absolute, no traversal, canonicalized). +In-memory URIs are accepted as-is. At runtime, `Database.load(dbKey)` and +`Connection::connect(dbKey)` only accept **registered keys**; unknown keys return +`PATH_NOT_REGISTERED`. ## Development diff --git a/api-iife.js b/api-iife.js index 5fdfc7a..74ff231 100644 --- a/api-iife.js +++ b/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(t){"use strict";function e(t,e,s,i){if("function"==typeof e?t!==e||!i:!e.has(t))throw new TypeError("Cannot read private member from an object whose class did not declare it");return"m"===s?i:"a"===s?i.call(t):i?i.value:e.get(t)}function s(t,e,s,i,a){if("function"==typeof e||!e.has(t))throw new TypeError("Cannot write private member to an object whose class did not declare it");return e.set(t,s),s}var i,a,n,h;"function"==typeof SuppressedError&&SuppressedError;const r="__TAURI_TO_IPC_KEY__";class c{constructor(t){i.set(this,void 0),a.set(this,0),n.set(this,[]),h.set(this,void 0),s(this,i,t||(()=>{})),this.id=function(t,e=!1){return window.__TAURI_INTERNALS__.transformCallback(t,e)}(t=>{const r=t.index;if("end"in t)return void(r==e(this,a,"f")?this.cleanupCallback():s(this,h,r));const c=t.message;if(r==e(this,a,"f")){for(e(this,i,"f").call(this,c),s(this,a,e(this,a,"f")+1);e(this,a,"f")in e(this,n,"f");){const t=e(this,n,"f")[e(this,a,"f")];e(this,i,"f").call(this,t),delete e(this,n,"f")[e(this,a,"f")],s(this,a,e(this,a,"f")+1)}e(this,a,"f")===e(this,h,"f")&&this.cleanupCallback()}else e(this,n,"f")[r]=c})}cleanupCallback(){window.__TAURI_INTERNALS__.unregisterCallback(this.id)}set onmessage(t){s(this,i,t)}get onmessage(){return e(this,i,"f")}[(i=new WeakMap,a=new WeakMap,n=new WeakMap,h=new WeakMap,r)](){return`__CHANNEL__:${this.id}`}toJSON(){return this[r]()}}async function u(t,e={},s){return window.__TAURI_INTERNALS__.invoke(t,e,s)}class _{constructor(t,e){this._dbPath=t,this._transactionId=e}async read(t,e){return await u("plugin:sqlite|transaction_read",{token:{dbPath:this._dbPath,transactionId:this._transactionId},query:t,values:e??[]})}async continueWith(t){const e=await u("plugin:sqlite|transaction_continue",{token:{dbPath:this._dbPath,transactionId:this._transactionId},action:{type:"Continue",statements:t.map(([t,e])=>({query:t,values:e??[]}))}});return new _(e.dbPath,e.transactionId)}async commit(){await u("plugin:sqlite|transaction_continue",{token:{dbPath:this._dbPath,transactionId:this._transactionId},action:{type:"Commit"}})}async rollback(){await u("plugin:sqlite|transaction_continue",{token:{dbPath:this._dbPath,transactionId:this._transactionId},action:{type:"Rollback"}})}}class l{constructor(t){this._subscriptionId=t}get id(){return this._subscriptionId}async unsubscribe(){return await u("plugin:sqlite|unsubscribe",{subscriptionId:this._subscriptionId})}}class o{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|fetch_all",{db:this._db.path,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class d{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|fetch_one",{db:this._db.path,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class b{constructor(t,e,s,i,a){this._db=t,this._query=e,this._bindValues=s,this._keyset=i,this._pageSize=a,this._after=null,this._before=null,this._attached=[]}after(t){return this._after=t,this}before(t){return this._before=t,this}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|fetch_page",{db:this._db.path,query:this._query,values:this._bindValues,keyset:this._keyset,pageSize:this._pageSize,after:this._after,before:this._before,attached:this._attached.length>0?this._attached:null})}}class p{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){const[t,e]=await u("plugin:sqlite|execute",{db:this._db.path,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null});return{lastInsertId:e,rowsAffected:t}}}class f{constructor(t,e,s=[]){this._db=t,this._initialStatements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){const t=await u("plugin:sqlite|begin_interruptible_transaction",{db:this._db.path,initialStatements:this._initialStatements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null});return new _(t.dbPath,t.transactionId)}}class w{constructor(t,e,s=[]){this._db=t,this._statements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await u("plugin:sqlite|execute_transaction",{db:this._db.path,statements:this._statements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null})}}class y{constructor(t){this.path=t}static async load(t,e){const s=await u("plugin:sqlite|load",{db:t,customConfig:e});return new y(s)}static get(t){return new y(t)}static async close_all(){await u("plugin:sqlite|close_all")}execute(t,e){return new p(this,t,e??[])}executeTransaction(t){return new w(this,t)}fetchAll(t,e){return new o(this,t,e??[])}fetchOne(t,e){return new d(this,t,e??[])}fetchPage(t,e,s,i){return new b(this,t,e,s,i)}async observe(t,e){await u("plugin:sqlite|observe",{db:this.path,tables:t,config:e??null})}async subscribe(t,e){const s=new c;s.onmessage=e;const i=await u("plugin:sqlite|subscribe",{db:this.path,tables:t,onEvent:s});return new l(i)}async unobserve(){await u("plugin:sqlite|unobserve",{db:this.path})}async close(){return await u("plugin:sqlite|close",{db:this.path})}async remove(){return await u("plugin:sqlite|remove",{db:this.path})}beginInterruptibleTransaction(t){return new f(this,t)}async getMigrationEvents(){return await u("plugin:sqlite|get_migration_events",{db:this.path})}}return t.InterruptibleTransaction=_,t.Subscription=l,t.default=y,Object.defineProperty(t,"__esModule",{value:!0}),t}({});Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} +if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(t){"use strict";function e(t,e,s,i){if("function"==typeof e?t!==e||!i:!e.has(t))throw new TypeError("Cannot read private member from an object whose class did not declare it");return"m"===s?i:"a"===s?i.call(t):i?i.value:e.get(t)}function s(t,e,s,i,a){if("function"==typeof e||!e.has(t))throw new TypeError("Cannot write private member to an object whose class did not declare it");return e.set(t,s),s}var i,a,n,r;"function"==typeof SuppressedError&&SuppressedError;const h="__TAURI_TO_IPC_KEY__";class c{constructor(t){i.set(this,void 0),a.set(this,0),n.set(this,[]),r.set(this,void 0),s(this,i,t||(()=>{})),this.id=function(t,e=!1){return window.__TAURI_INTERNALS__.transformCallback(t,e)}(t=>{const h=t.index;if("end"in t)return void(h==e(this,a,"f")?this.cleanupCallback():s(this,r,h));const c=t.message;if(h==e(this,a,"f")){for(e(this,i,"f").call(this,c),s(this,a,e(this,a,"f")+1);e(this,a,"f")in e(this,n,"f");){const t=e(this,n,"f")[e(this,a,"f")];e(this,i,"f").call(this,t),delete e(this,n,"f")[e(this,a,"f")],s(this,a,e(this,a,"f")+1)}e(this,a,"f")===e(this,r,"f")&&this.cleanupCallback()}else e(this,n,"f")[h]=c})}cleanupCallback(){window.__TAURI_INTERNALS__.unregisterCallback(this.id)}set onmessage(t){s(this,i,t)}get onmessage(){return e(this,i,"f")}[(i=new WeakMap,a=new WeakMap,n=new WeakMap,r=new WeakMap,h)](){return`__CHANNEL__:${this.id}`}toJSON(){return this[h]()}}async function u(t,e={},s){return window.__TAURI_INTERNALS__.invoke(t,e,s)}const _=new WeakMap;async function l(t){if(""!==t.path)return;let e=_.get(t);e||(e=u("plugin:sqlite|load",{dbKey:t.key}).then(e=>{t.path=e}),_.set(t,e)),await e}class o{constructor(t,e){this._dbKey=t,this._transactionId=e}async read(t,e){return await u("plugin:sqlite|transaction_read",{token:{dbKey:this._dbKey,transactionId:this._transactionId},query:t,values:e??[]})}async continueWith(t){const e=await u("plugin:sqlite|transaction_continue",{token:{dbKey:this._dbKey,transactionId:this._transactionId},action:{type:"Continue",statements:t.map(([t,e])=>({query:t,values:e??[]}))}});return new o(e.dbKey,e.transactionId)}async commit(){await u("plugin:sqlite|transaction_continue",{token:{dbKey:this._dbKey,transactionId:this._transactionId},action:{type:"Commit"}})}async rollback(){await u("plugin:sqlite|transaction_continue",{token:{dbKey:this._dbKey,transactionId:this._transactionId},action:{type:"Rollback"}})}}class d{constructor(t){this._subscriptionId=t}get id(){return this._subscriptionId}async unsubscribe(){return await u("plugin:sqlite|unsubscribe",{subscriptionId:this._subscriptionId})}}class b{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|fetch_all",{dbKey:this._db.key,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class y{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|fetch_one",{dbKey:this._db.key,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null})}}class w{constructor(t,e,s,i,a){this._db=t,this._query=e,this._bindValues=s,this._keyset=i,this._pageSize=a,this._after=null,this._before=null,this._attached=[]}after(t){return this._after=t,this}before(t){return this._before=t,this}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|fetch_page",{dbKey:this._db.key,query:this._query,values:this._bindValues,keyset:this._keyset,pageSize:this._pageSize,after:this._after,before:this._before,attached:this._attached.length>0?this._attached:null})}}class p{constructor(t,e,s,i=[]){this._db=t,this._query=e,this._bindValues=s,this._attached=i}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){await l(this._db);const[t,e]=await u("plugin:sqlite|execute",{dbKey:this._db.key,query:this._query,values:this._bindValues,attached:this._attached.length>0?this._attached:null});return{lastInsertId:e,rowsAffected:t}}}class f{constructor(t,e,s=[]){this._db=t,this._initialStatements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){await l(this._db);const t=await u("plugin:sqlite|begin_interruptible_transaction",{dbKey:this._db.key,initialStatements:this._initialStatements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null});return new o(t.dbKey,t.transactionId)}}class g{constructor(t,e,s=[]){this._db=t,this._statements=e,this._attached=s}attach(t){return this._attached=t,this}then(t,e){return this._execute().then(t,e)}async _execute(){return await l(this._db),await u("plugin:sqlite|execute_transaction",{dbKey:this._db.key,statements:this._statements.map(([t,e])=>({query:t,values:e??[]})),attached:this._attached.length>0?this._attached:null})}}class q{constructor(t,e){this.key=t,this.path=e}static async load(t,e){const s=await u("plugin:sqlite|load",{dbKey:t,customConfig:e});return new q(t,s)}static get(t){return new q(t,"")}static async close_all(){await u("plugin:sqlite|close_all")}execute(t,e){return new p(this,t,e??[])}executeTransaction(t){return new g(this,t)}fetchAll(t,e){return new b(this,t,e??[])}fetchOne(t,e){return new y(this,t,e??[])}fetchPage(t,e,s,i){return new w(this,t,e,s,i)}async observe(t,e){await l(this),await u("plugin:sqlite|observe",{dbKey:this.key,tables:t,config:e??null})}async subscribe(t,e){await l(this);const s=new c;s.onmessage=e;const i=await u("plugin:sqlite|subscribe",{dbKey:this.key,tables:t,onEvent:s});return new d(i)}async unobserve(){await l(this),await u("plugin:sqlite|unobserve",{dbKey:this.key})}async close(){return await u("plugin:sqlite|close",{dbKey:this.key})}async remove(){return await u("plugin:sqlite|remove",{dbKey:this.key})}beginInterruptibleTransaction(t){return new f(this,t)}async getMigrationEvents(){return await u("plugin:sqlite|get_migration_events",{dbKey:this.key})}}return t.InterruptibleTransaction=o,t.Subscription=d,t.default=q,Object.defineProperty(t,"__esModule",{value:!0}),t}({});Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} diff --git a/crates/sqlx-sqlite-conn-mgr/README.md b/crates/sqlx-sqlite-conn-mgr/README.md index 578dddf..2b2c087 100644 --- a/crates/sqlx-sqlite-conn-mgr/README.md +++ b/crates/sqlx-sqlite-conn-mgr/README.md @@ -89,7 +89,8 @@ Migrations are tracked in `_sqlx_migrations` — calling `run_migrations()` mult times is safe (already-applied migrations are skipped). > **Note:** When using the Tauri plugin, migrations are handled automatically via -> `Builder::add_migrations()`. The plugin starts migrations at setup and waits for +> `Builder::register_database(..., Some(migrator))`. The plugin starts migrations at setup +> and waits for > completion when `load()` is called. ### Attached Databases diff --git a/crates/sqlx-sqlite-conn-mgr/src/config.rs b/crates/sqlx-sqlite-conn-mgr/src/config.rs index a8ab964..4baf501 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/config.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/config.rs @@ -25,6 +25,7 @@ use serde::{Deserialize, Serialize}; /// }; /// ``` #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct SqliteDatabaseConfig { /// Maximum number of concurrent read connections /// diff --git a/crates/sqlx-sqlite-conn-mgr/src/lib.rs b/crates/sqlx-sqlite-conn-mgr/src/lib.rs index c4a15b7..d3bc2b3 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/lib.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/lib.rs @@ -80,5 +80,7 @@ pub use write_guard::WriteGuard; // Re-export sqlx migrate types for convenience pub use sqlx::migrate::Migrator; +pub use registry::{canonicalize_database_path, is_memory_database}; + /// A type alias for Results with our custom Error type pub type Result = std::result::Result; diff --git a/crates/sqlx-sqlite-conn-mgr/src/registry.rs b/crates/sqlx-sqlite-conn-mgr/src/registry.rs index fb98009..f3915ff 100644 --- a/crates/sqlx-sqlite-conn-mgr/src/registry.rs +++ b/crates/sqlx-sqlite-conn-mgr/src/registry.rs @@ -16,14 +16,23 @@ fn registry() -> &'static RwLock>> { DATABASE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new())) } +/// Returns true when a `file:` URI query string contains an exact `mode=memory` parameter. +fn file_uri_has_mode_memory(path_str: &str) -> bool { + let Some(query) = path_str.split_once('?').map(|(_, query)| query) else { + return false; + }; + query.split('&').any(|param| param == "mode=memory") +} + /// Check if a path represents an in-memory SQLite database /// -/// Returns true for `:memory:` and `file::memory:*` URIs +/// Returns true for `:memory:` and `file::memory:*` URIs, and for `file:` URIs whose +/// query string includes a `mode=memory` parameter (not merely a substring match). pub fn is_memory_database(path: &Path) -> bool { let path_str = path.to_str().unwrap_or(""); path_str == ":memory:" || path_str.starts_with("file::memory:") - || path_str.contains("mode=memory") + || (path_str.starts_with("file:") && file_uri_has_mode_memory(path_str)) } /// Get or open a SQLite database connection @@ -44,7 +53,7 @@ where } // Canonicalize the path for consistent lookups - let canonical_path = canonicalize_path(path)?; + let canonical_path = canonicalize_database_path(path)?; // Try to get existing database with read lock (allows concurrent reads) { @@ -82,10 +91,9 @@ where Ok(arc_db) } -/// Helper to canonicalize a database path +/// Canonicalize a database file path for consistent registry lookups. /// -/// This function attempts to resolve paths to their canonical form to ensure -/// consistent cache lookups. It handles: +/// This function attempts to resolve paths to their canonical form. It handles: /// - Absolute path resolution /// - Symlink resolution (when file exists) /// - Parent directory canonicalization (when file doesn't exist yet) @@ -97,7 +105,7 @@ where /// least until the file is created and can be canonicalized properly. /// - Symlinks in filename: If the filename itself will be a symlink (rare for SQLite), /// different symlink names won't be resolved until the file exists. -fn canonicalize_path(path: &Path) -> std::io::Result { +pub fn canonicalize_database_path(path: &Path) -> std::io::Result { match path.canonicalize() { Ok(p) => Ok(p), Err(_) => { @@ -132,7 +140,7 @@ pub async fn uncache_database(path: &Path) -> std::io::Result<()> { } // Canonicalize path - let canonical_path = canonicalize_path(path)?; + let canonical_path = canonicalize_database_path(path)?; let mut registry = registry().write().await; registry.remove(&canonical_path); @@ -149,12 +157,12 @@ mod tests { let test_path = temp_dir.join("test.db"); // Test that path is canonicalized to absolute path - let canonical = canonicalize_path(&test_path).unwrap(); + let canonical = canonicalize_database_path(&test_path).unwrap(); assert!(canonical.is_absolute()); // Test relative path let relative_path = Path::new("./test_relative.db"); - let canonical_relative = canonicalize_path(relative_path).unwrap(); + let canonical_relative = canonicalize_database_path(relative_path).unwrap(); assert!(canonical_relative.is_absolute()); } @@ -164,7 +172,22 @@ mod tests { let nonexistent = temp_dir.join("nonexistent_dir").join("test.db"); // Should fail if parent directory doesn't exist - let result = canonicalize_path(&nonexistent); + let result = canonicalize_database_path(&nonexistent); assert!(result.is_err()); } + + #[test] + fn test_mode_memory_query_param() { + assert!(is_memory_database(Path::new("file:test?mode=memory"))); + assert!(is_memory_database(Path::new( + "file:/data/db?cache=shared&mode=memory" + ))); + } + + #[test] + fn test_mode_memory_substring_in_value_is_not_memory() { + assert!(!is_memory_database(Path::new( + "file:/home/user/real.db?x=mode=memory" + ))); + } } diff --git a/crates/sqlx-sqlite-toolkit/src/lib.rs b/crates/sqlx-sqlite-toolkit/src/lib.rs index 201ba2d..627801c 100644 --- a/crates/sqlx-sqlite-toolkit/src/lib.rs +++ b/crates/sqlx-sqlite-toolkit/src/lib.rs @@ -46,7 +46,7 @@ pub use error::{Error, Result}; pub use pagination::{KeysetColumn, KeysetPage, SortDirection}; pub use transactions::{ ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, - Statement, TransactionWriter, cleanup_all_transactions, + Statement, TransactionWriter, cleanup_all_transactions, cleanup_transactions_for_db, }; pub use wrapper::{ DatabaseWrapper, InterruptibleTransaction, InterruptibleTransactionBuilder, diff --git a/crates/sqlx-sqlite-toolkit/src/transactions.rs b/crates/sqlx-sqlite-toolkit/src/transactions.rs index 559d9eb..de91965 100644 --- a/crates/sqlx-sqlite-toolkit/src/transactions.rs +++ b/crates/sqlx-sqlite-toolkit/src/transactions.rs @@ -389,6 +389,24 @@ impl ActiveInterruptibleTransactions { } } + /// Roll back and remove the interruptible transaction for a single database, if any. + pub async fn abort_for_db(&self, db_key: &str) { + let maybe_tx = { + let mut txs = self.inner.lock().await; + txs.remove(db_key) + }; + + if let Some(tx) = maybe_tx { + debug!( + "Rolling back interruptible transaction for database: {}", + db_key + ); + if let Err(err) = tx.rollback().await { + warn!("rollback during abort_for_db failed (db: {db_key}): {err}"); + } + } + } + /// Remove and return transaction for commit/rollback. /// /// Returns `Err(Error::TransactionTimedOut)` if the transaction has exceeded the @@ -461,6 +479,26 @@ impl ActiveRegularTransactions { txs.clear(); } + + /// Abort in-flight regular transactions for a single database. + /// + /// Tracking entries are left in place until the cancelled task finishes and + /// removes itself. + pub async fn abort_for_db(&self, db_key: &str) { + let prefix = format!("{db_key}:"); + let handles: Vec<(String, AbortHandle)> = { + let txs = self.0.read().await; + txs.iter() + .filter(|(key, _)| key.starts_with(&prefix)) + .map(|(key, handle)| (key.clone(), handle.clone())) + .collect() + }; + + for (key, abort_handle) in handles { + debug!("Aborting regular transaction: {}", key); + abort_handle.abort(); + } + } } /// Cleanup all transactions on app exit. @@ -475,3 +513,13 @@ pub async fn cleanup_all_transactions( debug!("Transaction cleanup initiated"); } + +pub async fn cleanup_transactions_for_db( + db_key: &str, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result<()> { + interruptible_txs.abort_for_db(db_key).await; + regular_txs.abort_for_db(db_key).await; + Ok(()) +} diff --git a/crates/sqlx-sqlite-toolkit/tests/memory_tests.rs b/crates/sqlx-sqlite-toolkit/tests/memory_tests.rs new file mode 100644 index 0000000..291c491 --- /dev/null +++ b/crates/sqlx-sqlite-toolkit/tests/memory_tests.rs @@ -0,0 +1,81 @@ +use std::path::Path; +use std::sync::Arc; + +use serde_json::json; +use sqlx_sqlite_toolkit::DatabaseWrapper; + +/// These tests verify the correct behavior of `DatabaseWrapper::connect` with an +/// in-memory database. +#[tokio::test] +async fn connect_memory_runs_ddl_and_dml() { + let db = DatabaseWrapper::connect(Path::new(":memory:"), None) + .await + .expect("Failed to connect to in-memory database"); + + db.execute( + "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)".into(), + vec![], + ) + .execute() + .await + .expect("CREATE TABLE should succeed"); + + let result = db + .execute( + "INSERT INTO t (name) VALUES ($1)".into(), + vec![json!("Alice")], + ) + .execute() + .await + .expect("INSERT should succeed"); + + assert_eq!((result.rows_affected, result.last_insert_id), (1, 1)); + + // Read back on the same write connection via an interruptible transaction. + let mut tx = db + .begin_interruptible_transaction() + .execute(vec![( + "INSERT INTO t (name) VALUES (?)", + vec![json!("Bob")], + )]) + .await + .expect("transaction should start"); + + let rows = tx + .read("SELECT name FROM t ORDER BY id".into(), vec![]) + .await + .expect("SELECT within transaction should succeed"); + + assert_eq!(rows.len(), 2); + tx.commit().await.expect("commit should succeed"); +} + +#[tokio::test] +async fn connect_memory_instances_are_independent() { + let db1 = DatabaseWrapper::connect(Path::new(":memory:"), None) + .await + .expect("Failed to connect first in-memory database"); + let db2 = DatabaseWrapper::connect(Path::new(":memory:"), None) + .await + .expect("Failed to connect second in-memory database"); + + assert!( + !Arc::ptr_eq(db1.inner(), db2.inner()), + ":memory: databases should not share the same SqliteDatabase instance" + ); + + db1.execute("CREATE TABLE test (id INTEGER)".into(), vec![]) + .execute() + .await + .expect("CREATE TABLE on first database should succeed"); + + let result = db2 + .fetch_all("SELECT * FROM test".into(), vec![]) + .execute() + .await; + + assert!( + result.is_err(), + "Second :memory: database should not see tables from the first" + ); +} diff --git a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs index 7ac675d..72fdc69 100644 --- a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs +++ b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs @@ -141,6 +141,34 @@ async fn test_abort_all_clears_transactions() { assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); } +#[tokio::test] +async fn test_abort_for_db_clears_only_matching_interruptible() { + let (db1, _temp1) = create_test_db("main.db").await; + let (db2, _temp2) = create_test_db("other.db").await; + + for db in [&db1, &db2] { + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + } + + let state = ActiveInterruptibleTransactions::default(); + let main_tx = begin_transaction(&db1, "main").await; + let main_tx_id = main_tx.transaction_id().to_string(); + let other_tx = begin_transaction(&db2, "other").await; + let other_tx_id = other_tx.transaction_id().to_string(); + + state.insert("main".into(), main_tx).await.unwrap(); + state.insert("other".into(), other_tx).await.unwrap(); + + state.abort_for_db("main").await; + + let err = expect_err(state.remove("main", &main_tx_id).await); + assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); + + assert!(state.remove("other", &other_tx_id).await.is_ok()); +} + #[tokio::test] async fn test_abort_all_auto_rollbacks_uncommitted_writes() { let (db, _temp) = create_test_db("rollback.db").await; @@ -329,6 +357,32 @@ async fn test_regular_abort_all_clears_state() { state.insert("a".into(), h3.abort_handle()).await; } +#[tokio::test] +async fn test_regular_abort_for_db_only_matching_prefix() { + let state = ActiveRegularTransactions::default(); + + let main_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + let other_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + state + .insert("main:one".into(), main_handle.abort_handle()) + .await; + state + .insert("other:two".into(), other_handle.abort_handle()) + .await; + + state.abort_for_db("main").await; + + assert!(main_handle.await.unwrap_err().is_cancelled()); + assert!(!other_handle.is_finished()); + + state.remove("main:one").await; +} + // ============================================================================ // cleanup_all_transactions tests // ============================================================================ diff --git a/guest-js/index.test.ts b/guest-js/index.test.ts index 39d1ed9..54a0cb4 100644 --- a/guest-js/index.test.ts +++ b/guest-js/index.test.ts @@ -16,12 +16,49 @@ import Database, { let lastCmd = '', lastArgs: Record = {}; +const DATABASE_KEYS = { + T: 'T', + TEST: 'TEST', + ARCHIVE: 'ARCHIVE', + STATS: 'STATS', + ORDERS: 'ORDERS', + MAIN: 'MAIN', +}; + +function getDatabasePath(key: string): string { + if (key === DATABASE_KEYS.ARCHIVE) { + return 'archive.db'; + } + + if (key === DATABASE_KEYS.STATS) { + return 'stats.db'; + } + + if (key === DATABASE_KEYS.ORDERS) { + return 'orders.db'; + } + + if (key === DATABASE_KEYS.T) { + return 't.db'; + } + + if (key === DATABASE_KEYS.TEST) { + return 'test.db'; + } + + if (key === DATABASE_KEYS.MAIN) { + return 'main.db'; + } + + throw new Error(`Unknown database key: ${key}`); +} + beforeEach(() => { mockIPC((cmd, args) => { lastCmd = cmd; lastArgs = args as Record; if (cmd === 'plugin:sqlite|load') { - return (args as { db: string }).db; + return getDatabasePath((args as { dbKey: string }).dbKey); } if (cmd === 'plugin:sqlite|execute') { return [ 1, 1 ]; @@ -30,13 +67,13 @@ beforeEach(() => { return []; } if (cmd === 'plugin:sqlite|begin_interruptible_transaction') { - return { dbPath: (args as { db: string }).db, transactionId: 'test-tx-id' }; + return { dbKey: (args as { dbKey: string }).dbKey, transactionId: 'test-tx-id' }; } if (cmd === 'plugin:sqlite|transaction_continue') { const action = (args as { action: { type: string } }).action; if (action.type === 'Continue') { - return { dbPath: 'test.db', transactionId: 'test-tx-id' }; + return { dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }; } return undefined; } @@ -84,32 +121,41 @@ afterEach(() => { return clearMocks(); }); describe('Database commands', () => { it('load', async () => { - await Database.load('test.db'); + await Database.load(DATABASE_KEYS.TEST); expect(lastCmd).toBe('plugin:sqlite|load'); - expect(lastArgs.db).toBe('test.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.TEST); + }); + + it('get is lazy until first operation', async () => { + const db = Database.get(DATABASE_KEYS.TEST); + + expect(db.path).toBe(''); + + await db.fetchAll('SELECT 1'); + expect(db.path).toBe(getDatabasePath(DATABASE_KEYS.TEST)); }); it('execute', async () => { - await Database.get('t.db').execute('INSERT INTO t VALUES ($1)', [ 1 ]); + await Database.get(DATABASE_KEYS.T).execute('INSERT INTO t VALUES ($1)', [ 1 ]); expect(lastCmd).toBe('plugin:sqlite|execute'); - expect(lastArgs).toMatchObject({ db: 't.db', query: 'INSERT INTO t VALUES ($1)', values: [ 1 ], attached: null }); + expect(lastArgs).toMatchObject({ dbKey: DATABASE_KEYS.T, query: 'INSERT INTO t VALUES ($1)', values: [ 1 ], attached: null }); }); it('execute with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .execute('UPDATE todos SET status = $1 WHERE id IN (SELECT todo_id FROM archive.completed)', [ 'archived' ]) .attach([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|execute'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'archive.db', + databaseKey: 'ARCHIVE', schemaName: 'archive', mode: 'readOnly', }, @@ -117,30 +163,30 @@ describe('Database commands', () => { }); it('execute_transaction', async () => { - await Database.get('t.db').executeTransaction([ [ 'DELETE FROM t' ] ]); + await Database.get(DATABASE_KEYS.T).executeTransaction([ [ 'DELETE FROM t' ] ]); expect(lastCmd).toBe('plugin:sqlite|execute_transaction'); expect(lastArgs.statements).toEqual([ { query: 'DELETE FROM t', values: [] } ]); expect(lastArgs.attached).toBe(null); }); it('execute_transaction with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .executeTransaction([ [ 'INSERT INTO orders (user_id, total) VALUES ($1, $2)', [ 1, 99.99 ] ], [ 'UPDATE stats.order_stats SET order_count = order_count + 1', [] ], ]) .attach([ { - databasePath: 'stats.db', + databaseKey: DATABASE_KEYS.STATS, schemaName: 'stats', mode: 'readWrite', }, ]); expect(lastCmd).toBe('plugin:sqlite|execute_transaction'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'stats.db', + databaseKey: DATABASE_KEYS.STATS, schemaName: 'stats', mode: 'readWrite', }, @@ -148,26 +194,26 @@ describe('Database commands', () => { }); it('fetch_all', async () => { - await Database.get('t.db').fetchAll('SELECT * FROM t'); + await Database.get(DATABASE_KEYS.T).fetchAll('SELECT * FROM t'); expect(lastCmd).toBe('plugin:sqlite|fetch_all'); - expect(lastArgs).toMatchObject({ db: 't.db', query: 'SELECT * FROM t', attached: null }); + expect(lastArgs).toMatchObject({ dbKey: DATABASE_KEYS.T, query: 'SELECT * FROM t', attached: null }); }); it('fetch_all with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .fetchAll('SELECT u.name, o.total FROM users u JOIN orders.orders o ON u.id = o.user_id', []) .attach([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|fetch_all'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, @@ -175,26 +221,26 @@ describe('Database commands', () => { }); it('fetch_one', async () => { - await Database.get('t.db').fetchOne('SELECT * FROM t WHERE id = $1', [ 1 ]); + await Database.get(DATABASE_KEYS.T).fetchOne('SELECT * FROM t WHERE id = $1', [ 1 ]); expect(lastCmd).toBe('plugin:sqlite|fetch_one'); - expect(lastArgs).toMatchObject({ db: 't.db', values: [ 1 ], attached: null }); + expect(lastArgs).toMatchObject({ dbKey: DATABASE_KEYS.T, values: [ 1 ], attached: null }); }); it('fetch_one with attached databases', async () => { - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .fetchOne('SELECT COUNT(*) as total FROM users u JOIN orders.orders o ON u.id = o.user_id', []) .attach([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|fetch_one'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.attached).toEqual([ { - databasePath: 'orders.db', + databaseKey: DATABASE_KEYS.ORDERS, schemaName: 'orders', mode: 'readOnly', }, @@ -206,10 +252,10 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db').fetchPage('SELECT * FROM posts', [], keyset, 25); + await Database.get(DATABASE_KEYS.T).fetchPage('SELECT * FROM posts', [], keyset, 25); expect(lastCmd).toBe('plugin:sqlite|fetch_page'); expect(lastArgs).toMatchObject({ - db: 't.db', + dbKey: DATABASE_KEYS.T, query: 'SELECT * FROM posts', values: [], keyset: [ { name: 'id', direction: 'asc' } ], @@ -225,7 +271,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts', [], keyset, 25) .after([ 100 ]); @@ -239,7 +285,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts', [], keyset, 25) .before([ 50 ]); @@ -253,22 +299,22 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('main.db') + await Database.get(DATABASE_KEYS.MAIN) .fetchPage('SELECT * FROM posts', [], keyset, 25) .attach([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|fetch_page'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.after).toBeNull(); expect(lastArgs.before).toBeNull(); expect(lastArgs.attached).toEqual([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, @@ -282,7 +328,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts WHERE active = $1', [ true ], keyset, 50) .after([ 'tech', 95, 42 ]); @@ -305,7 +351,7 @@ describe('Database commands', () => { { name: 'id', direction: 'asc' }, ]; - await Database.get('t.db') + await Database.get(DATABASE_KEYS.T) .fetchPage('SELECT * FROM posts WHERE active = $1', [ true ], keyset, 50) .before([ 'tech', 95, 42 ]); @@ -322,9 +368,9 @@ describe('Database commands', () => { }); it('close', async () => { - await Database.get('t.db').close(); + await Database.get(DATABASE_KEYS.T).close(); expect(lastCmd).toBe('plugin:sqlite|close'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); }); it('close_all', async () => { @@ -333,15 +379,15 @@ describe('Database commands', () => { }); it('remove', async () => { - await Database.get('t.db').remove(); + await Database.get(DATABASE_KEYS.T).remove(); expect(lastCmd).toBe('plugin:sqlite|remove'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); }); it('getMigrationEvents', async () => { const mockEvents: MigrationEvent[] = [ - { dbPath: 't.db', status: 'running' }, - { dbPath: 't.db', status: 'completed', migrationCount: 5 }, + { dbKey: DATABASE_KEYS.T, dbPath: 't.db', status: 'running' }, + { dbKey: DATABASE_KEYS.T, dbPath: 't.db', status: 'completed', migrationCount: 5 }, ]; mockIPC((cmd, args) => { @@ -353,28 +399,28 @@ describe('Database commands', () => { return undefined; }); - const events = await Database.get('t.db').getMigrationEvents(); + const events = await Database.get(DATABASE_KEYS.T).getMigrationEvents(); expect(lastCmd).toBe('plugin:sqlite|get_migration_events'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(events).toEqual(mockEvents); }); it('getMigrationEvents - empty array', async () => { - const events = await Database.get('test.db').getMigrationEvents(); + const events = await Database.get(DATABASE_KEYS.TEST).getMigrationEvents(); expect(lastCmd).toBe('plugin:sqlite|get_migration_events'); - expect(lastArgs.db).toBe('test.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.TEST); expect(events).toEqual([]); }); it('beginInterruptibleTransaction', async () => { - const tx = await Database.get('t.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.T).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); expect(lastCmd).toBe('plugin:sqlite|begin_interruptible_transaction'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(lastArgs.initialStatements).toEqual([ { query: 'INSERT INTO users (name) VALUES ($1)', values: [ 'Alice' ] }, ]); @@ -383,26 +429,26 @@ describe('Database commands', () => { }); it('beginInterruptibleTransaction with attached databases', async () => { - const tx = await Database.get('main.db') + const tx = await Database.get(DATABASE_KEYS.MAIN) .beginInterruptibleTransaction([ [ 'DELETE FROM users WHERE id IN (SELECT user_id FROM archive.archived_users)' ], ]) .attach([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, ]); expect(lastCmd).toBe('plugin:sqlite|begin_interruptible_transaction'); - expect(lastArgs.db).toBe('main.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.MAIN); expect(lastArgs.initialStatements).toEqual([ { query: 'DELETE FROM users WHERE id IN (SELECT user_id FROM archive.archived_users)', values: [] }, ]); expect(lastArgs.attached).toEqual([ { - databasePath: 'archive.db', + databaseKey: DATABASE_KEYS.ARCHIVE, schemaName: 'archive', mode: 'readOnly', }, @@ -411,7 +457,7 @@ describe('Database commands', () => { }); it('InterruptibleTransaction.continueWith()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); @@ -420,67 +466,70 @@ describe('Database commands', () => { ]); expect(lastCmd).toBe('plugin:sqlite|transaction_continue'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect((lastArgs.action as { type: string }).type).toBe('Continue'); expect(tx2).toBeInstanceOf(Object); }); it('InterruptibleTransaction.commit()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); await tx.commit(); expect(lastCmd).toBe('plugin:sqlite|transaction_continue'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect((lastArgs.action as { type: string }).type).toBe('Commit'); }); it('InterruptibleTransaction.rollback()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); await tx.rollback(); expect(lastCmd).toBe('plugin:sqlite|transaction_continue'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect((lastArgs.action as { type: string }).type).toBe('Rollback'); }); it('InterruptibleTransaction.read()', async () => { - const tx = await Database.get('test.db').beginInterruptibleTransaction([ + const tx = await Database.get(DATABASE_KEYS.TEST).beginInterruptibleTransaction([ [ 'INSERT INTO users (name) VALUES ($1)', [ 'Alice' ] ], ]); await tx.read('SELECT * FROM users WHERE name = $1', [ 'Alice' ]); expect(lastCmd).toBe('plugin:sqlite|transaction_read'); - expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }); + expect(lastArgs.token).toEqual({ dbKey: DATABASE_KEYS.TEST, transactionId: 'test-tx-id' }); expect(lastArgs.query).toBe('SELECT * FROM users WHERE name = $1'); expect(lastArgs.values).toEqual([ 'Alice' ]); }); it('handles errors from backend', async () => { + const db = new Database(DATABASE_KEYS.T, getDatabasePath(DATABASE_KEYS.T)); + mockIPC(() => { throw new Error('Database error'); }); - await expect(Database.get('t.db').execute('SELECT 1', [])).rejects.toThrow('Database error'); + + await expect(db.execute('SELECT 1', [])).rejects.toThrow('Database error'); }); }); describe('Database.load with customConfig', () => { it('passes customConfig to backend', async () => { - await Database.load('test.db', { maxReadConnections: 10, idleTimeoutSecs: 60 }); + await Database.load(DATABASE_KEYS.TEST, { maxReadConnections: 10, idleTimeoutSecs: 60 }); expect(lastCmd).toBe('plugin:sqlite|load'); - expect(lastArgs.db).toBe('test.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.TEST); expect(lastArgs.customConfig).toEqual({ maxReadConnections: 10, idleTimeoutSecs: 60 }); }); }); describe('Observer commands', () => { it('observe', async () => { - await Database.get('t.db').observe([ 'users', 'posts' ]); + await Database.get(DATABASE_KEYS.T).observe([ 'users', 'posts' ]); expect(lastCmd).toBe('plugin:sqlite|observe'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(lastArgs.tables).toEqual([ 'users', 'posts' ]); expect(lastArgs.config).toBe(null); }); @@ -488,7 +537,7 @@ describe('Observer commands', () => { it('observe with config', async () => { const config: ObserverConfig = { channelCapacity: 512, captureValues: false }; - await Database.get('t.db').observe([ 'users' ], config); + await Database.get(DATABASE_KEYS.T).observe([ 'users' ], config); expect(lastCmd).toBe('plugin:sqlite|observe'); expect(lastArgs.tables).toEqual([ 'users' ]); expect(lastArgs.config).toEqual({ channelCapacity: 512, captureValues: false }); @@ -497,10 +546,10 @@ describe('Observer commands', () => { it('subscribe', async () => { const events: TableChangeEvent[] = []; - const sub = await Database.get('t.db').subscribe([ 'users' ], (e) => { events.push(e); }); + const sub = await Database.get(DATABASE_KEYS.T).subscribe([ 'users' ], (e) => { events.push(e); }); expect(lastCmd).toBe('plugin:sqlite|subscribe'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); expect(lastArgs.tables).toEqual([ 'users' ]); expect(lastArgs.onEvent).toBeDefined(); expect(sub).toBeInstanceOf(Subscription); @@ -520,15 +569,16 @@ describe('Observer commands', () => { }); it('unobserve', async () => { - await Database.get('t.db').unobserve(); + await Database.get(DATABASE_KEYS.T).unobserve(); expect(lastCmd).toBe('plugin:sqlite|unobserve'); - expect(lastArgs.db).toBe('t.db'); + expect(lastArgs.dbKey).toBe(DATABASE_KEYS.T); }); }); describe('MigrationEvent type', () => { it('accepts running status', () => { const event: MigrationEvent = { + dbKey: 'TEST', dbPath: 'test.db', status: 'running', }; @@ -540,6 +590,7 @@ describe('MigrationEvent type', () => { it('accepts completed status with migrationCount', () => { const event: MigrationEvent = { + dbKey: 'TEST', dbPath: 'test.db', status: 'completed', migrationCount: 3, @@ -551,6 +602,7 @@ describe('MigrationEvent type', () => { it('accepts failed status with error', () => { const event: MigrationEvent = { + dbKey: 'TEST', dbPath: 'test.db', status: 'failed', error: 'Migration failed: syntax error', diff --git a/guest-js/index.ts b/guest-js/index.ts index 3537fae..2f3cb2e 100644 --- a/guest-js/index.ts +++ b/guest-js/index.ts @@ -1,5 +1,25 @@ import { Channel, invoke } from '@tauri-apps/api/core'; +const databaseLoadPromises = new WeakMap>(); + +async function ensureDatabaseLoaded(db: Database): Promise { + if (db.path !== '') { + return; + } + + let loadPromise = databaseLoadPromises.get(db); + + if (!loadPromise) { + loadPromise = invoke('plugin:sqlite|load', { dbKey: db.key }) + .then((resolvedPath) => { + db.path = resolvedPath; + }); + databaseLoadPromises.set(db, loadPromise); + } + + await loadPromise; +} + /** * Valid SQLite parameter binding value types. * @@ -23,9 +43,9 @@ export type AttachedDatabaseMode = 'readOnly' | 'readWrite'; export interface AttachedDatabaseSpec { /** - * Path to the database to attach (must be loaded via `Database.load()` first) + * Database key for identifying the database to attach */ - databasePath: string; + databaseKey: string; /** * Schema name to use for the attached database in queries @@ -77,11 +97,11 @@ export interface SqliteError { * Provides methods to read uncommitted data and execute additional statements. */ export class InterruptibleTransaction { - private readonly _dbPath: string; + private readonly _dbKey: string; private readonly _transactionId: string; - public constructor(dbPath: string, transactionId: string) { - this._dbPath = dbPath; + public constructor(dbKey: string, transactionId: string) { + this._dbKey = dbKey; this._transactionId = transactionId; } @@ -112,7 +132,7 @@ export class InterruptibleTransaction { */ public async read(query: string, bindValues?: SqlValue[]): Promise { return await invoke('plugin:sqlite|transaction_read', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, query, values: bindValues ?? [], }); @@ -137,10 +157,10 @@ export class InterruptibleTransaction { * ``` */ public async continueWith(statements: Array<[string, SqlValue[]?]>): Promise { - const token = await invoke<{ dbPath: string; transactionId: string }>( + const token = await invoke<{ dbKey: string; transactionId: string }>( 'plugin:sqlite|transaction_continue', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, action: { type: 'Continue', statements: statements.map(([ query, values ]) => { @@ -153,7 +173,7 @@ export class InterruptibleTransaction { } ); - return new InterruptibleTransaction(token.dbPath, token.transactionId); + return new InterruptibleTransaction(token.dbKey, token.transactionId); } /** @@ -170,7 +190,7 @@ export class InterruptibleTransaction { */ public async commit(): Promise { await invoke('plugin:sqlite|transaction_continue', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, action: { type: 'Commit' }, }); } @@ -189,7 +209,7 @@ export class InterruptibleTransaction { */ public async rollback(): Promise { await invoke('plugin:sqlite|transaction_continue', { - token: { dbPath: this._dbPath, transactionId: this._transactionId }, + token: { dbKey: this._dbKey, transactionId: this._transactionId }, action: { type: 'Rollback' }, }); } @@ -220,7 +240,7 @@ export interface CustomConfig { * // Get all migration events (including ones emitted before registering listener) * const events = await db.getMigrationEvents() * for (const event of events) { - * console.log(`${event.status}: ${event.dbPath}`) + * console.log(`${event.status}: ${event.dbKey}`) * if (event.status === 'failed') { * console.error(`Migration error: ${event.error}`) * } @@ -230,17 +250,17 @@ export interface CustomConfig { * * // Listen for real-time events (may miss early events) * await listen('sqlite:migration', (event) => { - * const { dbPath, status, migrationCount, error } = event.payload + * const { dbKey, dbPath, status, migrationCount, error } = event.payload * * switch (status) { * case 'running': - * console.log(`Running migrations for ${dbPath}`) + * console.log(`Running migrations for ${dbKey}`) * break * case 'completed': - * console.log(`Completed ${migrationCount} migrations for ${dbPath}`) + * console.log(`Completed ${migrationCount} migrations for ${dbKey}`) * break * case 'failed': - * console.error(`Migration failed for ${dbPath}: ${error}`) + * console.error(`Migration failed for ${dbKey}: ${error}`) * break * } * }) @@ -248,7 +268,12 @@ export interface CustomConfig { */ export interface MigrationEvent { - /** Database path (relative, as registered with the plugin) */ + /** Database registration key (such as `MAIN`) */ + dbKey: string; + + /** Database path, the absolute path to the database file, such as + * `/var/lib/myapp/main.db`. + */ dbPath: string; /** Status: "running", "completed", "failed" */ @@ -448,8 +473,10 @@ class FetchAllBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + return await invoke('plugin:sqlite|fetch_all', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, attached: this._attached.length > 0 ? this._attached : null, @@ -497,8 +524,10 @@ class FetchOneBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + return await invoke('plugin:sqlite|fetch_one', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, attached: this._attached.length > 0 ? this._attached : null, @@ -577,8 +606,10 @@ class FetchPageBuilder implements PromiseLike> { } private async _execute(): Promise> { + await ensureDatabaseLoaded(this._db); + return await invoke>('plugin:sqlite|fetch_page', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, keyset: this._keyset, @@ -630,10 +661,12 @@ class ExecuteBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + const [ rowsAffected, lastInsertId ] = await invoke<[number, number]>( 'plugin:sqlite|execute', { - db: this._db.path, + dbKey: this._db.key, query: this._query, values: this._bindValues, attached: this._attached.length > 0 ? this._attached : null, @@ -684,10 +717,12 @@ class InterruptibleTransactionBuilder implements PromiseLike { - const token = await invoke<{ dbPath: string; transactionId: string }>( + await ensureDatabaseLoaded(this._db); + + const token = await invoke<{ dbKey: string; transactionId: string }>( 'plugin:sqlite|begin_interruptible_transaction', { - db: this._db.path, + dbKey: this._db.key, initialStatements: this._initialStatements.map(([ query, values ]) => { return { query, @@ -698,7 +733,7 @@ class InterruptibleTransactionBuilder implements PromiseLike { } private async _execute(): Promise { + await ensureDatabaseLoaded(this._db); + return await invoke('plugin:sqlite|execute_transaction', { - db: this._db.path, + dbKey: this._db.key, statements: this._statements.map(([ query, values ]) => { return { query, @@ -765,9 +802,17 @@ class TransactionBuilder implements PromiseLike { * visible to reads in another, and closing a database affects all windows. */ export default class Database { + + /** Database key for identifying the database */ + public key: string; + + /** Database path, the absolute path to the database file, such as + * `/var/lib/myapp/main.db`. + */ public path: string; - public constructor(path: string) { + public constructor(key: string, path: string) { + this.key = key; this.path = path; } @@ -777,51 +822,58 @@ export default class Database { * A static initializer which connects to the underlying SQLite database and * returns a `Database` instance once a connection is established. * - * The path is relative to `tauri::path::BaseDirectory::AppConfig`. + * The key must be a key that the Rust side has registered with the plugin + * (via `Builder::register_database` / `SetupRegistrar::register_database`), + * or an in-memory database such as `:memory:`. * - * @param path - Database file path (relative to AppConfig directory) + * @param key - The key of the database to load * @param customConfig - Optional custom configuration for connection pools * * @example * ```ts + * const dbKey = 'MAIN'; + * * // Use default configuration - * const db = await Database.load("test.db"); + * const db = await Database.load(dbKey); * * // Use custom configuration - * const db = await Database.load("test.db", { + * const db2 = await Database.load(dbKey, { * maxReadConnections: 10, * idleTimeoutSecs: 60 * }); * ``` */ public static async load( - path: string, + dbKey: string, customConfig?: CustomConfig ): Promise { const resolvedPath = await invoke('plugin:sqlite|load', { - db: path, + dbKey, customConfig, }); - return new Database(resolvedPath); + return new Database(dbKey, resolvedPath); } /** * **get** * - * A static initializer which synchronously returns an instance of - * the Database class while deferring the actual database connection - * until the first invocation or selection on the database. + * Synchronously returns a `Database` handle for a registered key while deferring + * the actual connection until the first query or other operation that requires a + * loaded connection. Use {@link Database.load} when you need to pass + * `customConfig` or connect eagerly. * - * The path is relative to `tauri::path::BaseDirectory::AppConfig`. + * The key must be registered with the plugin on the Rust side + * (via `Builder::register_database` / `SetupRegistrar::register_database`). * * @example * ```ts - * const db = Database.get("test.db"); + * const db = Database.get('MAIN'); + * await db.fetchAll('SELECT 1'); * ``` */ - public static get(path: string): Database { - return new Database(path); + public static get(dbKey: string): Database { + return new Database(dbKey, ''); } /** @@ -866,7 +918,7 @@ export default class Database { * "(SELECT todo_id FROM archive.completed)", * [ "archived" ] * ).attach([{ - * databasePath: "archive.db", + * databaseKey: "ARCHIVE", * schemaName: "archive", * mode: "readOnly" * }]); @@ -918,7 +970,7 @@ export default class Database { * ['INSERT INTO main.orders (user_id, total) VALUES ($1, $2)', [userId, total]], * ['UPDATE archive.stats SET order_count = order_count + 1', []] * ]).attach([{ - * databasePath: "archive.db", + * databaseKey: "ARCHIVE", * schemaName: "archive", * mode: "readWrite" * }]); @@ -957,7 +1009,7 @@ export default class Database { * "SELECT u.name, o.total FROM users u JOIN orders.orders o ON u.id = o.user_id", * [] * ).attach([{ - * databasePath: "orders.db", + * databaseKey: "ORDERS", * schemaName: "orders", * mode: "readOnly" * }]); @@ -992,7 +1044,7 @@ export default class Database { * "SELECT COUNT(*) as total FROM users u JOIN orders.orders o ON u.id = o.user_id", * [] * ).attach([{ - * databasePath: "orders.db", + * databaseKey: "ORDERS", * schemaName: "orders", * mode: "readOnly" * }]); @@ -1061,7 +1113,7 @@ export default class Database { * keyset, * 25, * ).attach([{ - * databasePath: 'archive.db', + * databaseKey: 'ARCHIVE', * schemaName: 'archive', * mode: 'readOnly', * }]); @@ -1106,8 +1158,10 @@ export default class Database { * ``` */ public async observe(tables: string[], config?: ObserverConfig): Promise { + await ensureDatabaseLoaded(this); + await invoke('plugin:sqlite|observe', { - db: this.path, + dbKey: this.key, tables, config: config ?? null, }); @@ -1148,12 +1202,14 @@ export default class Database { tables: string[], onEvent: (event: TableChangeEvent) => void ): Promise { + await ensureDatabaseLoaded(this); + const channel = new Channel(); channel.onmessage = onEvent; const subscriptionId = await invoke('plugin:sqlite|subscribe', { - db: this.path, + dbKey: this.key, tables, onEvent: channel, }); @@ -1174,8 +1230,10 @@ export default class Database { * ``` */ public async unobserve(): Promise { + await ensureDatabaseLoaded(this); + await invoke('plugin:sqlite|unobserve', { - db: this.path, + dbKey: this.key, }); } @@ -1199,7 +1257,7 @@ export default class Database { */ public async close(): Promise { const success = await invoke('plugin:sqlite|close', { - db: this.path, + dbKey: this.key, }); return success; @@ -1229,7 +1287,7 @@ export default class Database { */ public async remove(): Promise { const success = await invoke('plugin:sqlite|remove', { - db: this.path, + dbKey: this.key, }); return success; @@ -1292,7 +1350,7 @@ export default class Database { * let tx = await db.beginInterruptibleTransaction([ * ['DELETE FROM users WHERE archived = 1'] * ]).attach([{ - * databasePath: 'archive.db', + * databaseKey: 'ARCHIVE', * schemaName: 'archive', * mode: 'readWrite' * }]); @@ -1328,7 +1386,7 @@ export default class Database { * // Get all migration events (including ones that happened before we could listen) * const events = await db.getMigrationEvents() * for (const event of events) { - * console.log(`${event.status}: ${event.dbPath}`) + * console.log(`${event.status}: ${event.dbKey}`) * if (event.status === 'failed') { * console.error(`Migration error: ${event.error}`) * } @@ -1337,7 +1395,7 @@ export default class Database { */ public async getMigrationEvents(): Promise { return await invoke('plugin:sqlite|get_migration_events', { - db: this.path, + dbKey: this.key, }); } } diff --git a/package.json b/package.json index fdfe120..e19c7c2 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,7 @@ ], "scripts": { "build": "rollup -c", + "check:iife": "rollup -c && git diff --exit-code api-iife.js", "check-node-version": "check-node-version --npm 10.5.0", "commitlint": "commitlint --from ${COMMITLINT_FROM:-002bcc8} --to ${COMMITLINT_TO:-HEAD}", "eslint": "eslint .", @@ -27,7 +28,7 @@ "type-check": "run-p type-check:ts", "markdownlint": "markdownlint-cli2", "prepare": "rollup -c", - "standards": "npm run eslint && npm run type-check && npm run markdownlint && npm run rust:lint && npm run commitlint", + "standards": "npm run eslint && npm run type-check && npm run markdownlint && npm run check:iife && npm run rust:lint && npm run commitlint", "rust:lint": "cargo lint-clippy && cargo lint-fmt", "rust:lint:fix": "cargo fix-clippy && cargo fix-fmt", "test": "vitest run && cargo test --workspace --lib --test '*'", diff --git a/src/commands.rs b/src/commands.rs index 2222bda..b315618 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -12,6 +12,7 @@ use sqlx_sqlite_toolkit::{ ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, DatabaseWrapper, Statement, TransactionWriter, WriteQueryResult, }; +use std::path::PathBuf; use std::sync::Arc; use tauri::ipc::Channel; use tauri::{AppHandle, Runtime, State}; @@ -19,17 +20,18 @@ use tracing::debug; use uuid::Uuid; use crate::{ - DbInstances, Error, MigrationEvent, MigrationStates, MigrationStatus, Result, + DbInstances, Error, MigrationEvent, MigrationStates, Result, subscriptions::{ ActiveSubscriptions, ObserverConfigParams, TableChangePayload, event_to_payload, }, }; +use crate::{close_database, connect_to_database}; /// Token representing an active interruptible transaction #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TransactionToken { - pub db_path: String, + pub db_key: String, pub transaction_id: String, } @@ -46,8 +48,8 @@ pub enum TransactionAction { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AttachedDatabaseSpec { - /// Path to the database to attach (must be loaded via `load()` first) - pub database_path: String, + /// Key of the database to attach (must be loaded via `load()` first) + pub database_key: String, /// Schema name to use for the attached database in queries pub schema_name: String, /// Access mode: "readOnly" or "readWrite" @@ -71,8 +73,8 @@ fn resolve_attached_specs( for spec in specs { let wrapper = db_instances - .get(&spec.database_path) - .ok_or_else(|| Error::DatabaseNotLoaded(spec.database_path.clone()))?; + .get(&spec.database_key) + .ok_or_else(|| Error::DatabaseNotLoaded(spec.database_key.clone()))?; let mode = match spec.mode { AttachedDatabaseMode::ReadOnly => sqlx_sqlite_conn_mgr::AttachedMode::ReadOnly, @@ -94,6 +96,10 @@ fn resolve_attached_specs( /// If the database is already loaded, returns the existing connection. /// Otherwise, creates a new connection with optional custom configuration. /// +/// `db_key` must be a registration key from +/// [`Builder::register_database`] / [`SetupRegistrar::register_database`]. +/// Unregistered keys are rejected with `PATH_NOT_REGISTERED`. +/// /// # Migration Timing /// /// If migrations are registered for this database, this function waits for them @@ -104,92 +110,18 @@ fn resolve_attached_specs( #[tauri::command] pub async fn load( app: AppHandle, - db_instances: State<'_, DbInstances>, - migration_states: State<'_, MigrationStates>, - db: String, + db_key: String, custom_config: Option, -) -> Result { - // Wait for migrations to complete if registered for this database - await_migrations(&migration_states, &db).await?; - - let instances = db_instances.inner.read().await; - - // Return cached if db was already loaded - if instances.contains_key(&db) { - return Ok(db); - } - - drop(instances); // Release read lock before acquiring write lock - - let mut instances = db_instances.inner.write().await; - - // Check database count limit before creating a new connection. - // This check is before entry() to avoid borrow conflicts, and the write lock - // prevents races between the len() check and the insert. - if !instances.contains_key(&db) && instances.len() >= db_instances.max { - return Err(Error::TooManyDatabases(db_instances.max)); - } - - // Use entry API to atomically check and insert, avoiding race conditions - // where two callers could both create wrappers - use std::collections::hash_map::Entry; - match instances.entry(db.clone()) { - Entry::Occupied(_) => { - // Another caller won the race and inserted while we waited for write lock - Ok(db) - } - Entry::Vacant(entry) => { - // We won the race, create and insert the wrapper - let wrapper = crate::resolve::connect(&db, &app, custom_config).await?; - entry.insert(wrapper); - Ok(db) - } - } -} - -/// Wait for migrations to complete for a database, if any are registered. -/// -/// Returns Ok(()) if: -/// - No migrations are registered for this database -/// - Migrations completed successfully -/// -/// Returns Err if migrations failed. -async fn await_migrations(migration_states: &State<'_, MigrationStates>, db: &str) -> Result<()> { - loop { - // Get notify handle before checking status - let notify = { - let states = migration_states.0.read().await; - match states.get(db) { - // No migrations registered for this database - None => return Ok(()), - - Some(state) => match &state.status { - // Migrations completed successfully - MigrationStatus::Complete => return Ok(()), - - // Migrations failed - return the error - MigrationStatus::Failed(error) => { - return Err(Error::Migration(sqlx::migrate::MigrateError::Source( - error.clone().into(), - ))); - } - - // Migrations still pending or running - wait for notification - MigrationStatus::Pending | MigrationStatus::Running => state.notify.clone(), - }, - } - }; - - // Wait for migration state change - notify.notified().await; - } +) -> Result { + let response = connect_to_database(&app, &db_key, custom_config).await?; + Ok(response.path) } /// Execute a write query (INSERT, UPDATE, DELETE, etc.) #[tauri::command] pub async fn execute( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, attached: Option>, @@ -197,8 +129,8 @@ pub async fn execute( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.execute(query, values); @@ -217,15 +149,15 @@ pub async fn execute( pub async fn execute_transaction( db_instances: State<'_, DbInstances>, regular_txs: State<'_, ActiveRegularTransactions>, - db: String, + db_key: String, statements: Vec, attached: Option>, ) -> Result> { let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; // Convert Statement structs to tuples for wrapper let stmt_tuples: Vec<(String, Vec)> = statements @@ -234,7 +166,7 @@ pub async fn execute_transaction( .collect(); // Generate unique key for tracking this transaction - let tx_key = format!("{}:{}", db, Uuid::new_v4()); + let tx_key = format!("{}:{}", db_key, Uuid::new_v4()); // Resolve attached specs if provided let resolved_specs = if let Some(specs) = attached { @@ -297,7 +229,7 @@ pub async fn execute_transaction( #[tauri::command] pub async fn fetch_all( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, attached: Option>, @@ -305,8 +237,8 @@ pub async fn fetch_all( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.fetch_all(query, values); @@ -324,7 +256,7 @@ pub async fn fetch_all( #[tauri::command] pub async fn fetch_one( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, attached: Option>, @@ -332,8 +264,8 @@ pub async fn fetch_one( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.fetch_one(query, values); @@ -352,7 +284,7 @@ pub async fn fetch_one( #[tauri::command] pub async fn fetch_page( db_instances: State<'_, DbInstances>, - db: String, + db_key: String, query: String, values: Vec, keyset: Vec, @@ -370,8 +302,8 @@ pub async fn fetch_page( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut builder = wrapper.fetch_page(query, values, keyset, page_size); @@ -391,27 +323,28 @@ pub async fn fetch_page( Ok(result) } -/// Close a specific database connection +/// Close the loaded instance for a registered database key. /// /// Returns `true` if the database was loaded and successfully closed. /// Returns `false` if the database was not loaded (nothing to close). -/// Any active subscriptions for this database are aborted before closing. +/// Active subscriptions and in-flight transactions for this key are aborted +/// before closing (10 second deadline for transactions). #[tauri::command] pub async fn close( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + interruptible_txs: State<'_, ActiveInterruptibleTransactions>, + regular_txs: State<'_, ActiveRegularTransactions>, + db_key: String, ) -> Result { - active_subs.remove_for_db(&db).await; - - let mut instances = db_instances.inner.write().await; - - if let Some(wrapper) = instances.remove(&db) { - wrapper.close().await?; - Ok(true) - } else { - Ok(false) // Database wasn't loaded - } + close_database( + &db_key, + &db_instances, + &active_subs, + &interruptible_txs, + ®ular_txs, + ) + .await } /// Close all database connections @@ -453,13 +386,13 @@ pub async fn close_all( pub async fn remove( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, ) -> Result { - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; - if let Some(wrapper) = instances.remove(&db) { + if let Some(wrapper) = instances.remove(&db_key) { wrapper.remove().await?; Ok(true) } else { @@ -476,11 +409,11 @@ pub async fn remove( #[tauri::command] pub async fn get_migration_events( migration_states: State<'_, MigrationStates>, - db: String, + db_key: String, ) -> Result> { let states = migration_states.0.read().await; - match states.get(&db) { + match states.get(&db_key) { Some(state) => Ok(state.events.clone()), None => Ok(Vec::new()), } @@ -495,15 +428,15 @@ pub async fn get_migration_events( pub async fn begin_interruptible_transaction( db_instances: State<'_, DbInstances>, active_txs: State<'_, ActiveInterruptibleTransactions>, - db: String, + db_key: String, initial_statements: Vec, attached: Option>, ) -> Result { let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; // Generate unique transaction ID let transaction_id = Uuid::new_v4().to_string(); @@ -524,15 +457,15 @@ pub async fn begin_interruptible_transaction( // Execute initial statements let mut active_tx = - ActiveInterruptibleTransaction::new(db.clone(), transaction_id.clone(), writer); + ActiveInterruptibleTransaction::new(db_key.clone(), transaction_id.clone(), writer); active_tx.continue_with(initial_statements).await?; // Store transaction state - active_txs.insert(db.clone(), active_tx).await?; + active_txs.insert(db_key.clone(), active_tx).await?; Ok(TransactionToken { - db_path: db, + db_key, transaction_id, }) } @@ -550,14 +483,14 @@ pub async fn transaction_continue( TransactionAction::Continue { statements } => { // Remove transaction to get mutable access let mut tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; // Execute statements on the transaction match tx.continue_with(statements).await { Ok(_results) => { // Re-insert transaction - if this fails, tx is dropped and auto-rolled back - match active_txs.insert(token.db_path.clone(), tx).await { + match active_txs.insert(token.db_key.clone(), tx).await { Ok(()) => Ok(Some(token)), Err(e) => { // Transaction lost but will auto-rollback via Drop @@ -576,7 +509,7 @@ pub async fn transaction_continue( TransactionAction::Commit => { // Remove transaction and commit let tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; tx.commit().await?; @@ -586,7 +519,7 @@ pub async fn transaction_continue( TransactionAction::Rollback => { // Remove transaction and rollback let tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; tx.rollback().await?; @@ -608,14 +541,14 @@ pub async fn transaction_read( ) -> Result>> { // Remove transaction to get mutable access let mut tx = active_txs - .remove(&token.db_path, &token.transaction_id) + .remove(&token.db_key, &token.transaction_id) .await?; // Execute read on the transaction match tx.read(query, values).await { Ok(results) => { // Re-insert transaction - if this fails, tx is dropped and auto-rolled back - match active_txs.insert(token.db_path.clone(), tx).await { + match active_txs.insert(token.db_key.clone(), tx).await { Ok(()) => Ok(results), Err(e) => { // Transaction lost but will auto-rollback via Drop @@ -643,7 +576,7 @@ pub async fn transaction_read( pub async fn observe( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, tables: Vec, config: Option, ) -> Result<()> { @@ -659,13 +592,13 @@ pub async fn observe( // Abort plugin-level subscription tasks before the crate-level // enable_observation() drops the old broker - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; let wrapper = instances - .get_mut(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get_mut(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let mut observer_config = sqlx_sqlite_observer::ObserverConfig::new().with_tables(tables); @@ -697,13 +630,13 @@ pub async fn observe( pub async fn subscribe( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, tables: Vec, on_event: Channel, ) -> Result { const MAX_SUBSCRIPTIONS_PER_DATABASE: usize = 100; - let sub_count = active_subs.count_for_db(&db).await; + let sub_count = active_subs.count_for_db(&db_key).await; if sub_count >= MAX_SUBSCRIPTIONS_PER_DATABASE { return Err(Error::TooManySubscriptions(MAX_SUBSCRIPTIONS_PER_DATABASE)); } @@ -711,12 +644,12 @@ pub async fn subscribe( let instances = db_instances.inner.read().await; let wrapper = instances - .get(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; let observable = wrapper .observable() - .ok_or_else(|| Error::ObservationNotEnabled(db.clone()))?; + .ok_or_else(|| Error::ObservationNotEnabled(db_key.clone()))?; // Create subscription stream let mut stream = observable.subscribe_stream(tables); @@ -726,7 +659,7 @@ pub async fn subscribe( // Spawn task to forward stream events to the Tauri Channel let sub_id = subscription_id.clone(); - let db_path = db.clone(); + let db_key_clone = db_key.clone(); let handle = tokio::spawn(async move { while let Some(event) = stream.next().await { @@ -738,12 +671,12 @@ pub async fn subscribe( } } - debug!("Subscription {} for db {} ended", sub_id, db_path); + debug!("Subscription {} for db {} ended", sub_id, &db_key_clone); }); // Track subscription active_subs - .insert(subscription_id.clone(), db.clone(), handle.abort_handle()) + .insert(subscription_id.clone(), db_key, handle.abort_handle()) .await; Ok(subscription_id) @@ -767,16 +700,16 @@ pub async fn unsubscribe( pub async fn unobserve( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, - db: String, + db_key: String, ) -> Result<()> { // Abort all subscriptions for this database first - active_subs.remove_for_db(&db).await; + active_subs.remove_for_db(&db_key).await; let mut instances = db_instances.inner.write().await; let wrapper = instances - .get_mut(&db) - .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + .get_mut(&db_key) + .ok_or_else(|| Error::DatabaseNotLoaded(db_key.clone()))?; wrapper.disable_observation(); Ok(()) diff --git a/src/error.rs b/src/error.rs index 4fc5e1e..4b1fd77 100644 --- a/src/error.rs +++ b/src/error.rs @@ -31,6 +31,10 @@ pub enum Error { #[error("path traversal not allowed: {0}")] PathTraversal(String), + /// Database key is not registered with the plugin. + #[error("database key not registered: {0}")] + PathNotRegistered(String), + /// Attempted to access a database that hasn't been loaded. #[error("database {0} not loaded")] DatabaseNotLoaded(String), @@ -84,6 +88,7 @@ impl Error { Error::Migration(_) => "MIGRATION_ERROR".to_string(), Error::InvalidPath(_) => "INVALID_PATH".to_string(), Error::PathTraversal(_) => "PATH_TRAVERSAL".to_string(), + Error::PathNotRegistered(_) => "PATH_NOT_REGISTERED".to_string(), Error::DatabaseNotLoaded(_) => "DATABASE_NOT_LOADED".to_string(), Error::ObservationNotEnabled(_) => "OBSERVATION_NOT_ENABLED".to_string(), Error::TooManyDatabases(_) => "TOO_MANY_DATABASES".to_string(), @@ -123,6 +128,24 @@ mod tests { assert_eq!(err.error_code(), "INVALID_PATH"); } + #[test] + fn test_error_code_path_not_registered() { + let err = Error::PathNotRegistered("MAIN".into()); + assert_eq!(err.error_code(), "PATH_NOT_REGISTERED"); + } + + #[test] + fn test_error_serialization_path_not_registered() { + let err = Error::PathNotRegistered("MAIN".into()); + let json = serde_json::to_value(&err).unwrap(); + + assert_eq!(json["code"], "PATH_NOT_REGISTERED"); + assert_eq!( + json["message"].as_str().unwrap(), + "database key not registered: MAIN" + ); + } + #[test] fn test_error_code_unsupported_datatype() { let err = Error::Toolkit(sqlx_sqlite_toolkit::Error::UnsupportedDatatype( diff --git a/src/lib.rs b/src/lib.rs index f1ff209..c72d209 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,18 @@ use std::collections::HashMap; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::{AtomicU8, Ordering}; use serde::Serialize; use sqlx_sqlite_conn_mgr::Migrator; -use tauri::{Emitter, Manager, RunEvent, Runtime, plugin::Builder as PluginBuilder}; +use tauri::{AppHandle, Emitter, Manager, RunEvent, Runtime, plugin::Builder as PluginBuilder}; use tokio::sync::{Notify, RwLock}; use tracing::{debug, error, info, trace, warn}; mod commands; mod error; -mod resolve; mod subscriptions; +mod validate; pub use error::{Error, Result}; pub use sqlx_sqlite_conn_mgr::{ @@ -23,6 +24,8 @@ pub use sqlx_sqlite_toolkit::{ TransactionExecutionBuilder, WriteQueryResult, }; +use crate::subscriptions::ActiveSubscriptions; + /// Default maximum number of concurrently loaded databases. const DEFAULT_MAX_DATABASES: usize = 50; @@ -54,6 +57,8 @@ impl Drop for ExitGuard { /// This struct maintains a thread-safe map of database paths to their corresponding /// connection wrappers, with a configurable upper limit on how many databases can be /// loaded simultaneously. +/// +/// The string key is the registered database key. #[derive(Clone)] pub struct DbInstances { pub(crate) inner: Arc>>, @@ -79,6 +84,64 @@ impl DbInstances { } } +/// Tracks the paths of all registered databases. +/// The String value of the key is the database identifier, not the path. +/// For example, the value of the key `MAIN` would be something like +/// `/var/lib/myapp/main.db`. +/// +/// This key value is what will be used by the caller to interact with the database. +/// For example, when calling `load()` or `execute()`, the caller will pass the key value +/// to identify the database to which they want to connect. +#[derive(Clone, Default)] +pub struct RegisteredDatabases { + pub(crate) database_path_by_key: Arc>, +} + +/// Contains the information required for registering a database. +/// +/// When initializing or setting up the plugin, the caller will pass the path to the database +/// file and the migrator to use for the database. +/// +/// This information is then stored in the `RegisteredDatabases` struct, which is used to +/// track the paths of all registered databases. +/// +/// The `migrator` is not held by the app state, but rather is only used after +/// initialization to run the migrations for the database. +#[derive(Debug, Clone)] +struct DatabaseInfo { + path: PathBuf, + migrator: Option>, +} + +fn validated_database_info( + path: impl Into, + migrator: Option, +) -> Result { + let path = path.into(); + Ok(DatabaseInfo { + path: validate::validate_database_path(&path)?, + migrator: migrator.map(Arc::new), + }) +} + +/// Ensure each registration key maps to a distinct database path. +fn ensure_distinct_database_paths( + database_info_by_key: &HashMap, +) -> Result<()> { + let mut path_to_key = HashMap::new(); + + for (key, info) in database_info_by_key { + if let Some(existing_key) = path_to_key.insert(info.path.clone(), key.as_str()) { + return Err(Error::InvalidConfig(format!( + "database keys {existing_key} and {key} both register the same path: {}", + info.path.display() + ))); + } + } + + Ok(()) +} + /// Migration status for a database. #[derive(Debug, Clone)] pub enum MigrationStatus { @@ -119,6 +182,7 @@ impl MigrationState { } /// Tracks migration state for all databases. +/// The String value of the key is the database identifier, not the path. #[derive(Default)] pub struct MigrationStates(pub RwLock>); @@ -126,8 +190,12 @@ pub struct MigrationStates(pub RwLock>); #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct MigrationEvent { - /// Database path (relative, as registered) - pub db_path: String, + /// Database key, meant to be human readable (such as `MAIN`). + /// This is what is to be used by the client to interact with the database. + pub db_key: String, + /// Database path, the absolute path to the database file, such as + /// `/var/lib/myapp/main.db`. + pub db_path: PathBuf, /// Status: "running", "completed", "failed" pub status: String, /// Total number of migrations defined in the migrator (on "completed"), not just newly applied @@ -142,6 +210,17 @@ pub struct MigrationEvent { /// /// Use this to configure the plugin and build the plugin instance. /// +/// # Database registration +/// +/// Every database must be **registered** with a stable key and filesystem path (or +/// in-memory URI) before it can be opened. The frontend and Rust callers open databases +/// by **key** via `load()` / [`Connection::connect`]. Paths are validated and +/// canonicalized at registration time. +/// +/// Because legitimate paths usually depend on runtime values (for example +/// `app.path().app_data_dir()`), registration normally happens in the [`Builder::on_setup`] +/// hook. Static paths can be registered up front with [`Builder::register_database`]. +/// /// # Example /// /// ```ignore @@ -151,9 +230,9 @@ pub struct MigrationEvent { /// use tauri_plugin_sqlite::Builder; /// /// # fn main() { -/// // Basic setup (no migrations): +/// // Basic setup (no databases registered yet — register them in `on_setup`): /// tauri::Builder::default() -/// .plugin(Builder::new().build()) +/// .plugin(Builder::new().build().expect("failed to build sqlite plugin")) /// .run(tauri::generate_context!()) /// .expect("error while running tauri application"); /// # } @@ -166,64 +245,169 @@ pub struct MigrationEvent { /// // tauri::generate_context!() requires tauri.conf.json at compile time, /// // which cannot be provided in doc test environments. /// use tauri_plugin_sqlite::Builder; +/// use tauri::Manager; /// /// # fn main() { -/// // Setup with migrations: +/// // Resolve the database path from the app instance and register it with migrations. +/// // The frontend then calls `Database.load("MAIN")` with the registration key. /// tauri::Builder::default() /// .plugin( /// Builder::new() -/// .add_migrations("main.db", sqlx::migrate!("./migrations/main")) -/// .add_migrations("cache.db", sqlx::migrate!("./migrations/cache")) +/// .on_setup(|app, reg| { +/// let db = app.path().app_data_dir()?.join("main.db"); +/// reg.register_database( +/// "MAIN", +/// db, +/// Some(sqlx::migrate!("./migrations/main")), +/// )?; +/// Ok(()) +/// }) /// .build() +/// .expect("failed to build sqlite plugin") /// ) /// .run(tauri::generate_context!()) /// .expect("error while running tauri application"); /// # } /// ``` -#[derive(Debug, Default)] -pub struct Builder { - /// Migrations registered per database path - migrations: HashMap>, +/// +/// Collects database registrations from the [`Builder::on_setup`] hook. +/// +/// Passed to the `on_setup` closure during plugin setup, where the `app` instance is +/// available. Use it to register values that can only be computed at runtime (for example, +/// paths derived from `app.path().app_data_dir()`). +#[derive(Default)] +pub struct SetupRegistrar { + database_info_by_key: HashMap, +} + +impl SetupRegistrar { + /// Register a database path, optionally with migrations. See [`Builder::register_database`]. + /// + /// This invocation is to be used when the database path is known at runtime (such as + /// a path dependent on `app.path().app_data_dir()`). + /// + /// For a path that is known at compile time, use [`Builder::register_database`] + /// instead. + /// + /// The `key` is the identifier for the database. It is used to identify the database + /// when calling `load()` or `execute()`. + /// + /// The `path` is the absolute filesystem path or in-memory URI. It is validated and + /// canonicalized at registration time. + /// + /// Returns `Err` if the path fails validation (relative, traversal, or canonicalization). + /// + /// The `migrator` runs automatically at plugin initialization when provided. + /// + /// If the same key is registered more than once, the last registration will override + /// all previous ones. + /// + /// Distinct keys must map to distinct database paths. Duplicate paths are rejected + /// when the plugin initializes (see [`Builder::build`]). + pub fn register_database( + &mut self, + key: &str, + path: impl Into, + migrator: Option, + ) -> Result<()> { + self + .database_info_by_key + .insert(key.to_string(), validated_database_info(path, migrator)?); + Ok(()) + } +} + +/// Closure type for the deferred [`Builder::on_setup`] hook. +type OnSetupHook = Box, &mut SetupRegistrar) -> Result<()> + Send>; + +pub struct Builder { + /// Migrations registered per database path, keyed by the database key. + database_info_by_key: HashMap, /// Timeout for interruptible transactions. Defaults to 5 minutes. transaction_timeout: Option, /// Maximum number of concurrently loaded databases. Defaults to 50. max_databases: Option, + /// Deferred hook run during plugin setup with the app handle. Lets callers register + /// paths/migrations computed from `app`. Returning `Err` aborts app startup. + on_setup: Option>, } -impl Builder { +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for Builder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Builder") + .field("database_info_by_key", &self.database_info_by_key) + .field("transaction_timeout", &self.transaction_timeout) + .field("max_databases", &self.max_databases) + .field("on_setup", &self.on_setup.is_some()) + .finish() + } +} + +impl Builder { /// Create a new builder instance. pub fn new() -> Self { Self { - migrations: HashMap::new(), + database_info_by_key: HashMap::new(), transaction_timeout: None, max_databases: None, + on_setup: None, } } - /// Register migrations for a database path. + /// Register a database by key and path, optionally with migrations. /// - /// Migrations will be run automatically at plugin initialization. - /// Multiple databases can have their own migrations. + /// Pass `None` for `migrator` when the database has no migrations. Migrations run + /// automatically at plugin initialization when provided. /// - /// # Arguments + /// Use this when the path is known at compile time. For paths derived from the `app` + /// instance (for example `app.path().app_data_dir()`), use [`on_setup`](Self::on_setup) + /// and [`SetupRegistrar::register_database`] instead. /// - /// * `path` - Database path (relative to app config directory) - /// * `migrator` - Migrator instance, typically from `sqlx::migrate!()` + /// The frontend must call `load()` with the registration **key**. /// /// # Example /// /// ```no_run /// use tauri_plugin_sqlite::Builder; + /// use std::path::PathBuf; + /// + /// const MAIN_DB_KEY: &str = "MAIN"; /// - /// # fn example() { - /// Builder::new() - /// .add_migrations("main.db", sqlx::migrate!("./doc-test-fixtures/migrations")) - /// .build::(); + /// # fn example() -> tauri_plugin_sqlite::Result<()> { + /// Builder::::new() + /// .register_database( + /// MAIN_DB_KEY, + /// PathBuf::from("/var/lib/myapp/main.db"), + /// Some(sqlx::migrate!("./doc-test-fixtures/migrations")), + /// )? + /// .build()?; + /// # Ok(()) /// # } /// ``` - pub fn add_migrations(mut self, path: &str, migrator: Migrator) -> Self { - self.migrations.insert(path.to_string(), Arc::new(migrator)); + /// + /// If the same key is registered more than once, the last registration will override + /// all previous ones. + /// + /// Distinct keys must map to distinct database paths. If two distinct keys register + /// the same path, plugin initialization returns [`Error::InvalidConfig`]. Registrations + /// from [`on_setup`](Self::on_setup) are validated when the merged map is initialized. + pub fn register_database( + mut self, + key: &str, + path: impl Into, + migrator: Option, + ) -> Result { self + .database_info_by_key + .insert(key.to_string(), validated_database_info(path, migrator)?); + + Ok(self) } /// Set the timeout for interruptible transactions. @@ -258,13 +442,62 @@ impl Builder { Ok(self) } + /// Register a hook that runs during plugin setup, once the `app` instance exists. + /// + /// This is the primary way to register database paths, because the legitimate absolute + /// paths usually depend on runtime values — for example paths derived from + /// `app.path().app_data_dir()`. The closure receives the app handle and a + /// [`SetupRegistrar`] on which you call [`register_database`](SetupRegistrar::register_database). + /// + /// Entries registered here are merged with those registered statically via + /// [`register_database`](Self::register_database); a later registration for the same + /// key overrides an earlier one. + /// + /// Returning `Err` from the hook aborts app startup (fail-fast). + /// + /// # Example + /// + /// ```no_run + /// use tauri_plugin_sqlite::Builder; + /// use tauri::Manager; + /// + /// const MAIN_DB_KEY: &str = "MAIN"; + /// + /// # fn example() -> tauri_plugin_sqlite::Result<()> { + /// Builder::::new() + /// .on_setup(|app, reg| { + /// let db = app.path().app_data_dir()?.join("main.db"); + /// reg.register_database( + /// MAIN_DB_KEY, + /// db, + /// Some(sqlx::migrate!("./doc-test-fixtures/migrations")) + /// )?; + /// Ok(()) + /// }) + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + pub fn on_setup( + mut self, + f: impl FnOnce(&AppHandle, &mut SetupRegistrar) -> Result<()> + Send + 'static, + ) -> Self { + self.on_setup = Some(Box::new(f)); + self + } + /// Build the plugin with command registration and state management. - pub fn build(self) -> tauri::plugin::TauriPlugin { - let migrations = Arc::new(self.migrations); + /// + /// Duplicate paths across distinct registration keys are rejected during plugin + /// initialization (the setup hook), after [`on_setup`](Self::on_setup) registrations + /// are merged with static ones. + pub fn build(self) -> Result> { + let database_info_by_key = self.database_info_by_key; let transaction_timeout = self.transaction_timeout; let max_databases = self.max_databases; + let on_setup = self.on_setup; - PluginBuilder::::new("sqlite") + Ok(PluginBuilder::::new("sqlite") .invoke_handler(tauri::generate_handler![ commands::load, commands::execute, @@ -297,26 +530,40 @@ impl Builder { app.manage(ActiveRegularTransactions::default()); app.manage(subscriptions::ActiveSubscriptions::default()); - // Initialize migration states as Pending for all registered databases + // Run the deferred setup hook (if any), merge with static registrations. + // Paths are validated and canonicalized at registration time. Hook errors + // abort startup (fail-fast). + let mut database_info_by_key = database_info_by_key; + if let Some(on_setup_action) = on_setup { + let mut registrar = SetupRegistrar::default(); + on_setup_action(app, &mut registrar)?; + database_info_by_key.extend(registrar.database_info_by_key); + } + + ensure_distinct_database_paths(&database_info_by_key)?; + + app.manage(RegisteredDatabases { + database_path_by_key: Arc::new(database_info_by_key.iter().map(|(key, info)| (key.clone(), info.path.clone())).collect()), + }); + let migration_states = app.state::(); { let mut states = migration_states.0.blocking_write(); - for path in migrations.keys() { - states.insert(path.clone(), MigrationState::new()); + for key in database_info_by_key.keys() { + states.insert(key.clone(), MigrationState::new()); } } - // Spawn parallel migration tasks for each registered database - if !migrations.is_empty() { - info!("Starting migrations for {} database(s)", migrations.len()); + for (key, info) in &database_info_by_key { + if let Some(migrator) = &info.migrator { + info!("Starting migrations for database {}", key); - for (path, migrator) in migrations.iter() { + let key = key.clone(); + let migrator = migrator.clone(); + let path = info.path.clone(); let app_handle = app.clone(); - let path = path.clone(); - let migrator = Arc::clone(migrator); - tauri::async_runtime::spawn(async move { - run_migrations_for_database(app_handle, path, migrator).await; + run_migrations_for_database(app_handle, &key, &path, &migrator).await; }); } } @@ -446,13 +693,15 @@ impl Builder { } } }) - .build() + .build()) } } /// Initializes the plugin with default configuration. pub fn init() -> tauri::plugin::TauriPlugin { - Builder::new().build() + Builder::::new() + .build() + .expect("failed to build sqlite plugin") } /// Run migrations for a single database and emit events. @@ -472,38 +721,39 @@ pub fn init() -> tauri::plugin::TauriPlugin { /// global registry and is reused when `load` creates its own wrapper. async fn run_migrations_for_database( app: tauri::AppHandle, - path: String, - migrator: Arc, + key: &str, + path: &Path, + migrator: &Arc, ) { let migration_states = app.state::(); // Update state to Running { let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Running); } } // Emit running event - emit_migration_event(&app, &path, "running", None, None); + emit_migration_event(&app, key, path, "running", None, None); // Resolve absolute path and connect - let abs_path = match resolve_migration_path(&path, &app) { + let abs_path = match resolve_database_path(key, &app) { Ok(p) => p, Err(e) => { let error_msg = e.to_string(); error!( "Failed to resolve migration path for {}: {}", - path, error_msg + key, error_msg ); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Failed(error_msg.clone())); } - emit_migration_event(&app, &path, "failed", None, Some(error_msg)); + emit_migration_event(&app, key, path, "failed", None, Some(error_msg)); return; } }; @@ -513,14 +763,14 @@ async fn run_migrations_for_database( Ok(wrapper) => wrapper, Err(e) => { let error_msg = e.to_string(); - error!("Failed to connect for migrations {}: {}", path, error_msg); + error!("Failed to connect for migrations {}: {}", key, error_msg); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Failed(error_msg.clone())); } - emit_migration_event(&app, &path, "failed", None, Some(error_msg)); + emit_migration_event(&app, key, path, "failed", None, Some(error_msg)); return; } }; @@ -529,30 +779,30 @@ async fn run_migrations_for_database( // Note: SQLx's migrator.run() doesn't provide per-migration callbacks, // so we can only report start and finish. For detailed per-migration events, // we would need to iterate migrations manually. - trace!("Running migrations for {}", path); + trace!("Running migrations for {}", key); - match db.run_migrations(&migrator).await { + match db.run_migrations(migrator).await { Ok(()) => { - info!("Migrations completed successfully for {}", path); + info!("Migrations completed successfully for {}", key); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Complete); } let migration_count = migrator.iter().count(); - emit_migration_event(&app, &path, "completed", Some(migration_count), None); + emit_migration_event(&app, key, path, "completed", Some(migration_count), None); } Err(e) => { let error_msg = e.to_string(); - error!("Migration failed for {}: {}", path, error_msg); + error!("Migration failed for {}: {}", key, error_msg); let mut states = migration_states.0.write().await; - if let Some(state) = states.get_mut(&path) { + if let Some(state) = states.get_mut(key) { state.update_status(MigrationStatus::Failed(error_msg.clone())); } - emit_migration_event(&app, &path, "failed", None, Some(error_msg)); + emit_migration_event(&app, key, path, "failed", None, Some(error_msg)); } } } @@ -560,13 +810,15 @@ async fn run_migrations_for_database( /// Emit a migration event to the frontend and cache it. fn emit_migration_event( app: &tauri::AppHandle, - db_path: &str, + db_key: &str, + db_path: &Path, status: &str, migration_count: Option, error: Option, ) { let event = MigrationEvent { - db_path: db_path.to_string(), + db_key: db_key.to_string(), + db_path: db_path.to_path_buf(), status: status.to_string(), migration_count, error, @@ -575,7 +827,7 @@ fn emit_migration_event( // Cache event in migration state let migration_states = app.state::(); if let Ok(mut states) = migration_states.0.try_write() - && let Some(state) = states.get_mut(db_path) + && let Some(state) = states.get_mut(db_key) { state.cache_event(event.clone()); } @@ -585,36 +837,373 @@ fn emit_migration_event( } } -/// Resolve database path for migrations. +/// Connect to a registered database by its registration key. /// -/// Delegates to `resolve::resolve_database_path` to ensure consistent path validation -/// across all entry points. -fn resolve_migration_path( - path: &str, - app: &tauri::AppHandle, -) -> Result { - crate::resolve::resolve_database_path(path, app) +/// Opens the database through the same path as the frontend `load` IPC command +/// ([`connect_to_database`]): awaits migrations, enforces max-database limits, and +/// stores the wrapper in [`DbInstances`]. Returns a [`DatabaseWrapper`] for direct +/// toolkit use. +/// +/// The `database_key` must match a key registered via +/// [`Builder::register_database`] or [`SetupRegistrar::register_database`]. +/// +/// # Why use a key? +/// +/// Database paths are usually resolved once during plugin setup — for example +/// `app.path().app_data_dir()?.join("main.db")` in [`Builder::on_setup`]. Without +/// registration keys, every call site would repeat that path discovery or keep its own +/// `PathBuf`. Registration stores the key-to-path mapping once; `connect` reuses the key +/// so callers do not supply a filesystem path on every open. +/// +/// On mobile, path discovery is not a cheap string join. Resolvers such as +/// [tauri-plugin-fs-resolver](https://github.com/silvermine/tauri-plugin-fs-resolver) +/// call platform-native APIs so paths match OS sandbox rules. On Android that means a +/// JNI call into Kotlin `Context` (e.g. `getFilesDir()`) on each resolve — noticeably +/// more expensive than a local HashMap lookup, and a different kind of boundary than +/// TypeScript-to-Rust IPC (in-process JNI vs webview bridge). Register the resolved +/// `PathBuf` once in `on_setup`; every later `connect(database_key)` only looks up that +/// key in [`RegisteredDatabases`] — no repeat native or JNI work. +/// +/// For webview/frontend access, use `Database.load(dbKey)` instead. +/// +/// # Example +/// +/// ```ignore +/// use tauri::{Manager, Runtime}; +/// use tauri_plugin_sqlite::Connection; +/// +/// // During setup (on_setup): +/// // reg.register_database("MAIN", app.path().app_data_dir()?.join("main.db"), None); +/// +/// async fn read_users(app: tauri::AppHandle) -> tauri_plugin_sqlite::Result<()> { +/// let db = app.connect("MAIN").await?; +/// let rows = db.fetch_all("SELECT * FROM users".into(), vec![]).execute().await?; +/// Ok(()) +/// } +/// ``` +pub trait Connection { + /// Connect with default pool configuration. + fn connect(&self, database_key: &str) -> impl Future> + Send; + + /// Connect with custom [`SqliteDatabaseConfig`] (pool sizes, idle timeout). + fn connect_with_config( + &self, + database_key: &str, + config: SqliteDatabaseConfig, + ) -> impl Future> + Send; + + /// Close the loaded instance for a registered database key. + /// + /// Returns `true` if the database was loaded and successfully closed. + /// Returns `false` if the database was not loaded (nothing to close). + /// Active subscriptions and in-flight transactions for this key are aborted + /// before closing (10 second deadline for transactions). + fn close(&self, database_key: &str) -> impl Future> + Send; +} + +/// Delegates to [`connect_to_database`]: same open path as the `load` IPC command. +impl Connection for AppHandle { + async fn connect(&self, database_key: &str) -> Result { + let response = connect_to_database(self, database_key, None).await?; + Ok(response.wrapper) + } + + async fn connect_with_config( + &self, + database_key: &str, + config: SqliteDatabaseConfig, + ) -> Result { + let response = connect_to_database(self, database_key, Some(config)).await?; + Ok(response.wrapper) + } + + async fn close(&self, database_key: &str) -> Result { + let instances = self + .try_state::() + .ok_or(Error::Other("DbInstances state not found".to_string()))?; + let subs = self.try_state::().ok_or(Error::Other( + "ActiveSubscriptions state not found".to_string(), + ))?; + let interruptible_txs = + self + .try_state::() + .ok_or(Error::Other( + "ActiveInterruptibleTransactions state not found".to_string(), + ))?; + let regular_txs = self + .try_state::() + .ok_or(Error::Other( + "ActiveRegularTransactions state not found".to_string(), + ))?; + close_database( + database_key, + &instances, + &subs, + &interruptible_txs, + ®ular_txs, + ) + .await + } +} + +struct ConnectionResponse { + path: PathBuf, + wrapper: DatabaseWrapper, +} + +async fn connect_to_database( + app: &AppHandle, + db_key: &str, + custom_config: Option, +) -> Result { + let migration_states = app.state::(); + let db_instances = app.state::(); + + // Wait for migrations to complete if registered for this database + await_migrations(&migration_states, db_key).await?; + + let path = resolve_database_path(db_key, app)?; + + let instances = db_instances.inner.read().await; + + // Return cached if db was already loaded + if let Some(wrapper) = instances.get(db_key) { + return Ok(ConnectionResponse { + path, + wrapper: wrapper.clone(), + }); + } + + drop(instances); // Release read lock before acquiring write lock + + let mut instances = db_instances.inner.write().await; + + // Check database count limit before creating a new connection. + // This check is before entry() to avoid borrow conflicts, and the write lock + // prevents races between the len() check and the insert. + if !instances.contains_key(db_key) && instances.len() >= db_instances.max { + return Err(Error::TooManyDatabases(db_instances.max)); + } + + // Use entry API to atomically check and insert, avoiding race conditions + // where two callers could both create wrappers + use std::collections::hash_map::Entry; + match instances.entry(db_key.to_string()) { + Entry::Occupied(entry) => { + // Another caller won the race and inserted while we waited for write lock + Ok(ConnectionResponse { + path, + wrapper: entry.get().clone(), + }) + } + Entry::Vacant(entry) => { + // We won the race, create and insert the wrapper + let wrapper = DatabaseWrapper::connect(&path, custom_config).await?; + entry.insert(wrapper.clone()); + Ok(ConnectionResponse { path, wrapper }) + } + } +} + +/// Wait for migrations to complete for a database, if any are registered. +/// +/// Returns Ok(()) if: +/// - No migrations are registered for this database +/// - Migrations completed successfully +/// +/// Returns Err if migrations failed. +async fn await_migrations(migration_states: &MigrationStates, db_key: &str) -> Result<()> { + loop { + // Get notify handle before checking status + let notify = { + match migration_states.0.read().await.get(db_key) { + // No migrations registered for this database + None => return Ok(()), + + Some(state) => match &state.status { + // Migrations completed successfully + MigrationStatus::Complete => return Ok(()), + + // Migrations failed - return the error + MigrationStatus::Failed(error) => { + return Err(Error::Migration(sqlx::migrate::MigrateError::Source( + error.clone().into(), + ))); + } + + // Migrations still pending or running - wait for notification + MigrationStatus::Pending | MigrationStatus::Running => state.notify.clone(), + }, + } + }; + + // Wait for migration state change + notify.notified().await; + } +} + +pub(crate) async fn close_database( + db_key: &str, + db_instances: &DbInstances, + active_subs: &ActiveSubscriptions, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result { + active_subs.remove_for_db(db_key).await; + + sqlx_sqlite_toolkit::cleanup_transactions_for_db(db_key, interruptible_txs, regular_txs).await?; + + let mut instances = db_instances.inner.write().await; + + if let Some(wrapper) = instances.remove(db_key) { + wrapper.close().await?; + Ok(true) + } else { + Ok(false) // Database wasn't loaded + } +} + +/// Resolve a registered database path by key. +/// +/// The `db_key` must match a key registered via +/// [`crate::Builder::register_database`] / [`crate::SetupRegistrar::register_database`]. +/// +/// Returns `Err(Error::PathNotRegistered)` if the key is not registered. +fn resolve_database_path(db_key: &str, app: &AppHandle) -> Result { + let registered_databases = app.state::(); + + if let Some(path) = registered_databases.database_path_by_key.get(db_key) { + return Ok(path.clone()); + } + + Err(Error::PathNotRegistered(db_key.to_string())) } #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; + use tauri::plugin::Plugin; + use tauri::test::{MockRuntime, mock_app, mock_builder, mock_context, noop_assets}; + + fn builder_with_duplicate_paths(temp_dir: &tempfile::TempDir) -> Builder { + let path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + + Builder::::new() + .register_database("MAIN", &path, None) + .unwrap() + .register_database("BACKUP", &path, None) + .unwrap() + } + + fn mock_app_with_registrations( + database_path_by_key: HashMap, + ) -> tauri::App { + let app = tauri::test::mock_app(); + app.manage(RegisteredDatabases { + database_path_by_key: Arc::new(database_path_by_key), + }); + app.manage(DbInstances::default()); + app.manage(MigrationStates::default()); + app + } + + #[tokio::test] + async fn test_connect_to_database_registered_key() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let mut registrations = HashMap::new(); + registrations.insert("MAIN".to_string(), db_path.clone()); + + let app = mock_app_with_registrations(registrations); + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .unwrap(); + + assert_eq!(response.path, db_path); + } + + #[tokio::test] + async fn test_connect_unregistered_key_returns_path_not_registered() { + let app = mock_app_with_registrations(HashMap::new()); + let err = match connect_to_database(app.handle(), "MAIN", None).await { + Err(error) => error, + Ok(_) => panic!("expected unregistered key to fail"), + }; + + assert!(matches!(err, Error::PathNotRegistered(_))); + assert_eq!(err.to_string(), "database key not registered: MAIN"); + } + + #[test] + fn test_register_database_last_registration_wins() { + let temp_dir = tempfile::tempdir().unwrap(); + let first_path = validate::validate_database_path(temp_dir.path().join("first.db")).unwrap(); + let second_path = + validate::validate_database_path(temp_dir.path().join("second.db")).unwrap(); + + let mut database_path_by_key = HashMap::new(); + database_path_by_key.insert("MAIN".to_string(), first_path); + database_path_by_key.insert("MAIN".to_string(), second_path.clone()); + + let app = mock_app_with_registrations(database_path_by_key); + let path = resolve_database_path("MAIN", app.handle()).unwrap(); + + assert_eq!(path, second_path); + } + + #[test] + fn test_setup_rejects_duplicate_paths_for_distinct_keys() { + let temp_dir = tempfile::tempdir().unwrap(); + let mut plugin = builder_with_duplicate_paths(&temp_dir).build().unwrap(); + let app = mock_app(); + + let err = plugin + .initialize(app.handle(), serde_json::Value::default()) + .unwrap_err(); + + let err_msg = err.to_string(); + assert!(err_msg.contains("MAIN")); + assert!(err_msg.contains("BACKUP")); + assert!(err_msg.contains("invalid configuration")); + } + + #[test] + fn test_app_build_rejects_duplicate_paths_for_distinct_keys() { + let temp_dir = tempfile::tempdir().unwrap(); + let plugin = builder_with_duplicate_paths(&temp_dir).build().unwrap(); + + let err = mock_builder() + .plugin(plugin) + .build(mock_context(noop_assets())) + .unwrap_err(); + + match err { + tauri::Error::PluginInitialization(name, message) => { + assert_eq!(name, "sqlite"); + assert!(message.contains("MAIN")); + assert!(message.contains("BACKUP")); + assert!(message.contains("invalid configuration")); + } + other => panic!("expected PluginInitialization, got {other:?}"), + } + } #[test] fn test_max_databases_rejects_zero() { - let err = Builder::new().max_databases(0).unwrap_err(); + let err = Builder::::new().max_databases(0).unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_))); } #[test] fn test_max_databases_accepts_positive() { - let builder = Builder::new().max_databases(1).unwrap(); + let builder = Builder::::new().max_databases(1).unwrap(); assert_eq!(builder.max_databases, Some(1)); } #[test] fn test_transaction_timeout_rejects_zero() { - let err = Builder::new() + let err = Builder::::new() .transaction_timeout(std::time::Duration::ZERO) .unwrap_err(); assert!(matches!(err, Error::InvalidConfig(_))); @@ -622,7 +1211,7 @@ mod tests { #[test] fn test_transaction_timeout_accepts_positive() { - let builder = Builder::new() + let builder = Builder::::new() .transaction_timeout(std::time::Duration::from_secs(1)) .unwrap(); assert_eq!( diff --git a/src/resolve.rs b/src/resolve.rs deleted file mode 100644 index 08d3b25..0000000 --- a/src/resolve.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::fs::create_dir_all; -use std::path::{Component, Path, PathBuf}; - -use sqlx_sqlite_conn_mgr::SqliteDatabaseConfig; -use sqlx_sqlite_toolkit::DatabaseWrapper; -use tauri::{AppHandle, Manager, Runtime}; - -use crate::Error; - -/// Connect to a SQLite database via the connection manager, resolving -/// the path relative to the app config directory. -/// -/// This is the Tauri-specific connection method that resolves relative paths -/// before delegating to the toolkit's `DatabaseWrapper::connect()`. -pub async fn connect( - path: &str, - app: &AppHandle, - custom_config: Option, -) -> Result { - let abs_path = resolve_database_path(path, app)?; - Ok(DatabaseWrapper::connect(&abs_path, custom_config).await?) -} - -/// Resolve database file path relative to app config directory. -/// -/// Paths are joined to `app_config_dir()` (e.g., `Library/Application Support/${bundleIdentifier}` -/// on iOS). Special paths like `:memory:` are passed through unchanged. -/// -/// Returns `Err(Error::PathTraversal)` if the path attempts to escape the app config directory -/// via absolute paths, `..` segments, or null bytes. -pub fn resolve_database_path(path: &str, app: &AppHandle) -> Result { - let app_path = app - .path() - .app_config_dir() - .map_err(|_| Error::InvalidPath("No app config path found".to_string()))?; - - create_dir_all(&app_path)?; - - validate_and_resolve(path, &app_path) -} - -/// Validate a user-supplied path and resolve it against a base directory. -/// -/// In-memory database paths are passed through unchanged. All other paths are validated -/// to ensure they cannot escape the base directory. -fn validate_and_resolve(path: &str, base: &Path) -> Result { - // Pass through in-memory database paths unchanged — they don't touch the filesystem. - // Matches the same patterns as `is_memory_database` in sqlx-sqlite-conn-mgr. - if is_memory_path(path) { - return Ok(PathBuf::from(path)); - } - - // Reject null bytes — these can truncate paths in C-level filesystem calls - if path.contains('\0') { - return Err(Error::PathTraversal("path contains null byte".to_string())); - } - - let rel = Path::new(path); - - // Reject absolute paths — PathBuf::join replaces the base when given an absolute path - if rel.is_absolute() { - return Err(Error::PathTraversal( - "absolute paths are not allowed".to_string(), - )); - } - - // Reject parent directory components — prevents escaping the base via `../` - for component in rel.components() { - if matches!(component, Component::ParentDir) { - return Err(Error::PathTraversal( - "parent directory references are not allowed".to_string(), - )); - } - } - - // Join and canonicalize to verify containment. The parent directory is canonicalized - // because the file may not exist yet. - let joined = base.join(rel); - let canonical_base = base - .canonicalize() - .map_err(|e| Error::InvalidPath(format!("cannot canonicalize base path: {e}")))?; - - let canonical_resolved = if joined.exists() { - joined.canonicalize() - } else { - // Ensure intermediate directories exist so that canonicalize can resolve the - // parent. This matches the caller's expectation that nested relative paths like - // "subdir/mydb.db" work without pre-creating "subdir/". - let parent = joined - .parent() - .ok_or_else(|| Error::InvalidPath("path has no parent".to_string()))?; - - create_dir_all(parent)?; - - parent - .canonicalize() - .map(|p| p.join(joined.file_name().unwrap_or_default())) - } - .map_err(|e| Error::InvalidPath(format!("cannot canonicalize path: {e}")))?; - - if !canonical_resolved.starts_with(&canonical_base) { - return Err(Error::PathTraversal( - "resolved path escapes the base directory".to_string(), - )); - } - - // Return the original (non-canonicalized) joined path for consistency with how the - // rest of the codebase references database paths. - Ok(joined) -} - -/// Check if a path string represents an in-memory SQLite database. -/// -/// Matches the same patterns as `is_memory_database` in `sqlx-sqlite-conn-mgr`: -/// `:memory:`, `file::memory:*` URIs, and `mode=memory` query parameters. -fn is_memory_path(path: &str) -> bool { - path == ":memory:" - || path.starts_with("file::memory:") - || (path.starts_with("file:") && path.contains("mode=memory")) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::fs; - - /// Helper that creates a temporary base directory for testing. - fn make_temp_base() -> PathBuf { - let dir = std::env::temp_dir().join(format!("tauri_sqlite_test_{}", std::process::id())); - fs::create_dir_all(&dir).unwrap(); - dir - } - - #[test] - fn test_simple_filename() { - let base = make_temp_base(); - let result = validate_and_resolve("mydb.db", &base).unwrap(); - assert_eq!(result, base.join("mydb.db")); - } - - #[test] - fn test_subdirectory_path() { - let base = make_temp_base(); - // Intermediate directories are auto-created — no manual setup needed - let result = validate_and_resolve("subdir/mydb.db", &base).unwrap(); - assert_eq!(result, base.join("subdir/mydb.db")); - assert!(base.join("subdir").is_dir()); - } - - #[test] - fn test_nested_subdirectory_path() { - let base = make_temp_base(); - let result = validate_and_resolve("a/b/c/mydb.db", &base).unwrap(); - assert_eq!(result, base.join("a/b/c/mydb.db")); - assert!(base.join("a/b/c").is_dir()); - } - - #[test] - fn test_memory_passthrough() { - let base = make_temp_base(); - assert_eq!( - validate_and_resolve(":memory:", &base).unwrap(), - PathBuf::from(":memory:"), - ); - } - - #[test] - fn test_file_memory_uri_passthrough() { - let base = make_temp_base(); - assert_eq!( - validate_and_resolve("file::memory:?cache=shared", &base).unwrap(), - PathBuf::from("file::memory:?cache=shared"), - ); - } - - #[test] - fn test_mode_memory_passthrough() { - let base = make_temp_base(); - assert_eq!( - validate_and_resolve("file:test?mode=memory", &base).unwrap(), - PathBuf::from("file:test?mode=memory"), - ); - } - - #[test] - fn test_rejects_parent_traversal() { - let base = make_temp_base(); - let err = validate_and_resolve("../../../etc/passwd", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_absolute_path() { - let base = make_temp_base(); - let err = validate_and_resolve("/etc/passwd", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_embedded_traversal() { - let base = make_temp_base(); - let err = validate_and_resolve("foo/../../bar", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_null_byte() { - let base = make_temp_base(); - let err = validate_and_resolve("path\0evil", &base).unwrap_err(); - assert!(matches!(err, Error::PathTraversal(_))); - } - - #[test] - fn test_rejects_non_uri_mode_memory() { - let base = make_temp_base(); - // A bare filename containing "mode=memory" is not a valid SQLite URI — - // it should go through normal path validation, not be passed through. - let result = validate_and_resolve("evil.db?mode=memory", &base).unwrap(); - assert_eq!(result, base.join("evil.db?mode=memory")); - } -} diff --git a/src/subscriptions.rs b/src/subscriptions.rs index 2805d6c..21207a2 100644 --- a/src/subscriptions.rs +++ b/src/subscriptions.rs @@ -113,8 +113,8 @@ pub struct ObserverConfigParams { struct ActiveSubscription { /// Abort handle for the subscription forwarding task. abort_handle: tokio::task::AbortHandle, - /// Database path this subscription is for. - db_path: String, + /// Database key this subscription is for. + db_key: String, } /// Global state tracking all active observer subscriptions. @@ -123,13 +123,13 @@ pub struct ActiveSubscriptions(Arc>>) impl ActiveSubscriptions { /// Insert a new subscription. - pub async fn insert(&self, id: String, db_path: String, abort_handle: tokio::task::AbortHandle) { + pub async fn insert(&self, id: String, db_key: String, abort_handle: tokio::task::AbortHandle) { let mut subs = self.0.write().await; subs.insert( id, ActiveSubscription { abort_handle, - db_path, + db_key, }, ); } @@ -146,11 +146,11 @@ impl ActiveSubscriptions { } /// Remove and abort all subscriptions for a specific database. - pub async fn remove_for_db(&self, db_path: &str) { + pub async fn remove_for_db(&self, db_key: &str) { let mut subs = self.0.write().await; let keys_to_remove: Vec = subs .iter() - .filter(|(_, sub)| sub.db_path == db_path) + .filter(|(_, sub)| sub.db_key == db_key) .map(|(k, _)| k.clone()) .collect(); @@ -162,9 +162,9 @@ impl ActiveSubscriptions { } /// Count active subscriptions for a specific database. - pub async fn count_for_db(&self, db_path: &str) -> usize { + pub async fn count_for_db(&self, db_key: &str) -> usize { let subs = self.0.read().await; - subs.values().filter(|sub| sub.db_path == db_path).count() + subs.values().filter(|sub| sub.db_key == db_key).count() } /// Abort all subscriptions (for cleanup on app exit). diff --git a/src/validate.rs b/src/validate.rs new file mode 100644 index 0000000..ee24103 --- /dev/null +++ b/src/validate.rs @@ -0,0 +1,171 @@ +use std::fs; +use std::path::{Component, Path, PathBuf}; + +use sqlx_sqlite_conn_mgr::{canonicalize_database_path, is_memory_database}; + +use crate::{Error, Result}; + +/// Validate and normalize a database path at registration time. +/// +/// In-memory databases (`:memory:`, `file::memory:*`, and `file:` URIs with an exact +/// `mode=memory` query parameter) are returned unchanged. File paths must not contain +/// null bytes or `..` components, must be absolute, and are canonicalized for consistent +/// lookups (symlink-safe when the path or its parent exists). +pub fn validate_database_path(path: impl AsRef) -> Result { + let path = path.as_ref(); + + if is_memory_database(path) { + return Ok(path.to_path_buf()); + } + + let path_str = path.to_str().ok_or_else(|| { + Error::InvalidPath(format!( + "database path is not valid UTF-8: {}", + path.display() + )) + })?; + + if path_str.contains('\0') { + return Err(Error::PathTraversal(format!( + "database path contains null byte: {path_str}" + ))); + } + + if path + .components() + .any(|component| matches!(component, Component::ParentDir)) + { + return Err(Error::PathTraversal(format!( + "database path contains parent traversal: {path_str}" + ))); + } + + if !path.is_absolute() { + return Err(Error::InvalidPath(format!( + "database path must be absolute: {path_str}" + ))); + } + + if let Some(parent) = path.parent() + && !parent.as_os_str().is_empty() + { + fs::create_dir_all(parent).map_err(|error| { + Error::InvalidPath(format!( + "failed to create parent directory for database path {path_str}: {error}" + )) + })?; + } + + canonicalize_database_path(path).map_err(|error| { + Error::InvalidPath(format!( + "failed to canonicalize database path {path_str}: {error}" + )) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::sync::atomic::{AtomicU64, Ordering}; + + fn make_temp_dir() -> PathBuf { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let n = COUNTER.fetch_add(1, Ordering::Relaxed); + let dir = + std::env::temp_dir().join(format!("tauri_sqlite_test_{}_{}", std::process::id(), n)); + fs::create_dir_all(&dir).unwrap(); + dir + } + + #[test] + fn test_memory_passthrough() { + assert_eq!( + validate_database_path(":memory:").unwrap(), + PathBuf::from(":memory:"), + ); + } + + #[test] + fn test_file_memory_uri_passthrough() { + assert_eq!( + validate_database_path("file::memory:?cache=shared").unwrap(), + PathBuf::from("file::memory:?cache=shared"), + ); + } + + #[test] + fn test_mode_memory_passthrough() { + assert_eq!( + validate_database_path("file:test?mode=memory").unwrap(), + PathBuf::from("file:test?mode=memory"), + ); + } + + #[test] + fn test_mode_memory_substring_in_value_is_not_treated_as_memory() { + let err = validate_database_path("file:/home/user/real.db?x=mode=memory").unwrap_err(); + assert!(matches!(err, Error::InvalidPath(_))); + } + + #[test] + fn test_creates_missing_parent_directory() { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let n = COUNTER.fetch_add(1, Ordering::Relaxed); + let base = std::env::temp_dir().join(format!( + "tauri_sqlite_test_missing_parent_{}_{}", + std::process::id(), + n + )); + let db_path = base.join("nested").join("main.db"); + + let result = validate_database_path(&db_path).unwrap(); + + assert!(base.join("nested").is_dir()); + assert_eq!( + result, + base.canonicalize().unwrap().join("nested").join("main.db") + ); + + fs::remove_dir_all(&base).unwrap(); + } + + #[test] + fn test_accepts_absolute_path() { + let dir = make_temp_dir(); + let abs = dir.join("exact.db"); + + let result = validate_database_path(&abs).unwrap(); + assert_eq!(result, dir.canonicalize().unwrap().join("exact.db")); + } + + #[test] + fn test_rejects_relative_path() { + let err = validate_database_path("relative.db").unwrap_err(); + assert!(matches!(err, Error::InvalidPath(_))); + } + + #[test] + fn test_rejects_absolute_path_with_parent_traversal() { + let dir = make_temp_dir(); + let abs_str = format!("{}/../escape.db", dir.to_str().unwrap()); + + let err = validate_database_path(&abs_str).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_absolute_path_with_embedded_traversal() { + let dir = make_temp_dir(); + let abs_str = format!("{}/sub/../../escape.db", dir.to_str().unwrap()); + + let err = validate_database_path(&abs_str).unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } + + #[test] + fn test_rejects_null_byte() { + let err = validate_database_path("path\0evil").unwrap_err(); + assert!(matches!(err, Error::PathTraversal(_))); + } +}