From d61d65b6525811209b41d2e3a0b7caee8a38e5f1 Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Fri, 27 Feb 2026 10:14:53 +0100 Subject: [PATCH 01/76] initial implementation --- Cargo.lock | 9 + docs/specs/rerandomization.md | 244 ++++++++++ .../continuous-rerand-local.sh | 112 +++++ .../bin/iris-mpc-upgrade/rerandomize_db.rs | 104 ++--- .../iris-mpc-upgrade/run-rerand-e2e-tests.sh | 53 +++ iris-mpc-bins/bin/iris-mpc/server.rs | 14 +- iris-mpc-common/src/helpers/sync.rs | 20 + iris-mpc-store/src/lib.rs | 1 + iris-mpc-store/src/rerand.rs | 441 ++++++++++++++++++ iris-mpc-upgrade/Cargo.toml | 12 + iris-mpc-upgrade/src/config.rs | 34 ++ iris-mpc-upgrade/src/continuous_rerand.rs | 295 ++++++++++++ iris-mpc-upgrade/src/epoch.rs | 223 +++++++++ iris-mpc-upgrade/src/lib.rs | 3 + iris-mpc-upgrade/src/rerandomization.rs | 59 ++- iris-mpc-upgrade/src/s3_coordination.rs | 279 +++++++++++ .../tests/continuous_rerand_e2e.rs | 312 +++++++++++++ iris-mpc-upgrade/tests/test_utils.rs | 371 +++++++++++++++ iris-mpc/src/server/mod.rs | 13 + .../20260226000001_add_rerand_epoch.down.sql | 14 + .../20260226000001_add_rerand_epoch.up.sql | 15 + ...0226000002_create_rerand_progress.down.sql | 1 + ...260226000002_create_rerand_progress.up.sql | 8 + 23 files changed, 2566 insertions(+), 71 deletions(-) create mode 100644 docs/specs/rerandomization.md create mode 100755 iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh create mode 100755 iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh create mode 100644 iris-mpc-store/src/rerand.rs create mode 100644 iris-mpc-upgrade/src/continuous_rerand.rs create mode 100644 iris-mpc-upgrade/src/epoch.rs create mode 100644 iris-mpc-upgrade/src/s3_coordination.rs create mode 100644 iris-mpc-upgrade/tests/continuous_rerand_e2e.rs create mode 100644 iris-mpc-upgrade/tests/test_utils.rs create mode 100644 migrations/20260226000001_add_rerand_epoch.down.sql create mode 100644 migrations/20260226000001_add_rerand_epoch.up.sql create mode 100644 migrations/20260226000002_create_rerand_progress.down.sql create mode 100644 migrations/20260226000002_create_rerand_progress.up.sql diff --git a/Cargo.lock b/Cargo.lock index c0aff9899b..b11660e34b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3183,12 +3183,18 @@ dependencies = [ "ark-ec", "ark-ff", "ark-serialize", + "aws-config", + "aws-sdk-s3", + "aws-sdk-secretsmanager", "axum 0.7.7", + "base64", "blake3", "bytemuck", "clap", "criterion", + "dotenvy", "eyre", + "futures", "iris-mpc-common", "iris-mpc-store", "itertools 0.13.0", @@ -3198,9 +3204,12 @@ dependencies = [ "rayon", "serde", "serde-big-array", + "serde_json", "sha2", + "sqlx", "thiserror 1.0.65", "tokio", + "tokio-util", "tonic", "tonic-build", "tracing", diff --git a/docs/specs/rerandomization.md b/docs/specs/rerandomization.md new file mode 100644 index 0000000000..d541aeef29 --- /dev/null +++ b/docs/specs/rerandomization.md @@ -0,0 +1,244 @@ +# Continuous Rerandomization Plan + +## Overview + +Replaces the existing, one-off rerandomization protocol by a continuous, online process that rerandomizes shares while the system is running. No downtime or restart required. + +Key design decision: in-memory shares are less likely to be exfiltrated, so only the DB (at-rest persistence) is rerandomized. The actor is completely unmodified. The rerand server handles everything, writing to a staging schema and then copying to live once all parties confirm. + +## Architecture + +1. **Rerand Server** (modified `iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs`, separate process, one per party) — rerandomizes shares, writes to staging, coordinates with peers via S3 markers, copies confirmed chunks to live DB. Replaces the existing one-off `RerandomizeDb` subcommand with a new `RerandomizeContinuous` subcommand. Core rerandomization logic in `iris-mpc-upgrade/src/rerandomization.rs` is reused; the new subcommand adds the continuous loop, S3 coordination, and staging management. +2. **Main Server** (existing, minimal changes) — at startup, syncs rerand progress with peers and catches up any missing chunks from staging before loading the DB into memory. + +The GPU actor, batch processing, and result processor are completely untouched. + +## Seed & Randomness + +One epoch is active at a time. At the start of each epoch: + +1. Each rerand server generates a fresh BLS12-381 keypair +2. Private key is saved to Secrets Manager at `{env}/iris-mpc-db-rerandomization/epoch-{E}/private-key-party-{P}` +3. Public key is uploaded to S3 at `s3://bucket/rerand/epoch-{E}/party-{P}/public-key` +4. Each rerand server downloads the other two parties' public keys from S3 (polling until all present) +5. Each derives the same 32-byte `shared_secret` via the BLS12-381 pairing + +Only the rerand server needs access to the key. The main server never touches it. + +### Keygen is idempotent on restart + +When starting an epoch, the rerand server: + +1. Checks if an epoch-scoped private key already exists in Secrets Manager at `{env}/iris-mpc-db-rerandomization/epoch-{E}/private-key-party-{P}` +2. If yes: loads it, derives the public key, and uploads the public key to S3 if not already present (covers crash-after-SM-write-before-S3-upload) +3. If no: generates a new keypair, saves the private key to Secrets Manager first, then uploads the public key to S3 + +Secrets Manager is checked first because the private key is written to SM before the public key is uploaded to S3. If we crash between the two writes, on restart we find the key in SM and re-upload to S3. + +### Epoch transition + +One epoch at a time, no overlap: + +1. All three rerand servers finish processing all chunks for epoch E +2. Each server uploads a completion marker: `s3://bucket/rerand/epoch-{E}/party-{P}/complete` +3. Each server polls until all three completion markers exist +4. Keys for epoch E are deleted from Secrets Manager — old secret is destroyed, old shares (overwritten in live DB) are unrecoverable +5. Epoch E+1 begins: create/publish `manifest.json`, keygen, derive new `shared_secret`, start processing + +Old S3 markers under `epoch-{E}/` are left in place (no active cleanup). Use S3 lifecycle policies to reap old epoch prefixes after a retention period. + +On restart mid-epoch: private key is still in SM, public keys and markers are still in S3, `rerand_progress` table tells you the current epoch and which chunk to resume from. Re-derive `shared_secret`, continue. + +## S3 Coordination Bus + +All cross-party coordination uses S3 markers in a shared bucket. Each party writes to its own prefixed paths. Marker layout: + +``` +s3://bucket/rerand/epoch-{E}/party-{P}/public-key # public key for DH +s3://bucket/rerand/epoch-{E}/party-{P}/max-id # party P watermark for manifest (MAX(id)) +s3://bucket/rerand/epoch-{E}/party-{P}/manifest.json # epoch chunking manifest (party 0 writes, others read) +s3://bucket/rerand/epoch-{E}/party-{P}/chunk-{K}/staged # chunk K staging committed +s3://bucket/rerand/epoch-{E}/party-{P}/complete # epoch E fully done +``` + +Coordination is polling-based: a rerand server checks for peer markers by listing the S3 prefix. A few seconds of polling latency is fine for background work. + +Authentication: the shared bucket uses IAM prefix policies to scope write access per party. Each party can only write to `s3://bucket/rerand/epoch-*/party-{P}/*`. All parties can read/list the full `s3://bucket/rerand/epoch-{E}/` prefix to observe peer markers. The manifest is written by the designated writer (party 0) under its own prefix (`party-0/manifest.json`) and is read-only for others. + +## Schema Changes + +### New column on `irises` + +```sql +ALTER TABLE irises ADD COLUMN rerand_epoch INTEGER NOT NULL DEFAULT 0; +``` + +### Modified `increment_version_id` trigger + +```sql +CREATE OR REPLACE FUNCTION increment_version_id() +RETURNS TRIGGER AS $$ +BEGIN + IF (OLD.left_code IS DISTINCT FROM NEW.left_code OR + OLD.left_mask IS DISTINCT FROM NEW.left_mask OR + OLD.right_code IS DISTINCT FROM NEW.right_code OR + OLD.right_mask IS DISTINCT FROM NEW.right_mask) + AND NEW.rerand_epoch IS NOT DISTINCT FROM OLD.rerand_epoch THEN + NEW.version_id = COALESCE(OLD.version_id, 0) + 1; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; +``` + +When `rerand_epoch` changes (rerandomization), share data changes but `version_id` stays the same. When `rerand_epoch` stays the same (user-facing modification), `version_id` bumps as before. + +### Staging schema + +Each party has a staging schema (e.g. `SMPC_rerand_staging`) with: + +```sql +CREATE TABLE irises ( + epoch INTEGER NOT NULL, + id BIGINT NOT NULL, + chunk_id INTEGER NOT NULL, + left_code BYTEA, + left_mask BYTEA, + right_code BYTEA, + right_mask BYTEA, + original_version_id SMALLINT, + rerand_epoch INTEGER, + PRIMARY KEY (epoch, id) +); +``` + +### Coordination table + +A `rerand_progress` table in each party's DB: + +```sql +CREATE TABLE rerand_progress ( + epoch INTEGER NOT NULL, + chunk_id INTEGER NOT NULL, + staging_written BOOLEAN NOT NULL DEFAULT FALSE, + all_confirmed BOOLEAN NOT NULL DEFAULT FALSE, + live_applied BOOLEAN NOT NULL DEFAULT FALSE, + PRIMARY KEY (epoch, chunk_id) +); +``` + +Chunk ranges are derived from the manifest (`chunk_size`, `max_id_inclusive`) and `chunk_id`, so they are not stored here. + +Lifecycle: `staging_written` → `all_confirmed` → `live_applied`. + +## Flow + +### Step 1: Rerand Server (per party, separate process) + +Runs continuously: + +1. Determine the active epoch E and load its manifest (the highest epoch with a manifest at `s3://bucket/rerand/epoch-{E}/party-0/manifest.json` but without all three completion markers). If no manifest exists for the next epoch, create it (party 0 only): collect watermarks, compute `max_id_inclusive`, write `manifest.json`. +2. Derive `shared_secret` for epoch E (keygen or resume — see above) +3. Pick next chunk range `[start, end)` for chunk K from the manifest +4. Read entries from live schema, recording each entry's `version_id` +5. Rerandomize shares using `BLAKE3(shared_secret || iris_id)` XOF +6. Write rerandomized shares to staging schema with `epoch = E`, `original_version_id`, `chunk_id = K`, and `rerand_epoch = E + 1` +7. Set `staging_written = TRUE` in local `rerand_progress` for `(epoch = E, chunk_id = K)` +8. Upload S3 marker after staging commit: `s3://bucket/rerand/epoch-{E}/party-{P}/chunk-{K}/staged` +9. Poll S3 until all 3 party markers exist for chunk K +10. Set `all_confirmed = TRUE` in local `rerand_progress` for `(epoch = E, chunk_id = K)` +11. Acquire `pg_advisory_lock(RERAND_APPLY_LOCK)` on a dedicated connection, then copy from staging to live DB, delete staging, and mark applied — all in one transaction (scoped to epoch and chunk): + ```sql + SELECT pg_advisory_lock(RERAND_APPLY_LOCK); -- on dedicated connection + BEGIN; + UPDATE irises SET + left_code = staging.left_code, + left_mask = staging.left_mask, + right_code = staging.right_code, + right_mask = staging.right_mask, + rerand_epoch = staging.rerand_epoch + FROM staging_schema.irises AS staging + WHERE irises.id = staging.id + AND staging.epoch = E + AND staging.chunk_id = K + AND irises.version_id = staging.original_version_id; + DELETE FROM staging_schema.irises WHERE epoch = E AND chunk_id = K; + UPDATE rerand_progress SET live_applied = TRUE WHERE epoch = E AND chunk_id = K; + COMMIT; + SELECT pg_advisory_unlock(RERAND_APPLY_LOCK); -- release after commit + ``` +12. Proceed to next chunk (or start epoch transition if all chunks done) + +### Step 2: Main Server Startup (minimal changes) + +At startup, before `load_iris_db`: + +1. **Existing**: modification sync (`sync_modifications`) — all parties catch up on modifications, producing identical `version_id` values +2. **New**: rerand sync — parties exchange a compact rerand watermark during the existing startup sync (`SyncState` exchange): + - Each party computes `(epoch, max_confirmed_chunk)` from its local `rerand_progress` table: the active epoch E and the highest `chunk_id` where `all_confirmed = TRUE`. Since chunks are processed in strictly increasing order, all chunks `0..max_confirmed_chunk` are implicitly confirmed. + - Each party sends this single `(epoch, max_confirmed_chunk)` pair as part of `SyncState`. + - Each party computes `safe_up_to = max(max_confirmed_chunk_party_0, max_confirmed_chunk_party_1, max_confirmed_chunk_party_2)` for the agreed epoch E, then locally applies all chunks `0..safe_up_to` where `live_applied = FALSE`. + - This is safe because `all_confirmed = TRUE` at any party means that party observed all three S3 `staged` markers, which means all three parties successfully committed the chunk to their staging schemas. A slower party may not have polled S3 yet, but its staging data is already there. Using `max` ensures all parties converge to the same applied set, preventing cross-party desync where one party loads rerandomized shares and another loads stale shares. + - Edge case: if no chunks have been confirmed yet (fresh epoch or very start), `max_confirmed_chunk` is -1 / None. `safe_up_to` becomes -1 / None and the catch-up step is skipped entirely. +3. **New (DB-only catch-up)**: acquire `pg_advisory_lock(RERAND_APPLY_LOCK)` on a dedicated connection. Then for every chunk K in `0..safe_up_to` where locally `live_applied = FALSE` (in increasing order): run the same apply transaction as Step 1.11. **Keep the lock held** through step 4. +4. **Existing**: `load_iris_db` — loads from live DB into GPU memory. The advisory lock is still held, so the rerand server cannot apply new chunks while the DB is being read into memory. +5. Release the advisory lock: `SELECT pg_advisory_unlock(RERAND_APPLY_LOCK)` on the dedicated connection, then drop the connection. + +### Advisory lock: startup vs rerand server concurrency + +Both the rerand server (Step 1.11) and the main server startup (Steps 2.3–2.4) acquire `pg_advisory_lock(RERAND_APPLY_LOCK)` before applying chunks. This ensures: + +- Only one process applies chunks at a time (no interleaving). +- The main server holds the lock from catch-up through `load_iris_db`, so the rerand server cannot sneak in applies between catch-up and memory load. +- If either process crashes, the connection drops and Postgres automatically releases the session-level lock. No stale locks. + +**Implementation with connection pools (sqlx)**: session-level advisory locks are tied to a specific Postgres connection. When using a connection pool, acquire a **dedicated connection** (`pool.acquire()`) and hold it (do not drop/return it) for the entire lock window. The catch-up queries and `load_iris_db` can use the pool normally — the dedicated connection just sits idle holding the lock. Release with `pg_advisory_unlock(...)` on the same connection after `load_iris_db` completes, then drop the connection. + +```rust +let mut lock_conn = pool.acquire().await?; +sqlx::query("SELECT pg_advisory_lock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *lock_conn).await?; + +apply_catchup_chunks(&pool).await?; // uses pool +load_iris_db(&pool).await?; // uses pool + +sqlx::query("SELECT pg_advisory_unlock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *lock_conn).await?; +drop(lock_conn); +``` + +### Why modification sync before rerand sync matters + +Modification sync ensures all parties have the same `version_id` values before the rerand staging copy runs. This guarantees the optimistic lock (`WHERE version_id = original_version_id`) produces the same skip set on all parties — the same entries are updated, the same entries are skipped. + +## Conflict Resolution: Rerandomization vs Modifications + +### Why the optimistic lock is needed + +The rerand server reads entry X at time T with `version_id = V`. A modification (reauth/deletion) may happen later, bumping `version_id` to V+1. The staging still has `original_version_id = V`. The optimistic lock prevents overwriting the modification: + +```sql +UPDATE irises SET ... WHERE version_id = original_version_id; +-- V ≠ V+1 → entry X skipped +``` + +### Why `rerand_epoch` and the trigger are needed + +Without the trigger change, the staging copy would bump `version_id` (because share data changed). The trigger change keeps `version_id` as a pure "user-facing modification counter," separate from rerandomization. + +## Chunking + +Chunk boundaries must be identical across parties for chunk K to be meaningful. Define them via an epoch manifest object in S3: + +- `s3://bucket/rerand/epoch-{E}/party-0/manifest.json`: `{ epoch: E, chunk_size: N, max_id_inclusive: M }` +- Party 0 writes the manifest once at epoch start under its own prefix (IAM-compliant); other parties poll until it exists and treat it as immutable. +- **Watermark sync**: before the manifest is written, each party P uploads its local watermark `max_id_party_P = SELECT MAX(id) FROM irises` to `s3://bucket/rerand/epoch-{E}/party-{P}/max-id`. +- The manifest writer waits until all three `max-id` markers exist, then sets `max_id_inclusive` as: + - `M = min(max_id_party_0, max_id_party_1, max_id_party_2) - safety_buffer_ids` + - `safety_buffer_ids` is configurable (default 0 or one chunk) to avoid rerandomizing the “tip” where replication/ingest lag could differ across parties. +- New inserts with `id > M` are left for a future epoch. +- Chunk K corresponds to `[start, end)` where `start = 1 + K * N` and `end = min(start + N, M + 1)`. + +A configurable delay (`--chunk-delay`, default e.g. 5s) is inserted between chunks to avoid sustained DB load. The rerand server should not stress the live DB with continuous writes — the delay spreads the I/O over time. The delay, chunk size, and number of parallel DB connections should all be configurable via CLI flags or environment variables. \ No newline at end of file diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh b/iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh new file mode 100755 index 0000000000..aeb3a0aa28 --- /dev/null +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +rm -rf "*.log" + +docker-compose -f "$SCRIPT_DIR/docker-compose.rand.yaml" down --remove-orphans -v +docker-compose -f "$SCRIPT_DIR/docker-compose.rand.yaml" up -d + +sleep 10 + +aws_local() { + AWS_ACCESS_KEY_ID=test AWS_SECRET_ACCESS_KEY=test AWS_DEFAULT_REGION=us-east-1 \ + aws --endpoint-url=http://${LOCALSTACK_HOST:-localhost}:4566 "$@" +} + +# Create S3 bucket for rerand coordination markers +BUCKET_NAME=wf-smpcv2-rerand-testing +aws_local s3api create-bucket --bucket $BUCKET_NAME --region us-east-1 + +# Build binaries +cargo build -p iris-mpc-bins --release --bin seed-v2-dbs --bin rerandomize-db + +TARGET_DIR=$(cargo metadata --format-version 1 | jq ".target_directory" -r) + +# Set AWS env vars for localstack +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test +export AWS_DEFAULT_REGION=us-east-1 +export AWS_ENDPOINT_URL="http://127.0.0.1:4566" + +export ENVIRONMENT="testing" + +# Seed DBs with initial data (using first 3 new-db containers as live DBs) +echo "=== Seeding DBs ===" +$TARGET_DIR/release/seed-v2-dbs \ + --db-url-party-0 postgres://postgres:postgres@localhost:6200 \ + --db-url-party-1 postgres://postgres:postgres@localhost:6201 \ + --db-url-party-2 postgres://postgres:postgres@localhost:6202 \ + --schema-name-party-0 SMPC_testing_0 \ + --schema-name-party-1 SMPC_testing_1 \ + --schema-name-party-2 SMPC_testing_2 \ + --fill-to 1000 \ + --batch-size 100 +echo "Seeding complete" + +# Run continuous rerandomization for all 3 parties in parallel +echo "=== Starting continuous rerandomization ===" +COMMON_ARGS="--chunk-size 200 --chunk-delay-secs 1 --s3-poll-interval-ms 2000 --safety-buffer-ids 0" + +$TARGET_DIR/release/rerandomize-db rerandomize-continuous \ + --party-id 0 \ + --db-url postgres://postgres:postgres@localhost:6200 \ + --schema-name SMPC_testing_0 \ + --s3-bucket $BUCKET_NAME \ + --healthcheck-port 3010 \ + $COMMON_ARGS & +PID_0=$! + +$TARGET_DIR/release/rerandomize-db rerandomize-continuous \ + --party-id 1 \ + --db-url postgres://postgres:postgres@localhost:6201 \ + --schema-name SMPC_testing_1 \ + --s3-bucket $BUCKET_NAME \ + --healthcheck-port 3011 \ + $COMMON_ARGS & +PID_1=$! + +$TARGET_DIR/release/rerandomize-db rerandomize-continuous \ + --party-id 2 \ + --db-url postgres://postgres:postgres@localhost:6202 \ + --schema-name SMPC_testing_2 \ + --s3-bucket $BUCKET_NAME \ + --healthcheck-port 3012 \ + $COMMON_ARGS & +PID_2=$! + +echo "Rerand servers started: PIDs $PID_0, $PID_1, $PID_2" +echo "Waiting for one epoch to complete (watching for completion markers in S3)..." + +# Poll until epoch 0 completion markers exist for all parties +MAX_WAIT=300 +ELAPSED=0 +while [ $ELAPSED -lt $MAX_WAIT ]; do + COMPLETE=true + for P in 0 1 2; do + KEY="rerand/epoch-0/party-${P}/complete" + if ! aws_local s3api head-object --bucket $BUCKET_NAME --key "$KEY" >/dev/null 2>&1; then + COMPLETE=false + break + fi + done + if [ "$COMPLETE" = true ]; then + echo "=== Epoch 0 completed! ===" + break + fi + sleep 5 + ELAPSED=$((ELAPSED + 5)) + echo "Waiting... ($ELAPSED s)" +done + +if [ $ELAPSED -ge $MAX_WAIT ]; then + echo "ERROR: Epoch 0 did not complete within ${MAX_WAIT}s" +fi + +# Stop the rerand servers +kill $PID_0 $PID_1 $PID_2 2>/dev/null || true +wait $PID_0 $PID_1 $PID_2 2>/dev/null || true + +echo "=== Continuous rerandomization test finished ===" diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs b/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs index 58910d38b3..5cbe5c51e8 100644 --- a/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs @@ -11,19 +11,17 @@ use base64::Engine; use clap::Parser; use eyre::Result; use futures::TryStreamExt; -use iris_mpc_common::galois; -use iris_mpc_common::galois::degree4::basis::Monomial; -use iris_mpc_common::galois::degree4::GaloisRingElement; use iris_mpc_common::galois_engine::degree4::{ GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare, }; -use iris_mpc_common::id::PartyID; use iris_mpc_common::postgres::{AccessMode, PostgresClient}; use iris_mpc_store::{DbStoredIris, Store, StoredIrisRef}; use iris_mpc_upgrade::config::{ KeyGenConfig, ReRandomizeCheckConfig, ReRandomizeConfig, ReRandomizeDbSubCommand, + RerandomizeContinuousConfig, }; -use iris_mpc_upgrade::rerandomization::randomize_iris; +use iris_mpc_upgrade::continuous_rerand; +use iris_mpc_upgrade::rerandomization::{randomize_iris, reconstruct_shares}; use iris_mpc_upgrade::tripartite_dh; use iris_mpc_upgrade::{ config::ReRandomizeDbConfig, @@ -42,6 +40,9 @@ async fn main() -> Result<()> { ReRandomizeDbSubCommand::RerandomizeDb(config) => rerandomize_db_main(config).await, ReRandomizeDbSubCommand::KeyGen(config) => keygen_main(config).await, ReRandomizeDbSubCommand::RerandomizeCheck(config) => rerandomize_check_main(config).await, + ReRandomizeDbSubCommand::RerandomizeContinuous(config) => { + rerandomize_continuous_main(config).await + } } } @@ -531,6 +532,35 @@ async fn rerandomize_check_main(config: ReRandomizeCheckConfig) -> Result<()> { Ok(()) } +async fn rerandomize_continuous_main(config: RerandomizeContinuousConfig) -> Result<()> { + tracing::info!("Starting continuous rerandomization for party {}", config.party_id); + + let mut background_tasks = TaskMonitor::new(); + let healthcheck_port = config.healthcheck_port; + let _health_check_abort = background_tasks + .spawn(async move { spawn_healthcheck_server(healthcheck_port).await }); + background_tasks.check_tasks(); + + let sdk_config = aws_config::from_env().load().await; + let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config); + let sm_config = aws_sdk_secretsmanager::config::Builder::from(&sdk_config); + let s3_client = S3Client::from_conf(s3_config.build()); + let sm_client = SecretsManagerClient::from_conf(sm_config.build()); + + let postgres_client = PostgresClient::new( + &config.db_url, + &config.schema_name, + AccessMode::ReadWrite, + ) + .await?; + let store = Store::new(&postgres_client).await?; + + continuous_rerand::run_continuous_rerand(&config, &s3_client, &sm_client, &store, None).await?; + + background_tasks.abort_and_wait_for_finish().await; + Ok(()) +} + async fn download_public_key(config: &ReRandomizeConfig, party_id: u8) -> Result { if config.env == "testing" { let bucket = config.public_key_bucket_name.as_ref().ok_or_else(|| { @@ -570,70 +600,6 @@ async fn build_read_only_store(db_url: &str, schema_name: &str) -> Result Store::new(&postgres_client).await } -fn reconstruct_shares(share0: &[u16], share1: &[u16], share2: &[u16]) -> Vec { - let lag_01 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID0, - PartyID::ID1, - ); - let lag_10 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID1, - PartyID::ID0, - ); - let lag_02 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID0, - PartyID::ID2, - ); - let lag_20 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID2, - PartyID::ID0, - ); - let lag_12 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID1, - PartyID::ID2, - ); - let lag_21 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID2, - PartyID::ID1, - ); - - assert!(share0.len() == share1.len() && share1.len() == share2.len()); - - let recon01 = share0 - .chunks_exact(4) - .zip_eq(share1.chunks_exact(4)) - .flat_map(|(a, b)| { - let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); - let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); - let c = a * lag_01 + b * lag_10; - c.coefs - }) - .collect_vec(); - let recon12 = share1 - .chunks_exact(4) - .zip_eq(share2.chunks_exact(4)) - .flat_map(|(a, b)| { - let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); - let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); - let c = a * lag_12 + b * lag_21; - c.coefs - }) - .collect_vec(); - let recon02 = share0 - .chunks_exact(4) - .zip_eq(share2.chunks_exact(4)) - .flat_map(|(a, b)| { - let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); - let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); - let c = a * lag_02 + b * lag_20; - c.coefs - }) - .collect_vec(); - - assert_eq!(recon01, recon12); - assert_eq!(recon01, recon02); - recon01 -} - async fn download_public_key_from_localstack(bucket: &str, party_id: u8) -> Result { let key = format!("{}-{}", PUBLIC_KEY_S3_KEY_NAME_PREFIX, party_id); let request_url = format!("http://localhost:4566/{}/{}", bucket, key); diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh b/iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh new file mode 100755 index 0000000000..167f2682ae --- /dev/null +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# +# Run the continuous rerandomization e2e chaos tests. +# Starts Postgres + localstack via docker-compose, runs the Rust tests, then +# tears everything down. +# +# Usage: +# ./run-rerand-e2e-tests.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +COMPOSE_FILE="$SCRIPT_DIR/docker-compose.rand.yaml" + +cleanup() { + echo "=== Tearing down containers ===" + docker-compose -f "$COMPOSE_FILE" down --remove-orphans -v 2>/dev/null || true +} +trap cleanup EXIT + +echo "=== Starting Postgres + localstack ===" +docker-compose -f "$COMPOSE_FILE" down --remove-orphans -v 2>/dev/null || true +docker-compose -f "$COMPOSE_FILE" up -d + +echo "Waiting for services to be ready..." +for i in $(seq 1 30); do + if docker exec iris-mpc-upgrade-new-db-1-1 pg_isready -U postgres -q 2>/dev/null; then + break + fi + sleep 1 +done +docker exec iris-mpc-upgrade-new-db-1-1 pg_isready -U postgres || { echo "Postgres not ready"; exit 1; } + +for i in $(seq 1 30); do + STATUS=$(docker inspect --format='{{.State.Health.Status}}' iris-mpc-upgrade-localstack-1 2>/dev/null || echo "unknown") + if [ "$STATUS" = "healthy" ]; then + break + fi + sleep 1 +done +echo "Infrastructure ready." + +echo "=== Running e2e chaos tests ===" +cd "$REPO_ROOT" +AWS_ACCESS_KEY_ID=test \ +AWS_SECRET_ACCESS_KEY=test \ +AWS_DEFAULT_REGION=us-east-1 \ +AWS_ENDPOINT_URL=http://127.0.0.1:4566 \ +ENVIRONMENT=testing \ + cargo test -p iris-mpc-upgrade --test continuous_rerand_e2e --features db_dependent -- --nocapture + +echo "=== All tests passed ===" diff --git a/iris-mpc-bins/bin/iris-mpc/server.rs b/iris-mpc-bins/bin/iris-mpc/server.rs index 2e9e318e6a..82d49a76e7 100644 --- a/iris-mpc-bins/bin/iris-mpc/server.rs +++ b/iris-mpc-bins/bin/iris-mpc/server.rs @@ -58,6 +58,7 @@ use iris_mpc_common::{ }; use iris_mpc_gpu::server::ServerActor; use iris_mpc_store::loader::load_iris_db; +use iris_mpc_store::rerand as rerand_store; use iris_mpc_store::{ fetch_and_parse_chunks, last_snapshot_timestamp, DbStoredIris, ObjectStore, S3Store, S3StoredIris, Store, StoredIrisRef, @@ -981,11 +982,13 @@ async fn server_main(config: Config) -> Result<()> { let is_ready_flag = Arc::new(AtomicBool::new(false)); let is_ready_flag_cloned = Arc::clone(&is_ready_flag); + let rerand_state = rerand_store::build_rerand_sync_state(&store.pool).await.ok(); let my_state = SyncState { db_len: store_len as u64, modifications: store.last_modifications(max_modification_lookback).await?, next_sns_sequence_num: next_sns_seq_number_future.await?, common_config: CommonConfig::from(config.clone()), + rerand_state, }; tracing::info!("Sync state: {:?}", my_state); @@ -1296,7 +1299,7 @@ async fn server_main(config: Config) -> Result<()> { None, &aws_clients, &shares_encryption_key_pair, - sync_result, + sync_result.clone(), ) .await?; } @@ -1315,6 +1318,13 @@ async fn server_main(config: Config) -> Result<()> { } } + let rerand_lock_conn = rerand_store::rerand_catchup_and_lock( + &store.pool, + &store.schema_name, + &sync_result, + ) + .await?; + if download_shutdown_handler.is_shutting_down() { tracing::warn!("Shutting down has been triggered"); return Ok(()); @@ -1408,6 +1418,8 @@ async fn server_main(config: Config) -> Result<()> { let (mut handle, store) = rx.await??; + rerand_store::release_rerand_lock(rerand_lock_conn).await?; + background_tasks.check_tasks(); // Start thread that will be responsible for communicating back the results diff --git a/iris-mpc-common/src/helpers/sync.rs b/iris-mpc-common/src/helpers/sync.rs index 89c6238754..9719af4f6f 100644 --- a/iris-mpc-common/src/helpers/sync.rs +++ b/iris-mpc-common/src/helpers/sync.rs @@ -10,6 +10,15 @@ pub struct SyncState { pub modifications: Vec, pub next_sns_sequence_num: Option, pub common_config: CommonConfig, + #[serde(default)] + pub rerand_state: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RerandSyncState { + pub epoch: i32, + /// Highest chunk_id where all_confirmed = TRUE. -1 if none confirmed. + pub max_confirmed_chunk: i32, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -405,6 +414,7 @@ mod tests { modifications, next_sns_sequence_num: None, common_config: CommonConfig::from(config), + rerand_state: None, } } @@ -846,18 +856,21 @@ mod tests { modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(200), common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 30, modifications: vec![], next_sns_sequence_num: Some(150), common_config: CommonConfig::default(), + rerand_state: None, }, ]; @@ -870,6 +883,7 @@ mod tests { modifications: vec![], next_sns_sequence_num: None, common_config: CommonConfig::default(), + rerand_state: None, }; let all_states = vec![ state_with_none_sequence_num.clone(), @@ -892,18 +906,21 @@ mod tests { modifications: vec![], next_sns_sequence_num: None, // NodeX - advanced but empty queue common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(123), // Other nodes still have messages common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 30, modifications: vec![], next_sns_sequence_num: Some(123), common_config: CommonConfig::default(), + rerand_state: None, }, ]; @@ -1023,18 +1040,21 @@ mod tests { modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::from(config1), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::from(config2), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::from(config3), + rerand_state: None, }, ]; diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index ffddede11a..e351f2a9b9 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -1,4 +1,5 @@ pub mod loader; +pub mod rerand; mod s3_importer; use bytemuck::cast_slice; diff --git a/iris-mpc-store/src/rerand.rs b/iris-mpc-store/src/rerand.rs new file mode 100644 index 0000000000..557304c057 --- /dev/null +++ b/iris-mpc-store/src/rerand.rs @@ -0,0 +1,441 @@ +use eyre::Result; +use iris_mpc_common::helpers::sync::{RerandSyncState, SyncResult}; +use sqlx::{pool::PoolConnection, PgPool, Postgres}; + +pub const RERAND_APPLY_LOCK: i64 = 0x5245_5241_4E44; + +pub struct StagingIrisEntry { + pub epoch: i32, + pub id: i64, + pub chunk_id: i32, + pub left_code: Vec, + pub left_mask: Vec, + pub right_code: Vec, + pub right_mask: Vec, + pub original_version_id: i16, + pub rerand_epoch: i32, +} + +#[derive(sqlx::FromRow, Debug, Clone)] +pub struct RerandProgress { + pub epoch: i32, + pub chunk_id: i32, + pub staging_written: bool, + pub all_confirmed: bool, + pub live_applied: bool, +} + +pub fn staging_schema_name(live_schema: &str) -> String { + format!("{}_rerand_staging", live_schema) +} + +pub async fn ensure_staging_schema(pool: &PgPool, staging_schema: &str) -> Result<()> { + let create_schema = format!(r#"CREATE SCHEMA IF NOT EXISTS "{}""#, staging_schema); + sqlx::query(&create_schema).execute(pool).await?; + + let create_table = format!( + r#" + CREATE TABLE IF NOT EXISTS "{}".irises ( + epoch INTEGER NOT NULL, + id BIGINT NOT NULL, + chunk_id INTEGER NOT NULL, + left_code BYTEA, + left_mask BYTEA, + right_code BYTEA, + right_mask BYTEA, + original_version_id SMALLINT, + rerand_epoch INTEGER, + PRIMARY KEY (epoch, id) + ) + "#, + staging_schema, + ); + sqlx::query(&create_table).execute(pool).await?; + Ok(()) +} + +pub async fn insert_staging_irises( + pool: &PgPool, + staging_schema: &str, + entries: &[StagingIrisEntry], +) -> Result<()> { + if entries.is_empty() { + return Ok(()); + } + + let table = format!("\"{}\".irises", staging_schema); + let header = format!( + "INSERT INTO {} (epoch, id, chunk_id, left_code, left_mask, right_code, right_mask, original_version_id, rerand_epoch)", + table + ); + + let mut qb = sqlx::QueryBuilder::new(header); + qb.push_values(entries, |mut b, e| { + b.push_bind(e.epoch); + b.push_bind(e.id); + b.push_bind(e.chunk_id); + b.push_bind(&e.left_code); + b.push_bind(&e.left_mask); + b.push_bind(&e.right_code); + b.push_bind(&e.right_mask); + b.push_bind(e.original_version_id); + b.push_bind(e.rerand_epoch); + }); + + qb.push(" ON CONFLICT (epoch, id) DO NOTHING"); + qb.build().execute(pool).await?; + Ok(()) +} + +/// Apply a confirmed staging chunk to the live DB. +/// +/// Within a single transaction: +/// 1. UPDATE live irises from staging (optimistic lock on version_id) +/// 2. DELETE staging rows for this chunk +/// 3. Mark live_applied in rerand_progress +pub async fn apply_staging_chunk( + pool: &PgPool, + staging_schema: &str, + epoch: i32, + chunk_id: i32, +) -> Result { + let mut tx = pool.begin().await?; + + let update_sql = format!( + r#" + UPDATE irises SET + left_code = staging.left_code, + left_mask = staging.left_mask, + right_code = staging.right_code, + right_mask = staging.right_mask, + rerand_epoch = staging.rerand_epoch + FROM "{}".irises AS staging + WHERE irises.id = staging.id + AND staging.epoch = $1 + AND staging.chunk_id = $2 + AND irises.version_id = staging.original_version_id + "#, + staging_schema, + ); + let result = sqlx::query(&update_sql) + .bind(epoch) + .bind(chunk_id) + .execute(&mut *tx) + .await?; + let rows_updated = result.rows_affected(); + + let delete_sql = format!( + r#"DELETE FROM "{}".irises WHERE epoch = $1 AND chunk_id = $2"#, + staging_schema, + ); + sqlx::query(&delete_sql) + .bind(epoch) + .bind(chunk_id) + .execute(&mut *tx) + .await?; + + sqlx::query( + "UPDATE rerand_progress SET live_applied = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(rows_updated) +} + +pub async fn upsert_rerand_progress(pool: &PgPool, epoch: i32, chunk_id: i32) -> Result<()> { + sqlx::query( + r#" + INSERT INTO rerand_progress (epoch, chunk_id) + VALUES ($1, $2) + ON CONFLICT (epoch, chunk_id) DO NOTHING + "#, + ) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(()) +} + +pub async fn set_staging_written(pool: &PgPool, epoch: i32, chunk_id: i32) -> Result<()> { + sqlx::query( + "UPDATE rerand_progress SET staging_written = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(()) +} + +pub async fn set_all_confirmed(pool: &PgPool, epoch: i32, chunk_id: i32) -> Result<()> { + sqlx::query( + "UPDATE rerand_progress SET all_confirmed = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(()) +} + +pub async fn get_rerand_progress( + pool: &PgPool, + epoch: i32, + chunk_id: i32, +) -> Result> { + let row = sqlx::query_as::<_, RerandProgress>( + "SELECT epoch, chunk_id, staging_written, all_confirmed, live_applied FROM rerand_progress WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .fetch_optional(pool) + .await?; + Ok(row) +} + +/// Returns the highest chunk_id where all_confirmed = TRUE for a given epoch, +/// or None if no chunks are confirmed. +pub async fn get_max_confirmed_chunk(pool: &PgPool, epoch: i32) -> Result> { + let row: Option<(i32,)> = sqlx::query_as( + "SELECT MAX(chunk_id) FROM rerand_progress WHERE epoch = $1 AND all_confirmed = TRUE", + ) + .bind(epoch) + .fetch_optional(pool) + .await?; + match row { + Some((max,)) => Ok(Some(max)), + None => Ok(None), + } +} + +/// Returns the highest epoch that has any rerand_progress rows. +pub async fn get_current_epoch(pool: &PgPool) -> Result> { + let row: (Option,) = + sqlx::query_as("SELECT MAX(epoch) FROM rerand_progress") + .fetch_one(pool) + .await?; + Ok(row.0) +} + +/// Returns chunk_ids for a given epoch where live_applied = FALSE and +/// chunk_id <= up_to_chunk, ordered ascending. +pub async fn get_unapplied_chunks( + pool: &PgPool, + epoch: i32, + up_to_chunk: i32, +) -> Result> { + let rows: Vec<(i32,)> = sqlx::query_as( + r#" + SELECT chunk_id FROM rerand_progress + WHERE epoch = $1 AND chunk_id <= $2 AND live_applied = FALSE + ORDER BY chunk_id ASC + "#, + ) + .bind(epoch) + .bind(up_to_chunk) + .fetch_all(pool) + .await?; + Ok(rows.into_iter().map(|(id,)| id).collect()) +} + +// --------------------------------------------------------------------------- +// Shared startup helpers (used by both HNSW and GPU servers) +// --------------------------------------------------------------------------- + +/// Build the rerand sync state from the local `rerand_progress` table. +pub async fn build_rerand_sync_state(pool: &PgPool) -> Result { + let epoch = get_current_epoch(pool).await?.unwrap_or(0); + let max_confirmed = get_max_confirmed_chunk(pool, epoch) + .await? + .unwrap_or(-1); + Ok(RerandSyncState { + epoch, + max_confirmed_chunk: max_confirmed, + }) +} + +/// Compute the safe-to-apply watermark from all parties' rerand sync states. +/// Returns `Some((epoch, max_chunk_id))` if there are chunks to catch up, +/// `None` otherwise. +pub fn compute_rerand_safe_up_to(sync_result: &SyncResult) -> Result> { + let my_state = match sync_result.my_state.rerand_state.as_ref() { + Some(s) => s, + None => return Ok(None), + }; + let my_epoch = my_state.epoch; + + let rerand_states: Vec<&RerandSyncState> = sync_result + .all_states + .iter() + .filter_map(|s| s.rerand_state.as_ref()) + .collect(); + + if rerand_states.is_empty() { + return Ok(None); + } + + let mut safe_up_to = -1; + for s in rerand_states { + let diff = s.epoch - my_epoch; + match diff { + 0 => { + safe_up_to = safe_up_to.max(s.max_confirmed_chunk); + } + 1 => { + safe_up_to = i32::MAX; + } + -1 => { + // They are behind, they contribute -1 + } + _ => { + eyre::bail!("Fatal epoch desync: local epoch is {}, but peer is on epoch {}", my_epoch, s.epoch); + } + } + } + + if safe_up_to < 0 { + return Ok(None); + } + + Ok(Some((my_epoch, safe_up_to))) +} + +/// Perform rerand catch-up and acquire the advisory lock. +/// +/// 1. Computes the safe-to-apply watermark from `sync_result`. +/// 2. If there are unapplied chunks, acquires `pg_advisory_lock(RERAND_APPLY_LOCK)` +/// on a dedicated connection, then applies all unapplied chunks. +/// 3. Returns the lock-holding connection (if the lock was acquired). +/// +/// The caller **must** keep the returned connection alive until `load_iris_db` +/// finishes, then call [`release_rerand_lock`] to release it. +pub async fn rerand_catchup_and_lock( + pool: &PgPool, + schema_name: &str, + sync_result: &SyncResult, +) -> Result>> { + let safe_up_to = match compute_rerand_safe_up_to(sync_result)? { + Some(v) => v, + None => return Ok(None), + }; + + let staging_schema = staging_schema_name(schema_name); + tracing::info!( + "Rerand catch-up: applying chunks up to {} for epoch {}", + safe_up_to.1, + safe_up_to.0 + ); + + let mut conn = pool.acquire().await?; + sqlx::query("SELECT pg_advisory_lock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *conn) + .await?; + + let unapplied = get_unapplied_chunks(pool, safe_up_to.0, safe_up_to.1).await?; + for chunk_id in unapplied { + let rows = + apply_staging_chunk(pool, &staging_schema, safe_up_to.0, chunk_id).await?; + tracing::info!( + "Rerand catch-up: applied epoch {} chunk {} ({} rows)", + safe_up_to.0, + chunk_id, + rows + ); + } + + Ok(Some(conn)) +} + +/// Release the advisory lock acquired by [`rerand_catchup_and_lock`]. +pub async fn release_rerand_lock( + lock_conn: Option>, +) -> Result<()> { + if let Some(mut conn) = lock_conn { + sqlx::query("SELECT pg_advisory_unlock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *conn) + .await?; + drop(conn); + tracing::info!("Rerand advisory lock released after DB load"); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use iris_mpc_common::config::CommonConfig; + use iris_mpc_common::helpers::sync::SyncState; + + fn dummy_sync_state(epoch: i32, max_confirmed_chunk: i32) -> SyncState { + SyncState { + db_len: 100, + modifications: vec![], + next_sns_sequence_num: None, + common_config: CommonConfig::default(), + rerand_state: Some(RerandSyncState { + epoch, + max_confirmed_chunk, + }), + } + } + + #[test] + fn test_compute_rerand_safe_up_to_same_epoch() { + let p0 = dummy_sync_state(1, 5); + let p1 = dummy_sync_state(1, 4); + let p2 = dummy_sync_state(1, 6); + let sync_result = SyncResult { + my_state: p0.clone(), + all_states: vec![p0, p1, p2], + }; + assert_eq!(compute_rerand_safe_up_to(&sync_result).unwrap(), Some((1, 6))); + } + + #[test] + fn test_compute_rerand_safe_up_to_peer_ahead() { + // I am on epoch 0, but peer is on epoch 1. + // This implies the peer has confirmed all my chunks for epoch 0. + let p0 = dummy_sync_state(0, 5); + let p1 = dummy_sync_state(1, 0); // ahead + let p2 = dummy_sync_state(0, 5); + let sync_result = SyncResult { + my_state: p0.clone(), + all_states: vec![p0, p1, p2], + }; + assert_eq!(compute_rerand_safe_up_to(&sync_result).unwrap(), Some((0, i32::MAX))); + } + + #[test] + fn test_compute_rerand_safe_up_to_peer_behind() { + // I am on epoch 1, but peer is on epoch 0. + // This implies the peer has not confirmed any chunks for epoch 1. + let p0 = dummy_sync_state(1, 2); + let p1 = dummy_sync_state(0, 10); // behind + let p2 = dummy_sync_state(1, 2); + let sync_result = SyncResult { + my_state: p0.clone(), + all_states: vec![p0, p1, p2], + }; + assert_eq!(compute_rerand_safe_up_to(&sync_result).unwrap(), Some((1, 2))); + } + + #[test] + fn test_compute_rerand_safe_up_to_fatal_desync() { + // I am on epoch 1, but peer is on epoch 3 (difference > 1). + let p0 = dummy_sync_state(1, 2); + let p1 = dummy_sync_state(3, 10); // way ahead + let p2 = dummy_sync_state(1, 2); + let sync_result = SyncResult { + my_state: p0.clone(), + all_states: vec![p0, p1, p2], + }; + assert!(compute_rerand_safe_up_to(&sync_result).is_err()); + } +} diff --git a/iris-mpc-upgrade/Cargo.toml b/iris-mpc-upgrade/Cargo.toml index 93eeb19589..e9f4c61b19 100644 --- a/iris-mpc-upgrade/Cargo.toml +++ b/iris-mpc-upgrade/Cargo.toml @@ -12,18 +12,24 @@ ark-bls12-381 = "0.5.0" ark-ff = "0.5.0" ark-ec = "0.5.0" ark-serialize = "0.5.0" +aws-config.workspace = true +aws-sdk-s3.workspace = true +aws-sdk-secretsmanager.workspace = true axum.workspace = true iris-mpc-common = { path = "../iris-mpc-common" } iris-mpc-store = { path = "../iris-mpc-store" } clap = { workspace = true, features = ["env"] } eyre.workspace = true bytemuck.workspace = true +base64.workspace = true serde.workspace = true +serde_json.workspace = true serde-big-array = "0.5" tracing.workspace = true itertools.workspace = true rand.workspace = true rand_chacha = "0.3" +sqlx.workspace = true tokio.workspace = true tracing-subscriber.workspace = true @@ -36,11 +42,17 @@ prost = "0.13.3" sha2.workspace = true thiserror.workspace = true blake3 = "1.8.2" +futures.workspace = true +tokio-util.workspace = true [dev-dependencies] criterion = "0.5" rayon = "1.10.0" +dotenvy.workspace = true + +[features] +db_dependent = [] [build-dependencies] diff --git a/iris-mpc-upgrade/src/config.rs b/iris-mpc-upgrade/src/config.rs index 3444eb3011..82ad7b347d 100644 --- a/iris-mpc-upgrade/src/config.rs +++ b/iris-mpc-upgrade/src/config.rs @@ -240,6 +240,7 @@ pub enum ReRandomizeDbSubCommand { KeyGen(KeyGenConfig), RerandomizeDb(ReRandomizeConfig), RerandomizeCheck(ReRandomizeCheckConfig), + RerandomizeContinuous(RerandomizeContinuousConfig), } #[derive(Args)] @@ -341,3 +342,36 @@ pub struct ReRandomizeCheckConfig { #[clap(long, env = "NEW_SCHEMA_NAME_PARTY_2")] pub new_schema_name_party_2: String, } + +#[derive(Args, Debug)] +pub struct RerandomizeContinuousConfig { + #[clap(long, env = "PARTY_ID")] + pub party_id: u8, + + #[clap(long, env = "DB_URL")] + pub db_url: String, + + #[clap(long, env = "ENVIRONMENT")] + pub env: String, + + #[clap(long, env = "RERAND_S3_BUCKET")] + pub s3_bucket: String, + + #[clap(long, env = "SCHEMA_NAME")] + pub schema_name: String, + + #[clap(long, default_value = "10000", env = "CHUNK_SIZE")] + pub chunk_size: u64, + + #[clap(long, default_value = "5", env = "CHUNK_DELAY_SECS")] + pub chunk_delay_secs: u64, + + #[clap(long, default_value = "0", env = "SAFETY_BUFFER_IDS")] + pub safety_buffer_ids: u64, + + #[clap(long, default_value = "5000", env = "S3_POLL_INTERVAL_MS")] + pub s3_poll_interval_ms: u64, + + #[clap(long, default_value = "3000", env = "HEALTHCHECK_PORT")] + pub healthcheck_port: usize, +} diff --git a/iris-mpc-upgrade/src/continuous_rerand.rs b/iris-mpc-upgrade/src/continuous_rerand.rs new file mode 100644 index 0000000000..c948b5cd86 --- /dev/null +++ b/iris-mpc-upgrade/src/continuous_rerand.rs @@ -0,0 +1,295 @@ +use aws_sdk_s3::Client as S3Client; +use aws_sdk_secretsmanager::Client as SecretsManagerClient; +use bytemuck::cast_slice; +use eyre::Result; +use futures::TryStreamExt; +use iris_mpc_store::rerand::{ + apply_staging_chunk, ensure_staging_schema, get_rerand_progress, insert_staging_irises, + set_all_confirmed, set_staging_written, staging_schema_name, upsert_rerand_progress, + StagingIrisEntry, RERAND_APPLY_LOCK, +}; +use iris_mpc_store::Store; +use sqlx::PgPool; +use std::time::Duration; +use tokio::time::sleep; +use tokio_util::sync::CancellationToken; + +use crate::config::RerandomizeContinuousConfig; +use crate::epoch; +use crate::rerandomization::randomize_iris; +use crate::s3_coordination::{self, Manifest}; + +/// Run the continuous rerandomization loop. +/// +/// If `cancel` is provided, the loop checks for cancellation between chunk +/// stages and exits cleanly with `Ok(())` when cancelled. Pass `None` for +/// production use where the loop runs until the process is killed. +pub async fn run_continuous_rerand( + config: &RerandomizeContinuousConfig, + s3: &S3Client, + sm: &SecretsManagerClient, + store: &Store, + cancel: Option<&CancellationToken>, +) -> Result<()> { + let pool = &store.pool; + let staging_schema = staging_schema_name(&store.schema_name); + let poll_interval = Duration::from_millis(config.s3_poll_interval_ms); + let chunk_delay = Duration::from_secs(config.chunk_delay_secs); + + ensure_staging_schema(pool, &staging_schema).await?; + tracing::info!("Staging schema ensured: {}", staging_schema); + + loop { + if is_cancelled(cancel) { + return Ok(()); + } + + let active_epoch = epoch::determine_active_epoch(s3, &config.s3_bucket).await?; + tracing::info!("Active epoch: {}", active_epoch); + + let shared_secret = epoch::derive_shared_secret( + sm, + s3, + &config.s3_bucket, + &config.env, + active_epoch, + config.party_id, + poll_interval, + ) + .await?; + + let manifest = + get_or_create_manifest(s3, store, config, active_epoch, poll_interval).await?; + tracing::info!( + "Epoch {} manifest: chunk_size={}, max_id_inclusive={}", + active_epoch, + manifest.chunk_size, + manifest.max_id_inclusive + ); + + let mut chunk_id: u32 = 0; + loop { + if is_cancelled(cancel) { + return Ok(()); + } + + if manifest.chunk_is_empty(chunk_id) { + break; + } + + let progress = + get_rerand_progress(pool, active_epoch as i32, chunk_id as i32).await?; + + if progress.as_ref().is_some_and(|p| p.live_applied) { + chunk_id += 1; + continue; + } + + upsert_rerand_progress(pool, active_epoch as i32, chunk_id as i32).await?; + + if !progress.as_ref().is_some_and(|p| p.staging_written) { + process_chunk_staging( + pool, + store, + &staging_schema, + &shared_secret, + config.party_id, + active_epoch, + chunk_id, + &manifest, + ) + .await?; + + set_staging_written(pool, active_epoch as i32, chunk_id as i32).await?; + + s3_coordination::upload_chunk_staged( + s3, + &config.s3_bucket, + active_epoch, + config.party_id, + chunk_id, + ) + .await?; + tracing::info!( + "Epoch {} chunk {}: staging written, S3 marker uploaded", + active_epoch, + chunk_id + ); + } + + if is_cancelled(cancel) { + return Ok(()); + } + + if !progress.as_ref().is_some_and(|p| p.all_confirmed) { + s3_coordination::poll_chunk_staged_all( + s3, + &config.s3_bucket, + active_epoch, + chunk_id, + poll_interval, + ) + .await?; + + set_all_confirmed(pool, active_epoch as i32, chunk_id as i32).await?; + tracing::info!( + "Epoch {} chunk {}: all parties confirmed", + active_epoch, + chunk_id + ); + } + + if is_cancelled(cancel) { + return Ok(()); + } + + let mut lock_conn = pool.acquire().await?; + sqlx::query("SELECT pg_advisory_lock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *lock_conn) + .await?; + + let rows = + apply_staging_chunk(pool, &staging_schema, active_epoch as i32, chunk_id as i32) + .await?; + tracing::info!( + "Epoch {} chunk {}: applied to live DB ({} rows updated)", + active_epoch, + chunk_id, + rows + ); + + sqlx::query("SELECT pg_advisory_unlock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *lock_conn) + .await?; + drop(lock_conn); + + chunk_id += 1; + + if chunk_delay > Duration::ZERO { + sleep(chunk_delay).await; + } + } + + if chunk_id == 0 && chunk_delay > Duration::ZERO { + tracing::info!( + "Epoch {} is empty, sleeping to avoid spinning", + active_epoch + ); + sleep(chunk_delay).await; + } + + epoch::complete_epoch( + sm, + s3, + &config.s3_bucket, + &config.env, + active_epoch, + config.party_id, + poll_interval, + ) + .await?; + tracing::info!("Epoch {} completed, moving to next epoch", active_epoch); + } +} + +fn is_cancelled(cancel: Option<&CancellationToken>) -> bool { + cancel.is_some_and(|c| c.is_cancelled()) +} + +async fn get_or_create_manifest( + s3: &S3Client, + store: &Store, + config: &RerandomizeContinuousConfig, + epoch: u32, + poll_interval: Duration, +) -> Result { + if s3_coordination::manifest_exists(s3, &config.s3_bucket, epoch).await? { + return s3_coordination::download_manifest(s3, &config.s3_bucket, epoch, poll_interval) + .await; + } + + if config.party_id == 0 { + let local_max = store.get_max_serial_id().await? as u64; + s3_coordination::upload_max_id(s3, &config.s3_bucket, epoch, 0, local_max).await?; + + let all_max_ids = + s3_coordination::download_all_max_ids(s3, &config.s3_bucket, epoch, poll_interval) + .await?; + let min_max = *all_max_ids.iter().min().unwrap(); + let max_id_inclusive = min_max.saturating_sub(config.safety_buffer_ids); + + let manifest = Manifest { + epoch, + chunk_size: config.chunk_size, + max_id_inclusive, + }; + s3_coordination::upload_manifest(s3, &config.s3_bucket, epoch, &manifest).await?; + tracing::info!( + "Epoch {}: manifest created (max_id_inclusive={}, chunk_size={})", + epoch, + max_id_inclusive, + config.chunk_size + ); + Ok(manifest) + } else { + let local_max = store.get_max_serial_id().await? as u64; + s3_coordination::upload_max_id( + s3, + &config.s3_bucket, + epoch, + config.party_id, + local_max, + ) + .await?; + + s3_coordination::download_manifest(s3, &config.s3_bucket, epoch, poll_interval).await + } +} + +#[allow(clippy::too_many_arguments)] +async fn process_chunk_staging( + pool: &PgPool, + store: &Store, + staging_schema: &str, + shared_secret: &[u8; 32], + party_id: u8, + epoch: u32, + chunk_id: u32, + manifest: &Manifest, +) -> Result<()> { + let (start, end) = manifest.chunk_range(chunk_id); + + let entries: Vec<_> = store + .stream_irises_in_range(start..end) + .try_collect() + .await?; + + let staging_entries: Vec = entries + .into_iter() + .map(|iris| { + let version_id = iris.version_id(); + let iris_id = iris.id(); + let (_, lc, lm, rc, rm) = randomize_iris(iris, shared_secret, party_id as usize); + StagingIrisEntry { + epoch: epoch as i32, + id: iris_id, + chunk_id: chunk_id as i32, + left_code: cast_slice::(&lc.coefs).to_vec(), + left_mask: cast_slice::(&lm.coefs).to_vec(), + right_code: cast_slice::(&rc.coefs).to_vec(), + right_mask: cast_slice::(&rm.coefs).to_vec(), + original_version_id: version_id, + rerand_epoch: (epoch + 1) as i32, + } + }) + .collect(); + + const BATCH_SIZE: usize = 500; + for batch in staging_entries.chunks(BATCH_SIZE) { + insert_staging_irises(pool, staging_schema, batch).await?; + } + + Ok(()) +} diff --git a/iris-mpc-upgrade/src/epoch.rs b/iris-mpc-upgrade/src/epoch.rs new file mode 100644 index 0000000000..1cd79da13d --- /dev/null +++ b/iris-mpc-upgrade/src/epoch.rs @@ -0,0 +1,223 @@ +use aws_sdk_s3::Client as S3Client; +use aws_sdk_secretsmanager::Client as SecretsManagerClient; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use eyre::{eyre, Result}; +use std::time::Duration; + +use crate::s3_coordination; +use crate::tripartite_dh; + +fn secret_id(env: &str, epoch: u32, party_id: u8) -> String { + format!( + "{}/iris-mpc-db-rerandomization/epoch-{}/private-key-party-{}", + env, epoch, party_id + ) +} + +/// Check if a private key for this epoch already exists in Secrets Manager. +async fn load_private_key_from_sm( + sm: &SecretsManagerClient, + env: &str, + epoch: u32, + party_id: u8, +) -> Result> { + let sid = secret_id(env, epoch, party_id); + match sm + .get_secret_value() + .secret_id(&sid) + .version_stage("AWSCURRENT") + .send() + .await + { + Ok(output) => { + let b64 = output + .secret_string() + .ok_or_else(|| eyre!("Secret {} has no string value", sid))?; + let bytes = STANDARD.decode(b64)?; + let key = tripartite_dh::PrivateKey::deserialize(&bytes) + .map_err(|e| eyre!("Failed to deserialize private key from SM: {:?}", e))?; + Ok(Some(key)) + } + Err(e) => { + let svc = e.into_service_error(); + if svc.is_resource_not_found_exception() { + Ok(None) + } else { + Err(eyre!("SM GetSecretValue failed for {}: {}", sid, svc)) + } + } + } +} + +async fn save_private_key_to_sm( + sm: &SecretsManagerClient, + env: &str, + epoch: u32, + party_id: u8, + key: &tripartite_dh::PrivateKey, +) -> Result<()> { + let sid = secret_id(env, epoch, party_id); + let b64 = STANDARD.encode(key.serialize()); + + match sm + .create_secret() + .name(&sid) + .secret_string(&b64) + .send() + .await + { + Ok(_) => Ok(()), + Err(e) => { + let svc = e.into_service_error(); + if svc.is_resource_exists_exception() { + sm.put_secret_value() + .secret_id(&sid) + .secret_string(&b64) + .send() + .await + .map_err(|e| eyre!("SM PutSecretValue failed for {}: {}", sid, e))?; + Ok(()) + } else { + Err(eyre!("SM CreateSecret failed for {}: {}", sid, svc)) + } + } + } +} + +async fn delete_private_key_from_sm( + sm: &SecretsManagerClient, + env: &str, + epoch: u32, + party_id: u8, +) -> Result<()> { + let sid = secret_id(env, epoch, party_id); + sm.delete_secret() + .secret_id(&sid) + .force_delete_without_recovery(true) + .send() + .await + .map_err(|e| eyre!("SM DeleteSecret failed for {}: {}", sid, e))?; + tracing::info!("Deleted epoch {} private key from SM", epoch); + Ok(()) +} + +/// Idempotent key generation for an epoch. +/// +/// 1. Check SM for existing private key +/// 2. If found: load it, derive public key, re-upload to S3 (covers crash between SM write and S3 upload) +/// 3. If not found: generate new keypair, write to SM first, then upload public key to S3 +pub async fn idempotent_keygen( + sm: &SecretsManagerClient, + s3: &S3Client, + bucket: &str, + env: &str, + epoch: u32, + party_id: u8, +) -> Result { + if let Some(existing) = load_private_key_from_sm(sm, env, epoch, party_id).await? { + tracing::info!( + "Epoch {}: private key found in SM, re-uploading public key to S3", + epoch + ); + let public_key = existing.public_key(); + let pk_b64 = STANDARD.encode(public_key.serialize()); + s3_coordination::upload_public_key(s3, bucket, epoch, party_id, &pk_b64).await?; + return Ok(existing); + } + + tracing::info!( + "Epoch {}: generating fresh BLS12-381 keypair for party {}", + epoch, + party_id + ); + let mut rng = rand::rngs::OsRng; + let private_key = tripartite_dh::PrivateKey::random(&mut rng); + + save_private_key_to_sm(sm, env, epoch, party_id, &private_key).await?; + + let public_key = private_key.public_key(); + let pk_b64 = STANDARD.encode(public_key.serialize()); + s3_coordination::upload_public_key(s3, bucket, epoch, party_id, &pk_b64).await?; + + Ok(private_key) +} + +/// Derive the shared secret for an epoch: keygen + download peer keys + BLS pairing. +pub async fn derive_shared_secret( + sm: &SecretsManagerClient, + s3: &S3Client, + bucket: &str, + env: &str, + epoch: u32, + party_id: u8, + poll_interval: Duration, +) -> Result<[u8; 32]> { + let private_key = idempotent_keygen(sm, s3, bucket, env, epoch, party_id).await?; + + let next_id = (party_id + 1) % 3; + let prev_id = (party_id + 2) % 3; + + let pk_next_b64 = + s3_coordination::download_public_key_for_party(s3, bucket, epoch, next_id, poll_interval) + .await?; + let pk_next = tripartite_dh::PublicKeys::deserialize(&STANDARD.decode(&pk_next_b64)?) + .map_err(|e| eyre!("Failed to deserialize public key for party {}: {:?}", next_id, e))?; + + let pk_prev_b64 = + s3_coordination::download_public_key_for_party(s3, bucket, epoch, prev_id, poll_interval) + .await?; + let pk_prev = tripartite_dh::PublicKeys::deserialize(&STANDARD.decode(&pk_prev_b64)?) + .map_err(|e| eyre!("Failed to deserialize public key for party {}: {:?}", prev_id, e))?; + + let shared_secret = private_key.derive_shared_secret(&pk_next, &pk_prev); + let hash = blake3::hash(&shared_secret); + tracing::info!( + "Epoch {}: derived shared secret (blake3 fingerprint: {})", + epoch, + hash.to_hex() + ); + Ok(shared_secret) +} + +/// Determine the active epoch by scanning S3 for the highest epoch with a +/// manifest but without all three `complete` markers. Falls back to 0 if +/// no epochs exist. +pub async fn determine_active_epoch(s3: &S3Client, bucket: &str) -> Result { + let mut epoch: u32 = 0; + loop { + if !s3_coordination::manifest_exists(s3, bucket, epoch).await? { + break; + } + if s3_coordination::all_parties_complete(s3, bucket, epoch).await? { + epoch += 1; + continue; + } + return Ok(epoch); + } + Ok(epoch) +} + +/// Upload completion marker, poll for all three, then delete the epoch key from SM. +pub async fn complete_epoch( + sm: &SecretsManagerClient, + s3: &S3Client, + bucket: &str, + env: &str, + epoch: u32, + party_id: u8, + poll_interval: Duration, +) -> Result<()> { + s3_coordination::upload_epoch_complete(s3, bucket, epoch, party_id).await?; + tracing::info!( + "Epoch {}: uploaded completion marker for party {}", + epoch, + party_id + ); + + s3_coordination::poll_epoch_complete_all(s3, bucket, epoch, poll_interval).await?; + tracing::info!("Epoch {}: all parties completed", epoch); + + delete_private_key_from_sm(sm, env, epoch, party_id).await?; + Ok(()) +} diff --git a/iris-mpc-upgrade/src/lib.rs b/iris-mpc-upgrade/src/lib.rs index 2ab2e2018d..7d32241d74 100644 --- a/iris-mpc-upgrade/src/lib.rs +++ b/iris-mpc-upgrade/src/lib.rs @@ -8,10 +8,13 @@ use std::{ }; pub mod config; +pub mod continuous_rerand; +pub mod epoch; pub mod packets; pub mod proto; pub mod rerandomization; pub mod reshare; +pub mod s3_coordination; pub mod tripartite_dh; pub mod utils; diff --git a/iris-mpc-upgrade/src/rerandomization.rs b/iris-mpc-upgrade/src/rerandomization.rs index 3b1e93ea3a..40db34136c 100644 --- a/iris-mpc-upgrade/src/rerandomization.rs +++ b/iris-mpc-upgrade/src/rerandomization.rs @@ -1,10 +1,12 @@ use std::io::Read; use iris_mpc_common::{ - galois::degree4::{basis::Monomial, GaloisRingElement}, + galois::degree4::{basis::Monomial, GaloisRingElement, ShamirGaloisRingShare}, galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, + id::PartyID, }; use iris_mpc_store::DbStoredIris; +use itertools::Itertools; pub fn randomize_iris( iris: DbStoredIris, @@ -87,6 +89,61 @@ fn randomize_galois_ring_coefs(coefs: &mut [u16], xof: &mut blake3::OutputReader } } +/// Reconstruct the plaintext from 3 Shamir shares using Lagrange interpolation. +/// Verifies consistency by reconstructing from all 3 pairs (0-1, 1-2, 0-2) and +/// asserting they agree. +pub fn reconstruct_shares(share0: &[u16], share1: &[u16], share2: &[u16]) -> Vec { + let lag_01 = + ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID0, PartyID::ID1); + let lag_10 = + ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID1, PartyID::ID0); + let lag_02 = + ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID0, PartyID::ID2); + let lag_20 = + ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID2, PartyID::ID0); + let lag_12 = + ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID1, PartyID::ID2); + let lag_21 = + ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID2, PartyID::ID1); + + assert!(share0.len() == share1.len() && share1.len() == share2.len()); + + let recon01 = share0 + .chunks_exact(4) + .zip_eq(share1.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + let c = a * lag_01 + b * lag_10; + c.coefs + }) + .collect_vec(); + let recon12 = share1 + .chunks_exact(4) + .zip_eq(share2.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + let c = a * lag_12 + b * lag_21; + c.coefs + }) + .collect_vec(); + let recon02 = share0 + .chunks_exact(4) + .zip_eq(share2.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + let c = a * lag_02 + b * lag_20; + c.coefs + }) + .collect_vec(); + + assert_eq!(recon01, recon12); + assert_eq!(recon01, recon02); + recon01 +} + #[cfg(test)] mod tests { use iris_mpc_common::{ diff --git a/iris-mpc-upgrade/src/s3_coordination.rs b/iris-mpc-upgrade/src/s3_coordination.rs new file mode 100644 index 0000000000..f8ccfe4713 --- /dev/null +++ b/iris-mpc-upgrade/src/s3_coordination.rs @@ -0,0 +1,279 @@ +use aws_sdk_s3::Client as S3Client; +use eyre::{eyre, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::time::sleep; + +const NUM_PARTIES: u8 = 3; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Manifest { + pub epoch: u32, + pub chunk_size: u64, + pub max_id_inclusive: u64, +} + +impl Manifest { + pub fn num_chunks(&self) -> u32 { + if self.max_id_inclusive == 0 { + return 0; + } + self.max_id_inclusive.div_ceil(self.chunk_size) as u32 + } + + /// Returns (start_id_inclusive, end_id_exclusive) for a given chunk_id. + /// IDs are 1-based. + pub fn chunk_range(&self, chunk_id: u32) -> (u64, u64) { + let start = 1 + (chunk_id as u64) * self.chunk_size; + let end = std::cmp::min(start + self.chunk_size, self.max_id_inclusive + 1); + (start, end) + } + + pub fn chunk_is_empty(&self, chunk_id: u32) -> bool { + let (start, end) = self.chunk_range(chunk_id); + start >= end + } +} + +fn epoch_party_prefix(epoch: u32, party: u8) -> String { + format!("rerand/epoch-{}/party-{}", epoch, party) +} + +pub async fn upload_marker( + s3: &S3Client, + bucket: &str, + key: &str, + body: Vec, +) -> Result<()> { + s3.put_object() + .bucket(bucket) + .key(key) + .body(body.into()) + .send() + .await + .map_err(|e| eyre!("S3 PutObject failed for key {}: {}", key, e))?; + Ok(()) +} + +pub async fn marker_exists(s3: &S3Client, bucket: &str, key: &str) -> Result { + match s3.head_object().bucket(bucket).key(key).send().await { + Ok(_) => Ok(true), + Err(e) => { + let svc_err = e.into_service_error(); + if svc_err.is_not_found() { + Ok(false) + } else { + Err(eyre!("S3 HeadObject failed for key {}: {}", key, svc_err)) + } + } + } +} + +pub async fn download_marker(s3: &S3Client, bucket: &str, key: &str) -> Result> { + let resp = s3 + .get_object() + .bucket(bucket) + .key(key) + .send() + .await + .map_err(|e| eyre!("S3 GetObject failed for key {}: {}", key, e))?; + let bytes = resp + .body + .collect() + .await + .map_err(|e| eyre!("Failed to read S3 body for key {}: {}", key, e))?; + Ok(bytes.to_vec()) +} + +pub async fn poll_until_marker_exists( + s3: &S3Client, + bucket: &str, + key: &str, + poll_interval: Duration, +) -> Result<()> { + loop { + if marker_exists(s3, bucket, key).await? { + return Ok(()); + } + tracing::debug!("Waiting for S3 marker: {}", key); + sleep(poll_interval).await; + } +} + +/// Polls until all three parties have uploaded a given marker suffix for an epoch. +pub async fn poll_until_all_parties_marker( + s3: &S3Client, + bucket: &str, + epoch: u32, + marker_suffix: &str, + poll_interval: Duration, +) -> Result<()> { + loop { + let mut all_present = true; + for party in 0..NUM_PARTIES { + let key = format!("{}/{}", epoch_party_prefix(epoch, party), marker_suffix); + if !marker_exists(s3, bucket, &key).await? { + all_present = false; + break; + } + } + if all_present { + return Ok(()); + } + tracing::debug!( + "Waiting for all parties' {} markers for epoch {}", + marker_suffix, + epoch + ); + sleep(poll_interval).await; + } +} + +// ---- Public key ---- + +pub async fn upload_public_key( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + key_b64: &str, +) -> Result<()> { + let key = format!("{}/public-key", epoch_party_prefix(epoch, party)); + upload_marker(s3, bucket, &key, key_b64.as_bytes().to_vec()).await +} + +pub async fn download_public_key_for_party( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + poll_interval: Duration, +) -> Result { + let key = format!("{}/public-key", epoch_party_prefix(epoch, party)); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + Ok(String::from_utf8(bytes)?) +} + +// ---- Max ID watermark ---- + +pub async fn upload_max_id( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + max_id: u64, +) -> Result<()> { + let key = format!("{}/max-id", epoch_party_prefix(epoch, party)); + upload_marker(s3, bucket, &key, max_id.to_string().into_bytes()).await +} + +pub async fn download_all_max_ids( + s3: &S3Client, + bucket: &str, + epoch: u32, + poll_interval: Duration, +) -> Result<[u64; 3]> { + let mut ids = [0u64; 3]; + for party in 0..NUM_PARTIES { + let key = format!("{}/max-id", epoch_party_prefix(epoch, party)); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + let s = String::from_utf8(bytes)?; + ids[party as usize] = s + .trim() + .parse() + .map_err(|e| eyre!("Failed to parse max-id from party {}: {}", party, e))?; + } + Ok(ids) +} + +// ---- Manifest ---- + +pub async fn upload_manifest( + s3: &S3Client, + bucket: &str, + epoch: u32, + manifest: &Manifest, +) -> Result<()> { + let key = format!("{}/manifest.json", epoch_party_prefix(epoch, 0)); + let body = serde_json::to_vec(manifest)?; + upload_marker(s3, bucket, &key, body).await +} + +pub async fn download_manifest( + s3: &S3Client, + bucket: &str, + epoch: u32, + poll_interval: Duration, +) -> Result { + let key = format!("{}/manifest.json", epoch_party_prefix(epoch, 0)); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + let manifest: Manifest = serde_json::from_slice(&bytes)?; + Ok(manifest) +} + +pub async fn manifest_exists(s3: &S3Client, bucket: &str, epoch: u32) -> Result { + let key = format!("{}/manifest.json", epoch_party_prefix(epoch, 0)); + marker_exists(s3, bucket, &key).await +} + +// ---- Chunk staged markers ---- + +pub async fn upload_chunk_staged( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + chunk_id: u32, +) -> Result<()> { + let key = format!( + "{}/chunk-{}/staged", + epoch_party_prefix(epoch, party), + chunk_id + ); + upload_marker(s3, bucket, &key, b"ok".to_vec()).await +} + +pub async fn poll_chunk_staged_all( + s3: &S3Client, + bucket: &str, + epoch: u32, + chunk_id: u32, + poll_interval: Duration, +) -> Result<()> { + let suffix = format!("chunk-{}/staged", chunk_id); + poll_until_all_parties_marker(s3, bucket, epoch, &suffix, poll_interval).await +} + +// ---- Epoch completion ---- + +pub async fn upload_epoch_complete( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, +) -> Result<()> { + let key = format!("{}/complete", epoch_party_prefix(epoch, party)); + upload_marker(s3, bucket, &key, b"done".to_vec()).await +} + +pub async fn all_parties_complete(s3: &S3Client, bucket: &str, epoch: u32) -> Result { + for party in 0..NUM_PARTIES { + let key = format!("{}/complete", epoch_party_prefix(epoch, party)); + if !marker_exists(s3, bucket, &key).await? { + return Ok(false); + } + } + Ok(true) +} + +pub async fn poll_epoch_complete_all( + s3: &S3Client, + bucket: &str, + epoch: u32, + poll_interval: Duration, +) -> Result<()> { + poll_until_all_parties_marker(s3, bucket, epoch, "complete", poll_interval).await +} diff --git a/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs b/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs new file mode 100644 index 0000000000..1b92b96571 --- /dev/null +++ b/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs @@ -0,0 +1,312 @@ +#![cfg(feature = "db_dependent")] + +mod test_utils; + +use eyre::Result; +use std::sync::Mutex; +use test_utils::*; + +const STACK_SIZE: usize = 16 * 1024 * 1024; + +/// Tests share 3 Postgres instances and a global advisory lock constant, so +/// they must run sequentially. This mutex enforces that even without +/// `--test-threads=1`. +static SERIAL: Mutex<()> = Mutex::new(()); + +fn run_async(f: impl std::future::Future> + Send + 'static) { + let _guard = SERIAL.lock().unwrap_or_else(|e| e.into_inner()); + let result = std::thread::Builder::new() + .stack_size(STACK_SIZE) + .name("e2e".into()) + .spawn(move || { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .thread_stack_size(STACK_SIZE) + .enable_all() + .build() + .unwrap() + .block_on(f) + }) + .unwrap() + .join() + .unwrap(); + result.unwrap(); +} + +// ============================================================================ +// Phase 1: Clean epoch -- run one full epoch, verify crypto correctness +// ============================================================================ + +#[test] +fn phase1_clean_epoch() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 1] Clean epoch..."); + + let (h, t) = env.spawn_all(); + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness).await?; + assert!(ep >= 1, "Expected rerand_epoch >= 1, got {}", ep); + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + println!("[phase 1] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 2: Kill-and-resume -- kill mid-epoch, restart, verify recovery +// ============================================================================ + +#[test] +fn phase2_kill_and_resume() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 2] Kill-and-resume..."); + + // Run epoch 0, let 2 chunks stage, then kill + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 2).await?; + println!("[phase 2] killing after 2 chunks staged"); + stop_all(t, h).await; + + // Restart -- should resume from where it left off + println!("[phase 2] restarting..."); + let (h, t) = env.spawn_all(); + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness).await?; + assert!(ep >= 1); + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + println!("[phase 2] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 3: Concurrent modifications -- bump version_id mid-epoch, verify +// optimistic lock skips those rows +// ============================================================================ + +#[test] +fn phase3_concurrent_modifications() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + let modified_ids: Vec = vec![5, 10, 15]; + println!("[phase 3] Concurrent modifications..."); + + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 1).await?; + + // Bump version_id on a few rows (simulates a reauth) + for &id in &modified_ids { + for party in &env.harness.parties { + sqlx::query("UPDATE irises SET left_code = left_code WHERE id = $1") + .bind(id) + .execute(&party.store.pool) + .await?; + } + } + println!("[phase 3] bumped version_id on {:?}", modified_ids); + + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness).await?; + assert!(ep >= 1); + verify_fingerprints(&env.harness, &env.fingerprints, &modified_ids).await?; + println!("[phase 3] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 4: Server restart during rerand -- simulate main server startup while +// rerand is running, verify advisory lock serializes access +// ============================================================================ + +#[test] +fn phase4_server_restart_during_rerand() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 4] Server restart during rerand..."); + + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 1).await?; + + for p in 0..NUM_PARTIES { + let r = simulate_server_startup(&env.harness, p).await; + println!("[phase 4] party {} server startup: {:?}", p, r.is_ok()); + } + + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness).await?; + assert!(ep >= 1); + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + println!("[phase 4] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 5: Staggered restart -- kill one party mid-epoch, restart it, verify +// it catches up and the epoch completes +// ============================================================================ + +#[test] +fn phase5_staggered_restart() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 5] Staggered restart..."); + + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 2).await?; + + // Kill party 0 + println!("[phase 5] killing party 0 after 2 chunks"); + t[0].cancel(); + h[0].abort(); + + // Immediately restart party 0 + println!("[phase 5] restarting party 0"); + let (h0, t0) = env.spawn_rerand(0); + + wait_epoch_done(&env.harness, 0).await?; + + t0.cancel(); + h0.abort(); + let _ = h0.await; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness).await?; + assert!(ep >= 1); + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + println!("[phase 5] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 6: Multiple Epochs -- let the system run continuously across multiple +// epochs, verify seamless transition and correct rerandomization +// ============================================================================ + +#[test] +fn phase6_multiple_epochs() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 6] Multiple epochs..."); + + let (h, t) = env.spawn_all(); + + // Wait for epoch 0 to finish + wait_epoch_done(&env.harness, 0).await?; + println!("[phase 6] epoch 0 completed"); + + // The continuous rerand servers should automatically move to epoch 1 + wait_epoch_done(&env.harness, 1).await?; + println!("[phase 6] epoch 1 completed"); + + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness).await?; + assert!(ep >= 2, "Expected rerand_epoch >= 2, got {}", ep); + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + println!("[phase 6] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 7: Epoch boundary desync -- simulate epoch mismatch +// ============================================================================ + +#[test] +fn phase7_epoch_boundary_desync() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 7] Epoch boundary desync..."); + + // Setup the exact boundary desync state in DB manually to test catch-up logic + // P1 is on Epoch 0 (has max epoch 0) + // P0 and P2 are on Epoch 1 (have max epoch 1) + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + // Everyone completes Epoch 0 + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (0, 0, TRUE, TRUE, TRUE)") + .execute(pool).await.unwrap(); + } + + // P0 and P2 move to Epoch 1 + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (1, 0, TRUE, TRUE, FALSE)") + .execute(&env.harness.parties[0].store.pool).await.unwrap(); + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (1, 0, TRUE, TRUE, FALSE)") + .execute(&env.harness.parties[2].store.pool).await.unwrap(); + + // Now simulate P1 main server startup (P1 is behind on Epoch 0) + // Should catch up using safe_up_to = i32::MAX + let r1 = simulate_server_startup(&env.harness, 1).await; + assert!(r1.is_ok(), "P1 startup failed during epoch mismatch"); + + // Now simulate P0 main server startup (P0 is ahead on Epoch 1) + // Should catch up using safe_up_to = -1 (nobody confirmed Epoch 1 yet since P1 hasn't started it) + let r0 = simulate_server_startup(&env.harness, 0).await; + assert!(r0.is_ok(), "P0 startup failed during epoch mismatch"); + + println!("[phase 7] PASSED"); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 8: Disallow loading mismatched peers +// ============================================================================ + +#[test] +fn phase8_reject_desync() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 8] Reject desync..."); + + // Setup the exact boundary desync state in DB manually + // P1 is on Epoch 0 (has max epoch 0) + // P0 and P2 are on Epoch 2 (have max epoch 2) + // If a peer is *more than 1 epoch ahead*, we should panic/reject + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (0, 0, TRUE, TRUE, TRUE)") + .execute(pool).await.unwrap(); + } + + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (2, 0, TRUE, TRUE, FALSE)") + .execute(&env.harness.parties[0].store.pool).await.unwrap(); + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (2, 0, TRUE, TRUE, FALSE)") + .execute(&env.harness.parties[2].store.pool).await.unwrap(); + + let r1 = simulate_server_startup(&env.harness, 1).await; + assert!(r1.is_err(), "P1 startup should have failed due to large epoch gap"); + + println!("[phase 8] PASSED"); + + env.teardown().await + }); +} diff --git a/iris-mpc-upgrade/tests/test_utils.rs b/iris-mpc-upgrade/tests/test_utils.rs new file mode 100644 index 0000000000..e23cd3869b --- /dev/null +++ b/iris-mpc-upgrade/tests/test_utils.rs @@ -0,0 +1,371 @@ +#![allow(dead_code)] + +use eyre::Result; +use iris_mpc_common::{ + config::CommonConfig, + galois_engine::degree4::FullGaloisRingIrisCodeShare, + helpers::sync::{SyncResult, SyncState}, + iris_db::iris::IrisCode, + postgres::{AccessMode, PostgresClient}, +}; +use iris_mpc_store::rerand::{self as rerand_store}; +use iris_mpc_store::{Store, StoredIrisRef}; +use iris_mpc_upgrade::config::RerandomizeContinuousConfig; +use iris_mpc_upgrade::continuous_rerand::run_continuous_rerand; +use iris_mpc_upgrade::rerandomization::reconstruct_shares; +use rand::{rngs::StdRng, SeedableRng}; +use std::collections::HashMap; +use std::time::Duration; +use tokio_util::sync::CancellationToken; + +pub const NUM_PARTIES: usize = 3; +pub const DB_SIZE: usize = 50; +pub const CHUNK_SIZE: u64 = 25; + +fn db_urls() -> Vec { + (0..3) + .map(|i| format!("postgres://postgres:postgres@localhost:{}", 6200 + i)) + .collect() +} + +pub struct PartyDb { + pub store: Store, + pub schema_name: String, +} + +pub struct TestHarness { + pub parties: Vec, +} + +impl TestHarness { + pub async fn new(db_urls: &[&str], schema_prefix: &str) -> Result { + let mut parties = Vec::new(); + for (i, url) in db_urls.iter().enumerate() { + let schema = format!("{}_{}", schema_prefix, i); + let pg = PostgresClient::new(url, &schema, AccessMode::ReadWrite).await?; + let store = Store::new(&pg).await?; + rerand_store::ensure_staging_schema( + &store.pool, + &rerand_store::staging_schema_name(&schema), + ) + .await?; + parties.push(PartyDb { + store, + schema_name: schema, + }); + } + Ok(Self { parties }) + } + + pub fn store(&self, party: usize) -> &Store { + &self.parties[party].store + } +} + +/// Full test environment: harness + AWS clients + unique prefix + unique S3 bucket. +pub struct TestEnv { + pub harness: TestHarness, + pub s3: aws_sdk_s3::Client, + pub sm: aws_sdk_secretsmanager::Client, + pub prefix: String, + pub bucket: String, + pub fingerprints: PlaintextFingerprints, +} + +impl TestEnv { + pub async fn setup() -> Result { + let id = rand::random::(); + let prefix = format!("SMPC_e2e_{}", id); + let bucket = format!("rerand-e2e-{}", id); + let urls = db_urls(); + let url_refs: Vec<&str> = urls.iter().map(|s| s.as_str()).collect(); + let harness = TestHarness::new(&url_refs, &prefix).await?; + + let sdk = aws_config::from_env().load().await; + let s3 = aws_sdk_s3::Client::new(&sdk); + let sm = aws_sdk_secretsmanager::Client::new(&sdk); + + s3.create_bucket().bucket(&bucket).send().await + .map_err(|e| eyre::eyre!("Failed to create bucket {}: {}", bucket, e))?; + + println!(" [setup] Seeding {} irises (prefix={}, bucket={})", DB_SIZE, prefix, bucket); + seed_three_party_db(&harness, DB_SIZE).await?; + let fingerprints = snapshot_all_fingerprints(&harness).await?; + + Ok(Self { harness, s3, sm, prefix, bucket, fingerprints }) + } + + pub async fn teardown(&self) -> Result<()> { + cleanup(&self.harness).await?; + // Delete all objects in the bucket then delete the bucket + let mut token = None; + loop { + let mut req = self.s3.list_objects_v2().bucket(&self.bucket); + if let Some(t) = &token { req = req.continuation_token(t); } + let resp = req.send().await?; + for obj in resp.contents() { + if let Some(key) = obj.key() { + self.s3.delete_object().bucket(&self.bucket).key(key).send().await?; + } + } + if resp.is_truncated() == Some(true) { + token = resp.next_continuation_token().map(|s| s.to_string()); + } else { break; } + } + let _ = self.s3.delete_bucket().bucket(&self.bucket).send().await; + Ok(()) + } + + pub fn make_config(&self, party_id: u8) -> RerandomizeContinuousConfig { + RerandomizeContinuousConfig { + party_id, + db_url: format!( + "postgres://postgres:postgres@localhost:{}", + 6200 + party_id as u16 + ), + env: "testing".to_string(), + s3_bucket: self.bucket.clone(), + schema_name: format!("{}_{}", self.prefix, party_id), + chunk_size: CHUNK_SIZE, + chunk_delay_secs: 0, + safety_buffer_ids: 0, + s3_poll_interval_ms: 200, + healthcheck_port: 3020 + party_id as usize, + } + } + + pub fn spawn_rerand(&self, party_id: u8) -> (tokio::task::JoinHandle>, CancellationToken) { + let config = self.make_config(party_id); + let s3 = self.s3.clone(); + let sm = self.sm.clone(); + let store = self.harness.store(party_id as usize).clone(); + let token = CancellationToken::new(); + let tc = token.clone(); + let h = tokio::spawn(async move { + run_continuous_rerand(&config, &s3, &sm, &store, Some(&tc)).await + }); + (h, token) + } + + pub fn spawn_all(&self) -> (Vec>>, Vec) { + let mut handles = Vec::new(); + let mut tokens = Vec::new(); + for p in 0u8..3 { + let (h, t) = self.spawn_rerand(p); + handles.push(h); + tokens.push(t); + } + (handles, tokens) + } +} + +pub async fn stop_all( + tokens: Vec, + handles: Vec>>, +) { + for t in &tokens { t.cancel(); } + for h in &handles { h.abort(); } + for h in handles { let _ = h.await; } +} + +// ---- DB seeding ---- + +pub async fn seed_three_party_db(harness: &TestHarness, count: usize) -> Result<()> { + let mut rng = StdRng::seed_from_u64(42); + for chunk_start in (1..=count).step_by(100) { + let chunk_end = std::cmp::min(chunk_start + 100, count + 1); + + struct S { id: i64, lc: Vec, lm: Vec, rc: Vec, rm: Vec } + + let mut party_data: Vec> = (0..NUM_PARTIES).map(|_| Vec::new()).collect(); + for serial_id in chunk_start..chunk_end { + let il = IrisCode::random_rng(&mut rng); + let ir = IrisCode::random_rng(&mut rng); + let [l0, l1, l2] = FullGaloisRingIrisCodeShare::encode_iris_code(&il, &mut rng); + let [r0, r1, r2] = FullGaloisRingIrisCodeShare::encode_iris_code(&ir, &mut rng); + for (pi, (left, right)) in [(l0, r0), (l1, r1), (l2, r2)].into_iter().enumerate() { + party_data[pi].push(S { + id: serial_id as i64, + lc: left.code.coefs.to_vec(), lm: left.mask.coefs.to_vec(), + rc: right.code.coefs.to_vec(), rm: right.mask.coefs.to_vec(), + }); + } + } + for (pi, shares) in party_data.iter().enumerate() { + let refs: Vec = shares.iter().map(|s| StoredIrisRef { + id: s.id, left_code: &s.lc, left_mask: &s.lm, + right_code: &s.rc, right_mask: &s.rm, + }).collect(); + let store = harness.store(pi); + let mut tx = store.tx().await?; + store.insert_irises_overriding(&mut tx, &refs).await?; + tx.commit().await?; + } + } + Ok(()) +} + +// ---- Fingerprint verification ---- + +/// blake3 hash of the concatenated reconstructed plaintext (left_code ++ left_mask +/// ++ right_code ++ right_mask) for every iris ID. +pub type PlaintextFingerprints = HashMap; + +/// Compute a fingerprint for every iris in the DB by reconstructing shares +/// from all 3 parties. +pub async fn snapshot_all_fingerprints(harness: &TestHarness) -> Result { + let ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM irises ORDER BY id") + .fetch_all(&harness.store(0).pool) + .await?; + + let mut fps = PlaintextFingerprints::new(); + for (id,) in ids { + let mut shares = Vec::new(); + for party in 0..NUM_PARTIES { + shares.push(harness.store(party).get_iris_data_by_id(id).await?); + } + let mut hasher = blake3::Hasher::new(); + let fields: Vec<[&[u16]; 3]> = vec![ + [shares[0].left_code(), shares[1].left_code(), shares[2].left_code()], + [shares[0].left_mask(), shares[1].left_mask(), shares[2].left_mask()], + [shares[0].right_code(), shares[1].right_code(), shares[2].right_code()], + [shares[0].right_mask(), shares[1].right_mask(), shares[2].right_mask()], + ]; + for [s0, s1, s2] in &fields { + let recon = reconstruct_shares(s0, s1, s2); + hasher.update(bytemuck::cast_slice::(&recon)); + } + fps.insert(id, *hasher.finalize().as_bytes()); + } + Ok(fps) +} + +/// Verify that current shares reconstruct to the same plaintexts as the +/// snapshot. `skip_ids` are excluded (modified during test). +pub async fn verify_fingerprints( + harness: &TestHarness, + expected: &PlaintextFingerprints, + skip_ids: &[i64], +) -> Result<()> { + let current = snapshot_all_fingerprints(harness).await?; + let mut checked = 0; + for (id, exp) in expected { + if skip_ids.contains(id) { + continue; + } + let cur = current + .get(id) + .unwrap_or_else(|| panic!("ID {} missing from current DB", id)); + assert_eq!(exp, cur, "Plaintext fingerprint mismatch for id {}", id); + checked += 1; + } + println!(" verified {}/{} iris fingerprints", checked, expected.len()); + Ok(()) +} + +// ---- Polling helpers ---- + +pub async fn wait_epoch_done(harness: &TestHarness, epoch: i32) -> Result<()> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(120); + let start = std::time::Instant::now(); + let mut last_print = start; + loop { + if tokio::time::Instant::now() > deadline { + eyre::bail!("Timeout waiting for epoch {}", epoch); + } + let mut done = true; + let mut applied = [0usize; 3]; + for (i, party) in harness.parties.iter().enumerate() { + let rows: Vec<(bool,)> = sqlx::query_as( + "SELECT live_applied FROM rerand_progress WHERE epoch = $1", + ).bind(epoch).fetch_all(&party.store.pool).await?; + applied[i] = rows.iter().filter(|(a,)| *a).count(); + if rows.is_empty() || !rows.iter().all(|(a,)| *a) { done = false; } + } + if done { + println!(" epoch {} done in {:.1}s", epoch, start.elapsed().as_secs_f64()); + return Ok(()); + } + if last_print.elapsed() > Duration::from_secs(5) { + println!(" waiting epoch {}: applied {:?} ({:.0}s)", epoch, applied, start.elapsed().as_secs_f64()); + last_print = std::time::Instant::now(); + } + tokio::time::sleep(Duration::from_millis(500)).await; + } +} + +pub async fn wait_chunks_staged(harness: &TestHarness, epoch: i32, n: i32) -> Result<()> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(60); + let start = std::time::Instant::now(); + loop { + if tokio::time::Instant::now() > deadline { + eyre::bail!("Timeout waiting for {} chunks staged in epoch {}", n, epoch); + } + let mut max_count = 0i64; + for party in &harness.parties { + let (count,): (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM rerand_progress WHERE epoch = $1 AND staging_written = TRUE", + ).bind(epoch).fetch_one(&party.store.pool).await?; + max_count = max_count.max(count); + } + if max_count >= n as i64 { + println!(" {} chunks staged for epoch {} in {:.1}s", max_count, epoch, start.elapsed().as_secs_f64()); + return Ok(()); + } + tokio::time::sleep(Duration::from_millis(200)).await; + } +} + +// ---- Server simulation ---- + +pub async fn simulate_server_startup(harness: &TestHarness, party: usize) -> Result<()> { + let sync_result = build_test_sync_result(harness, party).await?; + let pool = &harness.parties[party].store.pool; + let schema = &harness.parties[party].schema_name; + let lock_conn = rerand_store::rerand_catchup_and_lock(pool, schema, &sync_result).await?; + let _count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM irises").fetch_one(pool).await?; + rerand_store::release_rerand_lock(lock_conn).await?; + Ok(()) +} + +async fn build_test_sync_result(harness: &TestHarness, party: usize) -> Result { + let mut all_states = Vec::new(); + for p in &harness.parties { + let rerand_state = rerand_store::build_rerand_sync_state(&p.store.pool).await.ok(); + all_states.push(SyncState { + db_len: p.store.count_irises().await? as u64, + modifications: vec![], + next_sns_sequence_num: None, + common_config: CommonConfig::default(), + rerand_state, + }); + } + let my_state = all_states[party].clone(); + Ok(SyncResult { my_state, all_states }) +} + +pub async fn assert_consistent_rerand_epoch(harness: &TestHarness) -> Result { + let mut all: Vec> = Vec::new(); + for party in &harness.parties { + all.push(sqlx::query_as("SELECT id, rerand_epoch FROM irises ORDER BY id") + .fetch_all(&party.store.pool).await?); + } + assert_eq!(all[0].len(), all[1].len()); + assert_eq!(all[1].len(), all[2].len()); + for i in 0..all[0].len() { + assert_eq!(all[0][i].1, all[1][i].1, "epoch mismatch id {} p0 vs p1", all[0][i].0); + assert_eq!(all[0][i].1, all[2][i].1, "epoch mismatch id {} p0 vs p2", all[0][i].0); + } + Ok(all[0].first().map(|(_, e)| *e).unwrap_or(0)) +} + +async fn cleanup(harness: &TestHarness) -> Result<()> { + for party in &harness.parties { + let staging = rerand_store::staging_schema_name(&party.schema_name); + let _ = sqlx::query(&format!(r#"DROP SCHEMA IF EXISTS "{}" CASCADE"#, staging)) + .execute(&party.store.pool).await; + let _ = sqlx::query(&format!(r#"DROP SCHEMA IF EXISTS "{}" CASCADE"#, party.schema_name)) + .execute(&party.store.pool).await; + } + Ok(()) +} diff --git a/iris-mpc/src/server/mod.rs b/iris-mpc/src/server/mod.rs index b33e7637f8..1a28435eec 100644 --- a/iris-mpc/src/server/mod.rs +++ b/iris-mpc/src/server/mod.rs @@ -37,6 +37,7 @@ use iris_mpc_cpu::execution::hawk_main::{ use iris_mpc_cpu::hawkers::aby3::aby3_store::Aby3Store; use iris_mpc_cpu::hnsw::graph::graph_store::GraphPg; use iris_mpc_store::loader::load_iris_db; +use iris_mpc_store::rerand::{self as rerand_store}; use iris_mpc_store::Store; use pprof::protos::Message; use pprof::ProfilerGuardBuilder; @@ -138,6 +139,13 @@ pub async fn server_main(config: Config) -> Result<()> { sync_sqs_queues(&config, &sync_result, &aws_clients).await?; + let rerand_lock_conn = rerand_store::rerand_catchup_and_lock( + &iris_store.pool, + &iris_store.schema_name, + &sync_result, + ) + .await?; + if shutdown_handler.is_shutting_down() { tracing::warn!("Shutting down has been triggered"); return Ok(()); @@ -166,6 +174,8 @@ pub async fn server_main(config: Config) -> Result<()> { ) .await?; + rerand_store::release_rerand_lock(rerand_lock_conn).await?; + background_tasks.check_tasks(); let tx_results = start_results_thread( @@ -387,11 +397,14 @@ async fn build_sync_state( tracing::info!("Database store length is: {}", db_len); + let rerand_state = rerand_store::build_rerand_sync_state(&store.pool).await.ok(); + Ok(SyncState { db_len, modifications, next_sns_sequence_num, common_config, + rerand_state, }) } diff --git a/migrations/20260226000001_add_rerand_epoch.down.sql b/migrations/20260226000001_add_rerand_epoch.down.sql new file mode 100644 index 0000000000..97fde78bcd --- /dev/null +++ b/migrations/20260226000001_add_rerand_epoch.down.sql @@ -0,0 +1,14 @@ +ALTER TABLE irises DROP COLUMN IF EXISTS rerand_epoch; + +CREATE OR REPLACE FUNCTION increment_version_id() +RETURNS TRIGGER AS $$ +BEGIN + IF (OLD.left_code IS DISTINCT FROM NEW.left_code OR + OLD.left_mask IS DISTINCT FROM NEW.left_mask OR + OLD.right_code IS DISTINCT FROM NEW.right_code OR + OLD.right_mask IS DISTINCT FROM NEW.right_mask) THEN + NEW.version_id = COALESCE(OLD.version_id, 0) + 1; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/migrations/20260226000001_add_rerand_epoch.up.sql b/migrations/20260226000001_add_rerand_epoch.up.sql new file mode 100644 index 0000000000..ac3822e7e3 --- /dev/null +++ b/migrations/20260226000001_add_rerand_epoch.up.sql @@ -0,0 +1,15 @@ +ALTER TABLE irises ADD COLUMN IF NOT EXISTS rerand_epoch INTEGER NOT NULL DEFAULT 0; + +CREATE OR REPLACE FUNCTION increment_version_id() +RETURNS TRIGGER AS $$ +BEGIN + IF (OLD.left_code IS DISTINCT FROM NEW.left_code OR + OLD.left_mask IS DISTINCT FROM NEW.left_mask OR + OLD.right_code IS DISTINCT FROM NEW.right_code OR + OLD.right_mask IS DISTINCT FROM NEW.right_mask) + AND NEW.rerand_epoch IS NOT DISTINCT FROM OLD.rerand_epoch THEN + NEW.version_id = COALESCE(OLD.version_id, 0) + 1; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/migrations/20260226000002_create_rerand_progress.down.sql b/migrations/20260226000002_create_rerand_progress.down.sql new file mode 100644 index 0000000000..791f86c6c2 --- /dev/null +++ b/migrations/20260226000002_create_rerand_progress.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS rerand_progress; diff --git a/migrations/20260226000002_create_rerand_progress.up.sql b/migrations/20260226000002_create_rerand_progress.up.sql new file mode 100644 index 0000000000..214c4da9c9 --- /dev/null +++ b/migrations/20260226000002_create_rerand_progress.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS rerand_progress ( + epoch INTEGER NOT NULL, + chunk_id INTEGER NOT NULL, + staging_written BOOLEAN NOT NULL DEFAULT FALSE, + all_confirmed BOOLEAN NOT NULL DEFAULT FALSE, + live_applied BOOLEAN NOT NULL DEFAULT FALSE, + PRIMARY KEY (epoch, chunk_id) +); From 03557fd4944a700de4182a366a7d089b5bd0c438 Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Fri, 27 Feb 2026 11:05:25 +0100 Subject: [PATCH 02/76] simplfy and fixes --- docs/specs/rerandomization.md | 26 ++- iris-mpc-store/src/rerand.rs | 196 ++++++++++-------- iris-mpc-upgrade/src/continuous_rerand.rs | 18 +- .../tests/continuous_rerand_e2e.rs | 16 ++ 4 files changed, 151 insertions(+), 105 deletions(-) diff --git a/docs/specs/rerandomization.md b/docs/specs/rerandomization.md index d541aeef29..2ab365614a 100644 --- a/docs/specs/rerandomization.md +++ b/docs/specs/rerandomization.md @@ -175,15 +175,25 @@ At startup, before `load_iris_db`: 1. **Existing**: modification sync (`sync_modifications`) — all parties catch up on modifications, producing identical `version_id` values 2. **New**: rerand sync — parties exchange a compact rerand watermark during the existing startup sync (`SyncState` exchange): - - Each party computes `(epoch, max_confirmed_chunk)` from its local `rerand_progress` table: the active epoch E and the highest `chunk_id` where `all_confirmed = TRUE`. Since chunks are processed in strictly increasing order, all chunks `0..max_confirmed_chunk` are implicitly confirmed. + - Each party computes `(epoch, max_confirmed_chunk)` from its local `rerand_progress` table: the active epoch E and the highest `chunk_id` where `all_confirmed = TRUE`. - Each party sends this single `(epoch, max_confirmed_chunk)` pair as part of `SyncState`. - - Each party computes `safe_up_to = max(max_confirmed_chunk_party_0, max_confirmed_chunk_party_1, max_confirmed_chunk_party_2)` for the agreed epoch E, then locally applies all chunks `0..safe_up_to` where `live_applied = FALSE`. - - This is safe because `all_confirmed = TRUE` at any party means that party observed all three S3 `staged` markers, which means all three parties successfully committed the chunk to their staging schemas. A slower party may not have polled S3 yet, but its staging data is already there. Using `max` ensures all parties converge to the same applied set, preventing cross-party desync where one party loads rerandomized shares and another loads stale shares. - - Edge case: if no chunks have been confirmed yet (fresh epoch or very start), `max_confirmed_chunk` is -1 / None. `safe_up_to` becomes -1 / None and the catch-up step is skipped entirely. -3. **New (DB-only catch-up)**: acquire `pg_advisory_lock(RERAND_APPLY_LOCK)` on a dedicated connection. Then for every chunk K in `0..safe_up_to` where locally `live_applied = FALSE` (in increasing order): run the same apply transaction as Step 1.11. **Keep the lock held** through step 4. + - Each party checks whether any peer is exactly 1 confirmed chunk ahead (within the same epoch, or has moved to the next epoch). If so, it applies that single chunk (`my_max_confirmed + 1`) from staging to the live DB. + - **Why at most 1 chunk**: the rerand loop has a strict per-chunk synchronization barrier — a node cannot stage chunk K+1 until all three parties have confirmed chunk K via S3 markers. Therefore it is impossible for any peer to be more than 1 confirmed chunk ahead. The implementation enforces this with a fatal bail if the gap exceeds 1 (indicates DB corruption). + - **Why `max` across peers**: `all_confirmed = TRUE` at any party means that party observed all three S3 `staged` markers, which means all three parties successfully committed the chunk to their staging schemas. A slower party may not have polled S3 yet, but its staging data is already there. + - Edge case: if all parties report the same `max_confirmed_chunk`, there is nothing to catch up and the step is skipped. +3. **New (DB-only catch-up)**: acquire `pg_advisory_lock(RERAND_APPLY_LOCK)` on a dedicated connection. If step 2 identified a chunk to apply, run the same apply transaction as Step 1.11. **Keep the lock held** through step 4. 4. **Existing**: `load_iris_db` — loads from live DB into GPU memory. The advisory lock is still held, so the rerand server cannot apply new chunks while the DB is being read into memory. 5. Release the advisory lock: `SELECT pg_advisory_unlock(RERAND_APPLY_LOCK)` on the dedicated connection, then drop the connection. +### Epoch and chunk desync safety checks + +The startup sync validates two invariants derived from the protocol's synchronization barriers: + +- **Epoch gap ≤ 1**: epochs transition via a 3-party S3 barrier (`complete` markers), so no peer can be more than 1 epoch ahead. A gap > 1 is fatal. +- **Chunk gap ≤ 1** (within the same epoch): the per-chunk S3 barrier (`staged` markers) prevents any peer from confirming more than 1 chunk ahead. A gap > 1 is fatal. + +If either check fails, the main server refuses to start. This catches DB corruption, manual interference, or bugs in the rerand server early, before any data is loaded into memory. + ### Advisory lock: startup vs rerand server concurrency Both the rerand server (Step 1.11) and the main server startup (Steps 2.3–2.4) acquire `pg_advisory_lock(RERAND_APPLY_LOCK)` before applying chunks. This ensures: @@ -200,8 +210,10 @@ sqlx::query("SELECT pg_advisory_lock($1)") .bind(RERAND_APPLY_LOCK) .execute(&mut *lock_conn).await?; -apply_catchup_chunks(&pool).await?; // uses pool -load_iris_db(&pool).await?; // uses pool +if let Some((epoch, chunk_id)) = catchup_chunk { + apply_staging_chunk(&pool, epoch, chunk_id).await?; +} +load_iris_db(&pool).await?; sqlx::query("SELECT pg_advisory_unlock($1)") .bind(RERAND_APPLY_LOCK) diff --git a/iris-mpc-store/src/rerand.rs b/iris-mpc-store/src/rerand.rs index 557304c057..d03f231cfe 100644 --- a/iris-mpc-store/src/rerand.rs +++ b/iris-mpc-store/src/rerand.rs @@ -1,6 +1,6 @@ use eyre::Result; use iris_mpc_common::helpers::sync::{RerandSyncState, SyncResult}; -use sqlx::{pool::PoolConnection, PgPool, Postgres}; +use sqlx::{PgPool}; pub const RERAND_APPLY_LOCK: i64 = 0x5245_5241_4E44; @@ -222,27 +222,6 @@ pub async fn get_current_epoch(pool: &PgPool) -> Result> { Ok(row.0) } -/// Returns chunk_ids for a given epoch where live_applied = FALSE and -/// chunk_id <= up_to_chunk, ordered ascending. -pub async fn get_unapplied_chunks( - pool: &PgPool, - epoch: i32, - up_to_chunk: i32, -) -> Result> { - let rows: Vec<(i32,)> = sqlx::query_as( - r#" - SELECT chunk_id FROM rerand_progress - WHERE epoch = $1 AND chunk_id <= $2 AND live_applied = FALSE - ORDER BY chunk_id ASC - "#, - ) - .bind(epoch) - .bind(up_to_chunk) - .fetch_all(pool) - .await?; - Ok(rows.into_iter().map(|(id,)| id).collect()) -} - // --------------------------------------------------------------------------- // Shared startup helpers (used by both HNSW and GPU servers) // --------------------------------------------------------------------------- @@ -259,76 +238,85 @@ pub async fn build_rerand_sync_state(pool: &PgPool) -> Result { }) } -/// Compute the safe-to-apply watermark from all parties' rerand sync states. -/// Returns `Some((epoch, max_chunk_id))` if there are chunks to catch up, +/// Compute the single chunk (if any) that needs to be applied during startup catch-up. +/// +/// Because the rerand loop has a strict per-chunk synchronization barrier (all 3 parties +/// must confirm chunk K before any party can stage chunk K+1), peers can be at most +/// 1 confirmed chunk ahead. Therefore, catch-up is always 0 or 1 chunks. +/// +/// Returns `Some((epoch, chunk_id))` if there is exactly one chunk to catch up, /// `None` otherwise. -pub fn compute_rerand_safe_up_to(sync_result: &SyncResult) -> Result> { +pub fn compute_rerand_catchup_chunk(sync_result: &SyncResult) -> Result> { let my_state = match sync_result.my_state.rerand_state.as_ref() { Some(s) => s, None => return Ok(None), }; let my_epoch = my_state.epoch; + let my_chunk = my_state.max_confirmed_chunk; - let rerand_states: Vec<&RerandSyncState> = sync_result - .all_states - .iter() - .filter_map(|s| s.rerand_state.as_ref()) - .collect(); + let mut any_peer_ahead = false; - if rerand_states.is_empty() { - return Ok(None); - } - - let mut safe_up_to = -1; - for s in rerand_states { - let diff = s.epoch - my_epoch; - match diff { + for s in sync_result.all_states.iter().filter_map(|s| s.rerand_state.as_ref()) { + let epoch_diff = s.epoch - my_epoch; + match epoch_diff { 0 => { - safe_up_to = safe_up_to.max(s.max_confirmed_chunk); + let chunk_diff = s.max_confirmed_chunk - my_chunk; + if chunk_diff > 1 { + eyre::bail!( + "Fatal chunk desync: peer confirmed chunk {} but local is at {} \ + (max possible difference is 1)", + s.max_confirmed_chunk, + my_chunk + ); + } + if chunk_diff == 1 { + any_peer_ahead = true; + } } 1 => { - safe_up_to = i32::MAX; - } - -1 => { - // They are behind, they contribute -1 + any_peer_ahead = true; } + -1 => {} _ => { - eyre::bail!("Fatal epoch desync: local epoch is {}, but peer is on epoch {}", my_epoch, s.epoch); + eyre::bail!( + "Fatal epoch desync: local epoch is {}, but peer is on epoch {}", + my_epoch, + s.epoch + ); } } } - if safe_up_to < 0 { + if !any_peer_ahead { return Ok(None); } - Ok(Some((my_epoch, safe_up_to))) + let catchup_chunk = my_chunk + 1; + Ok(Some((my_epoch, catchup_chunk))) } /// Perform rerand catch-up and acquire the advisory lock. /// -/// 1. Computes the safe-to-apply watermark from `sync_result`. -/// 2. If there are unapplied chunks, acquires `pg_advisory_lock(RERAND_APPLY_LOCK)` -/// on a dedicated connection, then applies all unapplied chunks. -/// 3. Returns the lock-holding connection (if the lock was acquired). -/// -/// The caller **must** keep the returned connection alive until `load_iris_db` -/// finishes, then call [`release_rerand_lock`] to release it. +/// 1. Determines whether this node is 1 chunk behind a peer. +/// 2. If so, acquires `pg_advisory_lock(RERAND_APPLY_LOCK)` on a dedicated +/// connection and applies the single missing chunk. +/// 3. Returns the lock-holding connection (caller keeps it alive through +/// `load_iris_db`, then calls [`release_rerand_lock`]). pub async fn rerand_catchup_and_lock( pool: &PgPool, schema_name: &str, sync_result: &SyncResult, -) -> Result>> { - let safe_up_to = match compute_rerand_safe_up_to(sync_result)? { +) -> Result>> { + let (epoch, chunk_id) = match compute_rerand_catchup_chunk(sync_result)? { Some(v) => v, None => return Ok(None), }; let staging_schema = staging_schema_name(schema_name); tracing::info!( - "Rerand catch-up: applying chunks up to {} for epoch {}", - safe_up_to.1, - safe_up_to.0 + "Rerand catch-up: applying epoch {} chunk {}", + epoch, + chunk_id, ); let mut conn = pool.acquire().await?; @@ -337,24 +325,29 @@ pub async fn rerand_catchup_and_lock( .execute(&mut *conn) .await?; - let unapplied = get_unapplied_chunks(pool, safe_up_to.0, safe_up_to.1).await?; - for chunk_id in unapplied { - let rows = - apply_staging_chunk(pool, &staging_schema, safe_up_to.0, chunk_id).await?; - tracing::info!( - "Rerand catch-up: applied epoch {} chunk {} ({} rows)", - safe_up_to.0, - chunk_id, - rows - ); - } + let rows = match apply_staging_chunk(pool, &staging_schema, epoch, chunk_id).await { + Ok(r) => r, + Err(e) => { + let _ = sqlx::query("SELECT pg_advisory_unlock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *conn) + .await; + return Err(e); + } + }; + tracing::info!( + "Rerand catch-up: applied epoch {} chunk {} ({} rows)", + epoch, + chunk_id, + rows + ); Ok(Some(conn)) } /// Release the advisory lock acquired by [`rerand_catchup_and_lock`]. pub async fn release_rerand_lock( - lock_conn: Option>, + lock_conn: Option>, ) -> Result<()> { if let Some(mut conn) = lock_conn { sqlx::query("SELECT pg_advisory_unlock($1)") @@ -387,55 +380,80 @@ mod tests { } #[test] - fn test_compute_rerand_safe_up_to_same_epoch() { - let p0 = dummy_sync_state(1, 5); + fn test_catchup_peer_one_chunk_ahead() { + let p0 = dummy_sync_state(1, 4); let p1 = dummy_sync_state(1, 4); - let p2 = dummy_sync_state(1, 6); + let p2 = dummy_sync_state(1, 5); + let sync_result = SyncResult { + my_state: p0.clone(), + all_states: vec![p0, p1, p2], + }; + assert_eq!( + compute_rerand_catchup_chunk(&sync_result).unwrap(), + Some((1, 5)) + ); + } + + #[test] + fn test_catchup_all_same() { + let p0 = dummy_sync_state(1, 5); + let p1 = dummy_sync_state(1, 5); + let p2 = dummy_sync_state(1, 5); let sync_result = SyncResult { my_state: p0.clone(), all_states: vec![p0, p1, p2], }; - assert_eq!(compute_rerand_safe_up_to(&sync_result).unwrap(), Some((1, 6))); + assert_eq!(compute_rerand_catchup_chunk(&sync_result).unwrap(), None); } #[test] - fn test_compute_rerand_safe_up_to_peer_ahead() { - // I am on epoch 0, but peer is on epoch 1. - // This implies the peer has confirmed all my chunks for epoch 0. + fn test_catchup_peer_epoch_ahead() { let p0 = dummy_sync_state(0, 5); - let p1 = dummy_sync_state(1, 0); // ahead + let p1 = dummy_sync_state(1, 0); let p2 = dummy_sync_state(0, 5); let sync_result = SyncResult { my_state: p0.clone(), all_states: vec![p0, p1, p2], }; - assert_eq!(compute_rerand_safe_up_to(&sync_result).unwrap(), Some((0, i32::MAX))); + assert_eq!( + compute_rerand_catchup_chunk(&sync_result).unwrap(), + Some((0, 6)) + ); } #[test] - fn test_compute_rerand_safe_up_to_peer_behind() { - // I am on epoch 1, but peer is on epoch 0. - // This implies the peer has not confirmed any chunks for epoch 1. + fn test_catchup_peer_epoch_behind() { let p0 = dummy_sync_state(1, 2); - let p1 = dummy_sync_state(0, 10); // behind + let p1 = dummy_sync_state(0, 10); let p2 = dummy_sync_state(1, 2); let sync_result = SyncResult { my_state: p0.clone(), all_states: vec![p0, p1, p2], }; - assert_eq!(compute_rerand_safe_up_to(&sync_result).unwrap(), Some((1, 2))); + assert_eq!(compute_rerand_catchup_chunk(&sync_result).unwrap(), None); } - + + #[test] + fn test_catchup_fatal_chunk_desync() { + let p0 = dummy_sync_state(1, 2); + let p1 = dummy_sync_state(1, 4); + let p2 = dummy_sync_state(1, 2); + let sync_result = SyncResult { + my_state: p0.clone(), + all_states: vec![p0, p1, p2], + }; + assert!(compute_rerand_catchup_chunk(&sync_result).is_err()); + } + #[test] - fn test_compute_rerand_safe_up_to_fatal_desync() { - // I am on epoch 1, but peer is on epoch 3 (difference > 1). + fn test_catchup_fatal_epoch_desync() { let p0 = dummy_sync_state(1, 2); - let p1 = dummy_sync_state(3, 10); // way ahead + let p1 = dummy_sync_state(3, 10); let p2 = dummy_sync_state(1, 2); let sync_result = SyncResult { my_state: p0.clone(), all_states: vec![p0, p1, p2], }; - assert!(compute_rerand_safe_up_to(&sync_result).is_err()); + assert!(compute_rerand_catchup_chunk(&sync_result).is_err()); } } diff --git a/iris-mpc-upgrade/src/continuous_rerand.rs b/iris-mpc-upgrade/src/continuous_rerand.rs index c948b5cd86..abaf9afa73 100644 --- a/iris-mpc-upgrade/src/continuous_rerand.rs +++ b/iris-mpc-upgrade/src/continuous_rerand.rs @@ -149,15 +149,7 @@ pub async fn run_continuous_rerand( .execute(&mut *lock_conn) .await?; - let rows = - apply_staging_chunk(pool, &staging_schema, active_epoch as i32, chunk_id as i32) - .await?; - tracing::info!( - "Epoch {} chunk {}: applied to live DB ({} rows updated)", - active_epoch, - chunk_id, - rows - ); + let apply_res = apply_staging_chunk(pool, &staging_schema, active_epoch as i32, chunk_id as i32).await; sqlx::query("SELECT pg_advisory_unlock($1)") .bind(RERAND_APPLY_LOCK) @@ -165,6 +157,14 @@ pub async fn run_continuous_rerand( .await?; drop(lock_conn); + let rows = apply_res?; + tracing::info!( + "Epoch {} chunk {}: applied to live DB ({} rows updated)", + active_epoch, + chunk_id, + rows + ); + chunk_id += 1; if chunk_delay > Duration::ZERO { diff --git a/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs b/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs index 1b92b96571..4d5f1f9239 100644 --- a/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs +++ b/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs @@ -305,6 +305,22 @@ fn phase8_reject_desync() { let r1 = simulate_server_startup(&env.harness, 1).await; assert!(r1.is_err(), "P1 startup should have failed due to large epoch gap"); + // Now test the new chunk gap logic + // P1 has chunk 0 confirmed, P0 has chunk 2 confirmed (gap > 1) in the same epoch + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + sqlx::query("DELETE FROM rerand_progress").execute(pool).await.unwrap(); + } + + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (3, 0, TRUE, TRUE, TRUE)") + .execute(&env.harness.parties[1].store.pool).await.unwrap(); + + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (3, 2, TRUE, TRUE, FALSE)") + .execute(&env.harness.parties[0].store.pool).await.unwrap(); + + let r1_chunk_desync = simulate_server_startup(&env.harness, 1).await; + assert!(r1_chunk_desync.is_err(), "P1 startup should have failed due to large chunk gap"); + println!("[phase 8] PASSED"); env.teardown().await From 262af1a889bd9da267e0c63e55b35e1f8f1d44e8 Mon Sep 17 00:00:00 2001 From: Philipp Sippl Date: Fri, 27 Feb 2026 11:11:06 +0100 Subject: [PATCH 03/76] clippy and semgrep --- .../bin/iris-mpc-upgrade/rerandomize_db.rs | 17 +- iris-mpc-bins/bin/iris-mpc/server.rs | 13 +- iris-mpc-store/src/rerand.rs | 35 +++- iris-mpc-upgrade/src/continuous_rerand.rs | 17 +- iris-mpc-upgrade/src/epoch.rs | 20 +- iris-mpc-upgrade/src/rerandomization.rs | 18 +- iris-mpc-upgrade/src/s3_coordination.rs | 7 +- .../tests/continuous_rerand_e2e.rs | 23 +- iris-mpc-upgrade/tests/test_utils.rs | 196 ++++++++++++++---- iris-mpc/src/server/mod.rs | 4 +- 10 files changed, 243 insertions(+), 107 deletions(-) diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs b/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs index 5cbe5c51e8..da522dfa56 100644 --- a/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs @@ -533,12 +533,15 @@ async fn rerandomize_check_main(config: ReRandomizeCheckConfig) -> Result<()> { } async fn rerandomize_continuous_main(config: RerandomizeContinuousConfig) -> Result<()> { - tracing::info!("Starting continuous rerandomization for party {}", config.party_id); + tracing::info!( + "Starting continuous rerandomization for party {}", + config.party_id + ); let mut background_tasks = TaskMonitor::new(); let healthcheck_port = config.healthcheck_port; - let _health_check_abort = background_tasks - .spawn(async move { spawn_healthcheck_server(healthcheck_port).await }); + let _health_check_abort = + background_tasks.spawn(async move { spawn_healthcheck_server(healthcheck_port).await }); background_tasks.check_tasks(); let sdk_config = aws_config::from_env().load().await; @@ -547,12 +550,8 @@ async fn rerandomize_continuous_main(config: RerandomizeContinuousConfig) -> Res let s3_client = S3Client::from_conf(s3_config.build()); let sm_client = SecretsManagerClient::from_conf(sm_config.build()); - let postgres_client = PostgresClient::new( - &config.db_url, - &config.schema_name, - AccessMode::ReadWrite, - ) - .await?; + let postgres_client = + PostgresClient::new(&config.db_url, &config.schema_name, AccessMode::ReadWrite).await?; let store = Store::new(&postgres_client).await?; continuous_rerand::run_continuous_rerand(&config, &s3_client, &sm_client, &store, None).await?; diff --git a/iris-mpc-bins/bin/iris-mpc/server.rs b/iris-mpc-bins/bin/iris-mpc/server.rs index 82d49a76e7..6d68b5e0b4 100644 --- a/iris-mpc-bins/bin/iris-mpc/server.rs +++ b/iris-mpc-bins/bin/iris-mpc/server.rs @@ -982,7 +982,9 @@ async fn server_main(config: Config) -> Result<()> { let is_ready_flag = Arc::new(AtomicBool::new(false)); let is_ready_flag_cloned = Arc::clone(&is_ready_flag); - let rerand_state = rerand_store::build_rerand_sync_state(&store.pool).await.ok(); + let rerand_state = rerand_store::build_rerand_sync_state(&store.pool) + .await + .ok(); let my_state = SyncState { db_len: store_len as u64, modifications: store.last_modifications(max_modification_lookback).await?, @@ -1318,12 +1320,9 @@ async fn server_main(config: Config) -> Result<()> { } } - let rerand_lock_conn = rerand_store::rerand_catchup_and_lock( - &store.pool, - &store.schema_name, - &sync_result, - ) - .await?; + let rerand_lock_conn = + rerand_store::rerand_catchup_and_lock(&store.pool, &store.schema_name, &sync_result) + .await?; if download_shutdown_handler.is_shutting_down() { tracing::warn!("Shutting down has been triggered"); diff --git a/iris-mpc-store/src/rerand.rs b/iris-mpc-store/src/rerand.rs index d03f231cfe..a52372d268 100644 --- a/iris-mpc-store/src/rerand.rs +++ b/iris-mpc-store/src/rerand.rs @@ -1,6 +1,6 @@ use eyre::Result; use iris_mpc_common::helpers::sync::{RerandSyncState, SyncResult}; -use sqlx::{PgPool}; +use sqlx::PgPool; pub const RERAND_APPLY_LOCK: i64 = 0x5245_5241_4E44; @@ -29,7 +29,21 @@ pub fn staging_schema_name(live_schema: &str) -> String { format!("{}_rerand_staging", live_schema) } +fn validate_identifier(name: &str) -> Result<()> { + if name.is_empty() { + eyre::bail!("SQL identifier must not be empty"); + } + if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { + eyre::bail!( + "SQL identifier contains invalid characters (only ASCII alphanumeric and _ allowed): {:?}", + name + ); + } + Ok(()) +} + pub async fn ensure_staging_schema(pool: &PgPool, staging_schema: &str) -> Result<()> { + validate_identifier(staging_schema)?; let create_schema = format!(r#"CREATE SCHEMA IF NOT EXISTS "{}""#, staging_schema); sqlx::query(&create_schema).execute(pool).await?; @@ -62,6 +76,7 @@ pub async fn insert_staging_irises( if entries.is_empty() { return Ok(()); } + validate_identifier(staging_schema)?; let table = format!("\"{}\".irises", staging_schema); let header = format!( @@ -99,6 +114,7 @@ pub async fn apply_staging_chunk( epoch: i32, chunk_id: i32, ) -> Result { + validate_identifier(staging_schema)?; let mut tx = pool.begin().await?; let update_sql = format!( @@ -215,10 +231,9 @@ pub async fn get_max_confirmed_chunk(pool: &PgPool, epoch: i32) -> Result