diff --git a/.github/workflows/build-and-push-rerandomization-protocol.yaml b/.github/workflows/build-and-push-rerandomization-protocol.yaml index 3dbb60eec2..74dddf2830 100644 --- a/.github/workflows/build-and-push-rerandomization-protocol.yaml +++ b/.github/workflows/build-and-push-rerandomization-protocol.yaml @@ -5,6 +5,7 @@ on: branches: - main - "fix/add-cacerts-to-rerandom-binary" + - "ps/cont-rerand" paths: - Dockerfile.shares-re-randomization - iris-mpc-upgrade/** diff --git a/.github/workflows/continuous-rerand-e2e-tests.yaml b/.github/workflows/continuous-rerand-e2e-tests.yaml new file mode 100644 index 0000000000..eeca3808a8 --- /dev/null +++ b/.github/workflows/continuous-rerand-e2e-tests.yaml @@ -0,0 +1,70 @@ +name: Continuous Rerand E2E Tests + +on: + pull_request: + +concurrency: + group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" + cancel-in-progress: true + +jobs: + rerand-e2e: + timeout-minutes: 30 + runs-on: + labels: ubuntu-22.04-16core + permissions: + contents: read + + steps: + - name: Checkout + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 + + - name: Get all test, doc and src files that have changed + id: changed-files-yaml + uses: tj-actions/changed-files@24d32ffd492484c1d75e0c0b894501ddb9d30d62 + with: + files_yaml: | + src: + - Dockerfile* + - Cargo.lock + - Cargo.toml + - deny.toml + - rust-toolchain.toml + - iris-*/** + - iris-mpc-upgrade/** + - iris-mpc-store/** + - iris-mpc-common/** + - docs/specs/rerandomization.md + - migrations/** + - scripts/** + - iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh + - iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.rand.yaml + - .github/workflows/continuous-rerand-e2e-tests.yaml + + - name: Cache Rust build + if: steps.changed-files-yaml.outputs.src_any_changed == 'true' + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb + id: cache-rust + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: rust-build-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + rust-build-${{ runner.os }}- + + - name: Install protobuf compiler + if: steps.changed-files-yaml.outputs.src_any_changed == 'true' + run: | + if command -v protoc > /dev/null; then + echo "protoc already installed: $(command -v protoc)" + else + sudo apt-get update + sudo apt-get install -y protobuf-compiler + fi + + - name: Run rerandomization e2e tests + if: steps.changed-files-yaml.outputs.src_any_changed == 'true' + run: | + bash iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh diff --git a/.github/workflows/temp-branch-build-and-push.yaml b/.github/workflows/temp-branch-build-and-push.yaml index 2714d2f15b..94dc40afa5 100644 --- a/.github/workflows/temp-branch-build-and-push.yaml +++ b/.github/workflows/temp-branch-build-and-push.yaml @@ -5,6 +5,7 @@ on: branches: - "dev" - "pop-3544-gpu-shutdown-guardrail" + - "ps/cont-rerand" concurrency: group: "${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}" diff --git a/Cargo.lock b/Cargo.lock index 6b45cbfd91..1a7ab7117e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2879,6 +2879,7 @@ dependencies = [ "aws-sdk-secretsmanager", "aws-sdk-sns", "aws-sdk-sqs", + "axum 0.7.7", "base64", "bincode", "bytemuck", @@ -2919,6 +2920,7 @@ dependencies = [ "axum 0.7.7", "base64", "blake3", + "bytemuck", "chrono", "clap", "clap_builder", @@ -3132,9 +3134,13 @@ dependencies = [ "iris-mpc-common", "itertools 0.13.0", "rand 0.8.5", + "reqwest", + "serde_json", "sqlx", "tokio", + "tokio-util", "tracing", + "uuid", ] [[package]] @@ -3145,12 +3151,18 @@ dependencies = [ "ark-ec", "ark-ff", "ark-serialize", + "aws-config", + "aws-sdk-s3", + "aws-sdk-secretsmanager", "axum 0.7.7", + "base64", "blake3", "bytemuck", "clap", "criterion", + "dotenvy", "eyre", + "futures", "iris-mpc-common", "iris-mpc-store", "itertools 0.13.0", @@ -3160,9 +3172,12 @@ dependencies = [ "rayon", "serde", "serde-big-array", + "serde_json", "sha2", + "sqlx", "thiserror 1.0.65", "tokio", + "tokio-util", "tonic", "tonic-build", "tracing", @@ -4471,7 +4486,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.35", - "socket2 0.6.0", + "socket2 0.5.7", "thiserror 2.0.16", "tokio", "tracing", @@ -4508,7 +4523,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.0", + "socket2 0.5.7", "tracing", "windows-sys 0.59.0", ] @@ -6539,7 +6554,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b5187f7eab..a4378323f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ thiserror = "1" tokio = { version = "=1.49", features = ["full", "rt-multi-thread"] } tokio-util = "0.7.15" toml = { version = "0.8.23", features = ["preserve_order"] } -uuid = { version = "1", features = ["v4"] } +uuid = { version = "1", features = ["v4", "v7"] } iris-mpc-cpu = { path = "./iris-mpc-cpu" } ampc-anon-stats = { git = "https://github.com/worldcoin/ampc-common.git", rev = "edc8550e918dbb111c758a5883d971d7be10bc1f" } ampc-actor-utils = { git = "https://github.com/worldcoin/ampc-common.git", rev = "edc8550e918dbb111c758a5883d971d7be10bc1f" } diff --git a/deploy/stage/common-values-iris-mpc-continuous-rerandomization.yaml b/deploy/stage/common-values-iris-mpc-continuous-rerandomization.yaml new file mode 100644 index 0000000000..aba66dea81 --- /dev/null +++ b/deploy/stage/common-values-iris-mpc-continuous-rerandomization.yaml @@ -0,0 +1,67 @@ +image: "ghcr.io/worldcoin/rerandomization-protocol:4161e46ca376d391c52139bd1b1d56ca420c1072" +replicaCount: 0 + +environment: stage + +command: ["/bin/rerandomize-db"] +args: + - "rerandomize-continuous" + +strategy: + type: Recreate + +serviceAccount: + create: true + name: "iris-mpc-continuous-rerandomization" + +datadog: + enabled: true + +ports: + - containerPort: 3000 + name: health + protocol: TCP + +startupProbe: + httpGet: + path: /health + port: health + +livenessProbe: + httpGet: + path: /health + port: health + +readinessProbe: + periodSeconds: 20 + failureThreshold: 4 + httpGet: + path: /health + port: health + +podSecurityContext: + runAsUser: 65534 + runAsGroup: 65534 + +imagePullSecrets: + - name: github-secret + +resources: + limits: + cpu: 3.5 + memory: 12Gi + requests: + cpu: 3.5 + memory: 12Gi + +nodeSelector: + kubernetes.io/arch: amd64 + workload: "continuous_rerandomization" + +tolerations: + - key: "dedicated" + operator: "Equal" + value: "continuousDbRerandomization" + effect: "NoSchedule" + +concurrencyPolicy: Replace diff --git a/deploy/stage/common-values-iris-mpc.yaml b/deploy/stage/common-values-iris-mpc.yaml index caa368801a..3c661cbf99 100644 --- a/deploy/stage/common-values-iris-mpc.yaml +++ b/deploy/stage/common-values-iris-mpc.yaml @@ -1,7 +1,7 @@ image: "ghcr.io/worldcoin/iris-mpc:v0.31.5@sha256:af92dd27cabe80eb3a01fcec21960cb79d12a44e01f17459b83cc923c339f4d4" +replicaCount: 1 environment: stage -replicaCount: 1 strategy: type: Recreate diff --git a/deploy/stage/smpcv2-0-stage/values-iris-mpc-continuous-rerandomization.yaml b/deploy/stage/smpcv2-0-stage/values-iris-mpc-continuous-rerandomization.yaml new file mode 100644 index 0000000000..fb6c453912 --- /dev/null +++ b/deploy/stage/smpcv2-0-stage/values-iris-mpc-continuous-rerandomization.yaml @@ -0,0 +1,26 @@ +env: + - name: SMPC__SERVICE__SERVICE_NAME + value: iris-mpc-continuous-rerandomization-0 + - name: AWS_REGION + value: eu-north-1 + - name: PARTY_ID + value: "0" + - name: DB_URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + - name: SCHEMA_NAME + value: SMPC_stage_0 + - name: ENVIRONMENT + value: stage + - name: RERAND_S3_BUCKET + value: wf-smpcv2-stage-continuous-rerandomization + - name: CHUNK_SIZE + value: "2000" + - name: CHUNK_DELAY_SECS + value: "1" + - name: SAFETY_BUFFER_IDS + value: "0" + - name: S3_POLL_INTERVAL_MS + value: "2000" diff --git a/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml index c489b76623..41504b70f4 100644 --- a/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml @@ -96,7 +96,7 @@ env: value: "wf-smpcv2-stage-sns-requests" - name: SMPC__ENABLE_S3_IMPORTER - value: "true" + value: "false" - name: SMPC__DB_CHUNKS_BUCKET_NAME value: "iris-mpc-db-exporter-store-node-0-stage--eun1-az3--x-s3" diff --git a/deploy/stage/smpcv2-1-stage/values-iris-mpc-continuous-rerandomization.yaml b/deploy/stage/smpcv2-1-stage/values-iris-mpc-continuous-rerandomization.yaml new file mode 100644 index 0000000000..cc10d62daf --- /dev/null +++ b/deploy/stage/smpcv2-1-stage/values-iris-mpc-continuous-rerandomization.yaml @@ -0,0 +1,26 @@ +env: + - name: SMPC__SERVICE__SERVICE_NAME + value: iris-mpc-continuous-rerandomization-1 + - name: AWS_REGION + value: eu-north-1 + - name: PARTY_ID + value: "1" + - name: DB_URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + - name: SCHEMA_NAME + value: SMPC_stage_1 + - name: ENVIRONMENT + value: stage + - name: RERAND_S3_BUCKET + value: wf-smpcv2-stage-continuous-rerandomization + - name: CHUNK_SIZE + value: "2000" + - name: CHUNK_DELAY_SECS + value: "1" + - name: SAFETY_BUFFER_IDS + value: "0" + - name: S3_POLL_INTERVAL_MS + value: "2000" diff --git a/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml index ada9e48614..b50da0249c 100644 --- a/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml @@ -96,7 +96,7 @@ env: value: "wf-smpcv2-stage-sns-requests" - name: SMPC__ENABLE_S3_IMPORTER - value: "true" + value: "false" - name: SMPC__DB_CHUNKS_BUCKET_NAME value: "iris-mpc-db-exporter-store-node-1-stage--eun1-az3--x-s3" diff --git a/deploy/stage/smpcv2-2-stage/values-iris-mpc-continuous-rerandomization.yaml b/deploy/stage/smpcv2-2-stage/values-iris-mpc-continuous-rerandomization.yaml new file mode 100644 index 0000000000..facd60e02c --- /dev/null +++ b/deploy/stage/smpcv2-2-stage/values-iris-mpc-continuous-rerandomization.yaml @@ -0,0 +1,26 @@ +env: + - name: SMPC__SERVICE__SERVICE_NAME + value: iris-mpc-continuous-rerandomization-2 + - name: AWS_REGION + value: eu-north-1 + - name: PARTY_ID + value: "2" + - name: DB_URL + valueFrom: + secretKeyRef: + key: DATABASE_AURORA_URL + name: application + - name: SCHEMA_NAME + value: SMPC_stage_2 + - name: ENVIRONMENT + value: stage + - name: RERAND_S3_BUCKET + value: wf-smpcv2-stage-continuous-rerandomization + - name: CHUNK_SIZE + value: "2000" + - name: CHUNK_DELAY_SECS + value: "1" + - name: SAFETY_BUFFER_IDS + value: "0" + - name: S3_POLL_INTERVAL_MS + value: "2000" diff --git a/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml b/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml index 4811a3baaf..153f9b65d8 100644 --- a/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml +++ b/deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml @@ -96,7 +96,7 @@ env: value: "wf-smpcv2-stage-sns-requests" - name: SMPC__ENABLE_S3_IMPORTER - value: "true" + value: "false" - name: SMPC__DB_CHUNKS_BUCKET_NAME value: "iris-mpc-db-exporter-store-node-2-stage--eun1-az3--x-s3" diff --git a/docs/specs/rerandomization.md b/docs/specs/rerandomization.md new file mode 100644 index 0000000000..135721db99 --- /dev/null +++ b/docs/specs/rerandomization.md @@ -0,0 +1,585 @@ +# Continuous Rerandomization Plan + +## Overview + +Replaces the existing, one-off rerandomization protocol by a continuous, online process that rerandomizes shares while the system is running. No downtime or restart required. + +Key design decision: in-memory shares are less likely to be exfiltrated, so only the DB (at-rest persistence) is rerandomized. The actor is completely unmodified. The rerand server handles everything, writing to a staging schema and then copying to live once all parties confirm. + +## Critical assumption: reliable modification delivery + +The correctness of this protocol depends on **every modification (reauth, deletion, reset) eventually arriving at every party via SQS**. This is a pre-existing system invariant — without it, the MPC shares diverge regardless of rerandomization. The prior system already depended on this assumption; this design makes the dependency explicit and continues to enforce the same safety boundary. Rerandomization does not weaken this guarantee, but it does create a new transient inconsistency window (see [Post-staging modifications](#post-staging-modifications-transient-inconsistency)) that relies on modification delivery to self-correct. + +The protocol enforces this in two active mechanisms, with one residual coverage gap: + +1. **SQS delete after persist** — the SQS message is only deleted *after* the modification row is durably written to the DB. If the process crashes between receiving and persisting, SQS redelivers the message. This eliminates the window where a message could be lost between delete and persist. This behavior is implemented in this branch; it is safer than main’s previous delete-before-process behavior. + +2. **Startup reconciliation recovers missing modifications** — `sync_modifications` compares modification state across all three parties. In this branch, `compare_modifications` was strengthened to emit missing completed rows, and `sync_modifications` now stages them with `upsert_recovered_modification` so they are fully replayed locally. This closes local startup drift paths that were only partially handled before and is linked with lock ordering around rerand apply/state freeze. It still fails closed on lookback overrun. + +3. **Residual gap**: `sync_modifications` is a startup procedure, not a continuous background loop. A running node that permanently lost a modification (and never restarts) will stay inconsistent for the affected row until the next epoch re-randomizes it. Periodic rolling restarts or a future continuous reconciliation loop would close this gap entirely. + +## Architecture + +1. **Rerand Server** (separate process, one per party) — rerandomizes shares, writes to staging, coordinates with peers via S3 markers, copies confirmed chunks to live DB. +2. **Main Server** (existing, minimal changes) — acquires `RERAND_APPLY_LOCK` at startup to freeze applies during `load_iris_db`. Acquires `RERAND_MODIFY_LOCK` during modification writes to serialize with rerand applies. + +The GPU actor, batch processing, and result processor are completely untouched. + +## Seed & Randomness + +One epoch is active at a time. At the start of each epoch: + +1. Each rerand server generates a fresh BLS12-381 keypair +2. Private key is saved to Secrets Manager at `{env}/iris-mpc-db-rerandomization/epoch-{E}/private-key-party-{P}` +3. Public key is uploaded to S3 at `s3://bucket/rerand/epoch-{E}/party-{P}/public-key` +4. Each rerand server downloads the other two parties' public keys from S3 (polling until all present) +5. Each derives the same 32-byte `shared_secret` via the BLS12-381 pairing + +Only the rerand server needs access to the key. The main server never touches it. + +### Keygen is idempotent on restart + +When starting an epoch, the rerand server: + +1. Best-effort cleanup: attempts to delete the previous epoch's key from Secrets Manager (covers crash during epoch transition where deletion was skipped) +2. Checks if an epoch-scoped private key already exists in Secrets Manager at `{env}/iris-mpc-db-rerandomization/epoch-{E}/private-key-party-{P}` +3. If yes: loads it, derives the public key, and uploads the public key to S3 if not already present (covers crash-after-SM-write-before-S3-upload) +4. If no: generates a new keypair, saves the private key to Secrets Manager first, then uploads the public key to S3 + +Secrets Manager is checked first because the private key is written to SM before the public key is uploaded to S3. If we crash between the two writes, on restart we find the key in SM and re-upload to S3. + +### Epoch transition + +One epoch at a time, no overlap: + +1. All three rerand servers finish processing all chunks for epoch E +2. Each server uploads a completion marker: `s3://bucket/rerand/epoch-{E}/party-{P}/complete` +3. Each server polls until all three completion markers exist +4. Keys for epoch E are deleted from Secrets Manager — old secret is destroyed, old shares (overwritten in live DB) are unrecoverable +5. Epoch E+1 begins: create/publish `manifest.json`, keygen, derive new `shared_secret`, start processing + +Old S3 markers under `epoch-{E}/` are left in place (no active cleanup). Use S3 lifecycle policies to reap old epoch prefixes after a retention period. + +On restart mid-epoch: private key is still in SM, public keys and markers are still in S3, `rerand_progress` table tells you the current epoch and which chunk to resume from. Re-derive `shared_secret`, continue. + +## S3 Coordination Bus + +All cross-party coordination uses S3 markers in a shared bucket. Each party writes to its own prefixed paths. Marker layout: + +``` +s3://bucket/rerand/epoch-{E}/party-{P}/public-key # public key for DH +s3://bucket/rerand/epoch-{E}/party-{P}/max-id # party P watermark for manifest (MAX(id)) +s3://bucket/rerand/epoch-{E}/party-{P}/manifest.json # epoch chunking manifest (party 0 writes, others read) +s3://bucket/rerand/epoch-{E}/party-{P}/chunk-{K}/staged # chunk K staging committed +s3://bucket/rerand/epoch-{E}/party-{P}/chunk-{K}/version-hash # 32-byte blake3 hash of version map (fast-path comparison) +s3://bucket/rerand/epoch-{E}/party-{P}/chunk-{K}/version-map # chunk K [(id, version_id)] pairs (downloaded only on hash mismatch) +s3://bucket/rerand/epoch-{E}/party-{P}/complete # epoch E fully done +``` + +Coordination is polling-based: a rerand server checks for peer markers by listing the S3 prefix. A few seconds of polling latency is fine for background work. All polling loops have a 30-minute timeout to surface permanently stuck peers. + +Authentication: the shared bucket uses IAM prefix policies to scope write access per party. Each party can only write to `s3://bucket/rerand/epoch-*/party-{P}/*`. All parties can read/list the full `s3://bucket/rerand/epoch-{E}/` prefix to observe peer markers. The manifest is written by the designated writer (party 0) under its own prefix (`party-0/manifest.json`) and is read-only for others. + +## Schema Changes + +### New column on `irises` + +```sql +ALTER TABLE irises ADD COLUMN rerand_epoch INTEGER NOT NULL DEFAULT 0; +``` + +### Modified `increment_version_id` trigger + +```sql +CREATE OR REPLACE FUNCTION increment_version_id() +RETURNS TRIGGER AS $$ +BEGIN + IF (OLD.left_code IS DISTINCT FROM NEW.left_code OR + OLD.left_mask IS DISTINCT FROM NEW.left_mask OR + OLD.right_code IS DISTINCT FROM NEW.right_code OR + OLD.right_mask IS DISTINCT FROM NEW.right_mask) + AND NEW.rerand_epoch IS NOT DISTINCT FROM OLD.rerand_epoch THEN + NEW.version_id = COALESCE(OLD.version_id, 0) + 1; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; +``` + +When `rerand_epoch` changes (rerandomization), share data changes but `version_id` stays the same. When `rerand_epoch` stays the same (user-facing modification), `version_id` bumps as before. + +### Staging schema + +Each party has a staging schema (`{live_schema}_rerand_staging`), created automatically by a migration that derives the name from `current_schema()`: + +```sql +CREATE TABLE irises ( + epoch INTEGER NOT NULL, + id BIGINT NOT NULL, + chunk_id INTEGER NOT NULL, + left_code BYTEA, + left_mask BYTEA, + right_code BYTEA, + right_mask BYTEA, + original_version_id SMALLINT, + rerand_epoch INTEGER, + PRIMARY KEY (epoch, id) +); +CREATE INDEX idx_staging_irises_epoch_chunk ON irises (epoch, chunk_id); +``` + +### Coordination table + +A `rerand_progress` table in each party's DB: + +```sql +CREATE TABLE rerand_progress ( + epoch INTEGER NOT NULL, + chunk_id INTEGER NOT NULL, + staging_written BOOLEAN NOT NULL DEFAULT FALSE, + all_confirmed BOOLEAN NOT NULL DEFAULT FALSE, + live_applied BOOLEAN NOT NULL DEFAULT FALSE, + PRIMARY KEY (epoch, chunk_id) +); +``` + +Chunk ranges are derived from the manifest (`chunk_size`, `max_id_inclusive`) and `chunk_id`, so they are not stored here. + +Lifecycle: `staging_written` → `all_confirmed` → `live_applied`. + +### Control table (freeze protocol) + +A `rerand_control` table with a single row, used for the coordinated freeze between the main server and the rerand worker: + +```sql +CREATE TABLE rerand_control ( + id INTEGER PRIMARY KEY DEFAULT 1 CHECK (id = 1), + freeze_requested BOOLEAN NOT NULL DEFAULT FALSE, + freeze_generation TEXT, + frozen_generation TEXT +); +INSERT INTO rerand_control (id) VALUES (1) ON CONFLICT DO NOTHING; +``` + +- `freeze_requested`: set to `TRUE` by the main server during startup; the rerand worker checks this between chunks. +- `freeze_generation`: a unique UUID per freeze request (fencing token); prevents stale acknowledgements from prior startups. +- `frozen_generation`: written by the rerand worker to acknowledge the freeze; main server polls until this matches `freeze_generation`. + +## Flow + +### Step 1: Rerand Server (per party, separate process) + +Runs continuously. **Between every chunk boundary** (and between epochs), the worker checks the `rerand_control` table for a freeze request from the main server. If `freeze_requested = TRUE`, it writes `frozen_generation = ` to acknowledge the freeze, then blocks until `freeze_requested = FALSE`. This guarantees the worker is quiesced and not holding any locks during DB load. + +1. Determine the active epoch E (uses local `rerand_progress` as start hint, then scans S3 for the highest epoch with a manifest but without all three `complete` markers). +2. Derive `shared_secret` for epoch E (keygen or resume — see above) +3. Pick next chunk range `[start, end)` for chunk K from the manifest +4. **Stage**: delete any partial staging data for this chunk (crash recovery clean slate), read entries from live schema recording each entry's `version_id`, rerandomize shares using `BLAKE3(shared_secret || iris_id)` XOF, write to staging schema with `epoch = E`, `original_version_id`, `chunk_id = K`, and `rerand_epoch = E + 1` +5. Set `staging_written = TRUE` in local `rerand_progress` for `(epoch = E, chunk_id = K)` +6. Upload version map `[(id, original_version_id)]` and its blake3 hash for the chunk to S3 +7. Upload S3 staged marker: `s3://bucket/rerand/epoch-{E}/party-{P}/chunk-{K}/staged` +8. Poll S3 until all 3 party staged markers exist for chunk K +9. Set `all_confirmed = TRUE` in local `rerand_progress` for `(epoch = E, chunk_id = K)` +10. **Apply**: + a. Download all 3 parties' version-map hashes. If all match (fast path), the staging-divergent set is empty. If any differ, download full maps and compute cross-party disagreements: IDs where any party captured a different `original_version_id`. This is purely S3 reads — no DB lock is held. + b. **Apply transaction**: open a single transaction, acquire locks, delete staging-divergent rows, apply with `version_id` CAS, clean up: + ```sql + BEGIN; + SELECT pg_advisory_xact_lock(RERAND_MODIFY_LOCK); + SELECT pg_advisory_xact_lock(RERAND_APPLY_LOCK); + + -- Delete staging-divergent IDs (cross-party disagreements) + DELETE FROM staging.irises WHERE epoch = E AND chunk_id = K AND id = ANY(staging_divergent); + + -- Apply with version_id CAS — silently skips post-staging modifications + UPDATE irises SET + left_code = staging.left_code, + left_mask = staging.left_mask, + right_code = staging.right_code, + right_mask = staging.right_mask, + rerand_epoch = staging.rerand_epoch + FROM staging_schema.irises AS staging + WHERE irises.id = staging.id + AND staging.epoch = E + AND staging.chunk_id = K + AND irises.version_id = staging.original_version_id; + + DELETE FROM staging_schema.irises WHERE epoch = E AND chunk_id = K; + UPDATE rerand_progress SET live_applied = TRUE WHERE epoch = E AND chunk_id = K; + COMMIT; -- Both locks released here + ``` +11. Proceed to next chunk (or start epoch transition if all chunks done). + +**Key property: no S3 I/O while holding DB locks.** The version-map comparison (step 10a) completes before the transaction opens. Lock hold time is bounded by DB I/O only. + +**Crash recovery for staging**: if the process crashes mid-staging, `staging_written` is still `FALSE`. On restart, the code re-enters the staging block and deletes any partial rows before re-reading. This ensures all staging rows come from one read pass (no mixed-snapshot version_ids). Inserts use `ON CONFLICT (epoch, id) DO NOTHING` as a safety net. + +**Crash recovery for S3 upload**: the S3 staged marker upload is outside the `if !staging_written` block. If the process crashes after `set_staging_written` but before the S3 upload, the marker is re-uploaded on restart (idempotent PUT). + +**Crash recovery for apply**: if the process crashes during the apply transaction, the transaction rolls back (releasing both locks). On restart, `live_applied` is still `FALSE`, so the apply is retried. The `version_id` CAS re-evaluates against current live values — safe and idempotent. + +### Step 2: Main Server Startup + +At startup, before `load_iris_db`: + +1. **Existing**: modification sync (`sync_modifications`) — all parties catch up on modifications, producing identical `version_id` values. This transaction acquires `pg_advisory_xact_lock(RERAND_MODIFY_LOCK)` to serialize with rerand applies. +2. **New — Coordinated freeze with watermark convergence**: + a. Write `freeze_requested = TRUE, freeze_generation = ` to `rerand_control`. This signals the rerand worker to pause at the next chunk boundary. + b. Poll `rerand_control` until `frozen_generation = ` (the worker has acknowledged the freeze and is not holding any locks or applying any chunks). + c. Fetch live applied watermarks from all peers via their `/rerand-watermark` endpoint (always queries the DB — not a stale snapshot). + d. Compare watermarks. Three cases: + - **All equal** → proceed to DB load. + - **Local is behind max(peers)** → release the local freeze so the worker can catch up (apply the pending chunk), sleep briefly, then re-freeze and re-check from step (a). + - **Local is at or ahead of max(peers)** → stay frozen and re-poll peers after a short sleep. The behind parties' startups will release their own freezes, letting their workers catch up. + e. This loop converges by repeatedly releasing/re-freezing until all parties report matching `(epoch, max_applied_chunk)` watermarks. Timeout after 2 minutes if convergence doesn't happen (indicates a stuck worker). Only behind parties release their freeze, while at-max parties stay frozen and wait for peers. +3. **New**: acquire `RERAND_APPLY_LOCK` on a dedicated connection (belt-and-suspenders with the freeze). +4. **Existing**: `load_iris_db` — loads from live DB into GPU/HNSW memory. Both the freeze and the advisory lock are held, so the rerand server cannot apply new chunks. +5. Release `RERAND_APPLY_LOCK`. +6. Clear `freeze_requested = FALSE` in `rerand_control`. The rerand worker resumes. + +**Rollout note**: if the `rerand_control` table doesn't exist yet (pre-migration), the freeze is skipped and startup proceeds without the freeze handshake. + +**Fail-closed invariant**: modification drift that exceeds the configured lookback window causes a hard panic (not a best-effort continue). This prevents startup with incomplete reconciliation. + +### Advisory locks + +Two advisory lock keys are used: + +- **`RERAND_APPLY_LOCK`** — serializes chunk applies with `load_iris_db`. Used as `pg_advisory_xact_lock` inside the apply transaction (auto-released on commit/rollback), and as session-level `pg_advisory_lock` during startup to hold through `load_iris_db`. +- **`RERAND_MODIFY_LOCK`** — serializes modification writes with the rerand apply. The rerand server acquires it (`pg_advisory_xact_lock`) at the start of the apply transaction. The main server acquires it (`pg_advisory_xact_lock`) inside its modification transaction to prevent writes during the apply window. + +**Why `pg_advisory_xact_lock` for applies and modifications**: session-level locks are tied to a connection. If a process is killed while holding a session-level lock on a pooled connection, the connection may be returned to the pool with the lock still held, blocking future acquirers indefinitely. Transaction-level locks avoid this: when the connection is dropped, the transaction rolls back and the lock is released automatically. + +## Conflict Resolution: Rerandomization vs Modifications + +### The problem + +Modifications (reauthentications, deletions) propagate asynchronously to each party via independent SQS queues. During continuous rerandomization, a modification can land on some parties but not others between the time different parties stage a chunk. Without protection, this causes cross-party share divergence: different parties apply the rerand to different underlying shares, breaking the MPC invariant that all 3 parties' shares reconstruct to the same plaintext. + +### Two-layer protection + +The protocol uses two layers to handle modification races: + +#### Layer 1: Cross-party version-map exchange (staging-time disagreements) + +After staging, each party uploads its `[(id, original_version_id)]` map and a blake3 hash to S3. After the staged barrier, each party downloads the 3 hashes (96 bytes). If all match, the maps are identical — no disagreements (fast path, ~100% of the time). If any hash differs, full maps are downloaded and diffed to produce the exact set of IDs where parties captured different `original_version_id` values. + +These IDs are deleted from staging before apply. **All parties compute the same staging-divergent set** (the version maps are deterministic and downloaded after the barrier), so all parties skip the same rows. This prevents the dangerous case where parties apply rerandomization on top of different base data. + +#### Layer 2: `version_id` CAS (post-staging modifications) + +Modifications that land between staging and apply are caught by the `WHERE irises.version_id = staging.original_version_id` clause in the UPDATE. Rows where `version_id` changed are silently skipped. + +**This layer does NOT guarantee cross-party consistency on its own.** Different parties may have different live `version_id` values for the same row (because the modification hasn't propagated to all parties yet), so different parties may apply rerand to different subsets of rows. + +### Post-staging modifications: transient inconsistency + +When a modification lands on party B between staging and apply, but hasn't yet reached parties A and C: + +- A and C's CAS succeeds → row is rerandomized +- B's CAS fails → row keeps the modification's shares +- The 3 parties' shares for that row are temporarily inconsistent + +**This self-corrects when the modification propagates to A and C** (via SQS). The modification is a full-row overwrite that replaces the rerandomized shares with the modification's shares, restoring consistency. The row loses its rerandomization for this epoch; the next epoch picks it up. + +**Window**: bounded by SQS delivery time (typically seconds). During this window, the DB shares are inconsistent. The in-memory shares (used for query processing) are unaffected — the actor loaded from DB at startup before the rerand applied. + +**Restart risk**: if a party restarts during this window, `sync_modifications` at startup replays pending modifications from peers, closing the gap before `load_iris_db` runs. + +**Permanent failure**: if SQS permanently drops a modification, the row stays inconsistent until the next epoch. This is a pre-existing system risk — without reliable SQS delivery, the MPC protocol is already broken regardless of rerandomization. + +### Why `rerand_epoch` and the trigger are kept + +Without the trigger change, the rerand apply would bump `version_id` (because share data changes). This is not a safety issue — the CAS works correctly either way — but it inflates `version_id` by 1 per epoch per row. Since `version_id` is `SMALLINT` (max 32767), this limits the total number of rerandomizations + modifications before overflow. The trigger keeps `version_id` as a pure user-modification counter, preserving the full range for actual reauthentications. + +## Chunking + +Chunk boundaries must be identical across parties for chunk K to be meaningful. Define them via an epoch manifest object in S3: + +- `s3://bucket/rerand/epoch-{E}/party-0/manifest.json`: `{ epoch: E, chunk_size: N, max_id_inclusive: M }` +- Party 0 writes the manifest once at epoch start under its own prefix (IAM-compliant); other parties poll until it exists and treat it as immutable. +- **Watermark sync**: before the manifest is written, each party P uploads its local watermark `max_id_party_P = SELECT MAX(id) FROM irises` to `s3://bucket/rerand/epoch-{E}/party-{P}/max-id`. +- The manifest writer waits until all three `max-id` markers exist, then sets `max_id_inclusive` as: + - `M = min(max_id_party_0, max_id_party_1, max_id_party_2) - safety_buffer_ids` + - `safety_buffer_ids` is configurable (default 0 or one chunk) to avoid rerandomizing the "tip" where replication/ingest lag could differ across parties. +- New inserts with `id > M` are left for a future epoch. +- Chunk K corresponds to `[start, end)` where `start = 1 + K * N` and `end = min(start + N, M + 1)`. + +A configurable delay (`--chunk-delay`, default e.g. 5s) is inserted between chunks to avoid sustained DB load. The rerand server should not stress the live DB with continuous writes — the delay spreads the I/O over time. The delay, chunk size, and number of parallel DB connections should all be configurable via CLI flags or environment variables. + +## Sequence Diagrams + +### Chunk lifecycle (happy path) + +All three rerand workers process each chunk in lockstep via S3 barriers. No DB locks are held during S3 coordination. + +```mermaid +sequenceDiagram + participant P0 as Rerand Worker 0 + participant P1 as Rerand Worker 1 + participant P2 as Rerand Worker 2 + participant S3 as S3 Bucket + participant DB0 as DB (Party 0) + + Note over P0,P2: Chunk K begins + + par Stage (each party reads live DB, writes staging) + P0->>DB0: Read irises [start..end], record version_ids + P0->>DB0: Write to staging schema + end + + P0->>S3: Upload version-map hash + version-map + P1->>S3: Upload version-map hash + version-map + P2->>S3: Upload version-map hash + version-map + + P0->>S3: Upload staged marker + P1->>S3: Upload staged marker + P2->>S3: Upload staged marker + + Note over P0,P2: S3 barrier — all parties poll until 3 staged markers exist + + P0->>S3: Download 3 version-map hashes + alt All hashes match (fast path) + Note over P0: staging_divergent = empty + else + Note over P0: Hash mismatch (slow path) + P0->>S3: Download 3 full version maps + Note over P0: staging_divergent = differing IDs + end + + Note over P0: Apply transaction (no S3 I/O from here) + + P0->>DB0: BEGIN + P0->>DB0: pg_advisory_xact_lock(MODIFY_LOCK) + P0->>DB0: pg_advisory_xact_lock(APPLY_LOCK) + P0->>DB0: DELETE staging_divergent from staging + P0->>DB0: UPDATE irises FROM staging WHERE version_id CAS + P0->>DB0: DELETE staging rows, mark live_applied + P0->>DB0: COMMIT (locks released) +``` + +### Startup with coordinated freeze + +The main server freezes the rerand worker, verifies watermark equality across all three parties, then loads the DB snapshot. + +```mermaid +sequenceDiagram + participant MS as Main Server + participant RW as Rerand Worker + participant DB as Postgres + participant Peer1 as Peer Server 1 + participant Peer2 as Peer Server 2 + + Note over MS: Startup begins (after sync_modifications) + + MS->>DB: SET freeze_requested=TRUE, freeze_generation=G1 + + RW->>DB: (between chunks) Read rerand_control + Note over RW: Sees freeze_requested=TRUE, generation=G1 + RW->>DB: SET frozen_generation=G1 + Note over RW: Blocks in poll loop + + MS->>DB: Poll until frozen_generation=G1 + Note over MS: Worker is quiesced + + MS->>DB: Read local applied watermark + MS->>Peer1: GET /rerand-watermark + Peer1-->>MS: {epoch: 3, max_applied_chunk: 7} + MS->>Peer2: GET /rerand-watermark + Peer2-->>MS: {epoch: 3, max_applied_chunk: 7} + + loop Convergence + alt All watermarks equal + MS->>DB: pg_advisory_lock(APPLY_LOCK) + MS->>DB: load_iris_db (full DB snapshot into memory) + MS->>DB: pg_advisory_unlock(APPLY_LOCK) + MS->>DB: SET freeze_requested=FALSE + MS->>RW: Poll sees freeze_requested=FALSE + RW->>RW: Resume chunk processing (startup continues) + else + alt Local behind max + MS->>MS: Local behind max + MS->>DB: SET freeze_requested=FALSE + MS->>RW: Resume to catch up + MS->>MS: sleep + re-freeze with new request + else + Note over MS: Local at max, peers behind + MS->>MS: sleep briefly + end + end + end +``` + +### Freeze generation handoff (crash recovery) + +If the main server crashes while the worker is frozen, the new server instance writes a new generation. The worker detects the change and re-acknowledges. + +```mermaid +sequenceDiagram + participant MS1 as Main Server (attempt 1) + participant MS2 as Main Server (attempt 2) + participant RW as Rerand Worker + participant DB as Postgres + + MS1->>DB: SET freeze_requested=TRUE, freeze_generation=G1 + RW->>DB: SET frozen_generation=G1 + Note over RW: Blocked in freeze loop + + MS1->>MS1: CRASH (freeze_requested still TRUE) + + Note over MS2: Restart + + MS2->>DB: SET freeze_requested=TRUE, freeze_generation=G2 + + RW->>DB: Poll: reads freeze_generation=G2 (≠ G1) + RW->>DB: SET frozen_generation=G2 + Note over RW: Still blocked, now acked for G2 + + MS2->>DB: Poll until frozen_generation=G2 + Note over MS2: Proceeds with watermark check + load + + MS2->>DB: SET freeze_requested=FALSE + Note over RW: Resumes +``` + +### Modification conflict resolution + +Shows how the version-map exchange (Layer 1) and version_id CAS (Layer 2) handle a modification that arrives asymmetrically. + +```mermaid +sequenceDiagram + participant PA as Party A + participant PB as Party B + participant PC as Party C + participant SQS as SQS + + Note over PA,PC: Chunk K staging begins + + SQS->>PB: Modification M (row 42) + Note over PB: version_id for row 42 bumps to V+1 + + PA->>PA: Stage row 42 with version_id=V + PB->>PB: Stage row 42 with version_id=V+1 + PC->>PC: Stage row 42 with version_id=V + + Note over PA,PC: Version-map exchange (Layer 1) + PA->>PA: version_map_hash differs from PB + Note over PA,PC: row 42 added to staging_divergent + + Note over PA,PC: Apply — row 42 deleted from staging on all parties + Note over PA,PC: Row 42 is NOT rerandomized (safe) + + SQS->>PA: Modification M arrives (later) + SQS->>PC: Modification M arrives (later) + Note over PA,PC: All parties now have M applied — consistent + Note over PA,PC: Row 42 will be rerandomized in next epoch +``` + +### Startup watermark convergence (freeze race) + +During a rolling deploy, all three main servers freeze their local rerand workers. Because the freeze is per-party (not a global barrier), workers may pause at different chunk boundaries. The convergence protocol handles this: only the behind party releases its freeze, leading parties stay frozen. This guarantees convergence without the leading parties advancing further. + +```mermaid +sequenceDiagram + participant MSA as Main Server A + participant MSB as Main Server B + participant MSC as Main Server C + participant WA as Worker A + participant WB as Worker B + participant WC as Worker C + + Note over WA,WC: Workers processing in lockstep via S3 barrier + + Note over MSA,MSC: Deploy — all 3 main servers restart together + + MSA->>WA: freeze_requested=TRUE + MSB->>WB: freeze_requested=TRUE + MSC->>WC: freeze_requested=TRUE + + Note over WA: Finishes chunk 8 apply, THEN sees freeze + WA->>WA: Paused at watermark (E, 8) + + Note over WB: Sees freeze BEFORE chunk 8 apply + WB->>WB: Paused at watermark (E, 7) + + Note over WC: Finishes chunk 8 apply, THEN sees freeze + WC->>WC: Paused at watermark (E, 8) + + MSA->>MSA: Local=(E,8), max=(E,8), Peer B=(E,7) + Note over MSA: Local at max → stay frozen, re-poll + + MSB->>MSB: Local=(E,7), max=(E,8) + Note over MSB: Local behind → release freeze + + MSC->>MSC: Local=(E,8), max=(E,8), Peer B=(E,7) + Note over MSC: Local at max → stay frozen, re-poll + + Note over WB: Worker B resumes (only B unfreezes) + Note over WA,WC: Workers A & C stay frozen — cannot advance + + WB->>WB: Apply chunk 8 (already confirmed via S3) + + MSB->>WB: Re-freeze (new generation) + WB->>WB: Paused at watermark (E, 8) + + MSA->>MSA: Re-poll peers: B=(E,8), C=(E,8) + Note over MSA: All watermarks = (E, 8) ✓ + + MSB->>MSB: Re-check: local=(E,8), A=(E,8), C=(E,8) + Note over MSB: All watermarks = (E, 8) ✓ + + MSC->>MSC: Re-poll peers: A=(E,8), B=(E,8) + Note over MSC: All watermarks = (E, 8) ✓ + + Note over MSA,MSC: All parties proceed with DB load +``` + +### Post-staging modification: transient DB inconsistency (same as before) + +When a modification arrives at one party between staging and apply, but hasn't yet propagated to the others, the version_id CAS causes asymmetric application. The DB shares are temporarily inconsistent but self-correct when the modification propagates. In-memory shares (used for live queries) are unaffected. + +```mermaid +sequenceDiagram + participant PA as Party A (DB) + participant PB as Party B (DB) + participant PC as Party C (DB) + participant SQS as SQS + + Note over PA,PC: All parties staged row 42 with version_id=V + + Note over PA,PC: S3 barrier passed, version maps match (row 42 NOT in staging_divergent) + + SQS->>PB: Modification M arrives at Party B only + Note over PB: Row 42 version_id bumps V→V+1 + + Note over PA,PC: Apply transaction (under RERAND_MODIFY_LOCK) + + PA->>PA: UPDATE WHERE version_id=V → CAS succeeds ✓ + Note over PA: Row 42 rerandomized + + PB->>PB: UPDATE WHERE version_id=V → CAS fails (V≠V+1) ✗ + Note over PB: Row 42 keeps modification shares + + PC->>PC: UPDATE WHERE version_id=V → CAS succeeds ✓ + Note over PC: Row 42 rerandomized + + rect rgb(255, 230, 230) + Note over PA,PC: ⚠ TRANSIENT INCONSISTENCY WINDOW + Note over PA: rerandomized shares + Note over PB: modification shares + Note over PC: rerandomized shares + Note over PA,PC: Shamir reconstruction would be WRONG for row 42 + Note over PA,PC: But in-memory shares (serving queries) are unaffected + end + + SQS->>PA: Modification M propagates to A + Note over PA: Row 42 overwritten with modification shares + + SQS->>PC: Modification M propagates to C + Note over PC: Row 42 overwritten with modification shares + + rect rgb(230, 255, 230) + Note over PA,PC: ✓ CONSISTENT — all parties have modification shares + Note over PA,PC: Row 42 will be rerandomized in next epoch + end +``` diff --git a/iris-mpc-bins/Cargo.toml b/iris-mpc-bins/Cargo.toml index bdf398f471..76d3005ffe 100644 --- a/iris-mpc-bins/Cargo.toml +++ b/iris-mpc-bins/Cargo.toml @@ -78,6 +78,7 @@ iris-mpc-store = { path = "../iris-mpc-store" } iris-mpc-upgrade-hawk = { path = "../iris-mpc-upgrade-hawk" } iris-mpc-utils = { path = "../iris-mpc-utils" } blake3 = "1.8.2" +bytemuck.workspace = true aws-smithy-types = "1.2.9" clap_builder = "4.5.51" @@ -207,6 +208,10 @@ path = "bin/iris-mpc-upgrade/reshare-client.rs" name = "rerandomize-db" path = "bin/iris-mpc-upgrade/rerandomize_db.rs" +[[bin]] +name = "verify-shares" +path = "bin/iris-mpc-upgrade/verify_shares.rs" + # --------------------- # binaries for iris-mpc-upgrade-hawk diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh b/iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh new file mode 100755 index 0000000000..aeb3a0aa28 --- /dev/null +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/continuous-rerand-local.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +rm -rf "*.log" + +docker-compose -f "$SCRIPT_DIR/docker-compose.rand.yaml" down --remove-orphans -v +docker-compose -f "$SCRIPT_DIR/docker-compose.rand.yaml" up -d + +sleep 10 + +aws_local() { + AWS_ACCESS_KEY_ID=test AWS_SECRET_ACCESS_KEY=test AWS_DEFAULT_REGION=us-east-1 \ + aws --endpoint-url=http://${LOCALSTACK_HOST:-localhost}:4566 "$@" +} + +# Create S3 bucket for rerand coordination markers +BUCKET_NAME=wf-smpcv2-rerand-testing +aws_local s3api create-bucket --bucket $BUCKET_NAME --region us-east-1 + +# Build binaries +cargo build -p iris-mpc-bins --release --bin seed-v2-dbs --bin rerandomize-db + +TARGET_DIR=$(cargo metadata --format-version 1 | jq ".target_directory" -r) + +# Set AWS env vars for localstack +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test +export AWS_DEFAULT_REGION=us-east-1 +export AWS_ENDPOINT_URL="http://127.0.0.1:4566" + +export ENVIRONMENT="testing" + +# Seed DBs with initial data (using first 3 new-db containers as live DBs) +echo "=== Seeding DBs ===" +$TARGET_DIR/release/seed-v2-dbs \ + --db-url-party-0 postgres://postgres:postgres@localhost:6200 \ + --db-url-party-1 postgres://postgres:postgres@localhost:6201 \ + --db-url-party-2 postgres://postgres:postgres@localhost:6202 \ + --schema-name-party-0 SMPC_testing_0 \ + --schema-name-party-1 SMPC_testing_1 \ + --schema-name-party-2 SMPC_testing_2 \ + --fill-to 1000 \ + --batch-size 100 +echo "Seeding complete" + +# Run continuous rerandomization for all 3 parties in parallel +echo "=== Starting continuous rerandomization ===" +COMMON_ARGS="--chunk-size 200 --chunk-delay-secs 1 --s3-poll-interval-ms 2000 --safety-buffer-ids 0" + +$TARGET_DIR/release/rerandomize-db rerandomize-continuous \ + --party-id 0 \ + --db-url postgres://postgres:postgres@localhost:6200 \ + --schema-name SMPC_testing_0 \ + --s3-bucket $BUCKET_NAME \ + --healthcheck-port 3010 \ + $COMMON_ARGS & +PID_0=$! + +$TARGET_DIR/release/rerandomize-db rerandomize-continuous \ + --party-id 1 \ + --db-url postgres://postgres:postgres@localhost:6201 \ + --schema-name SMPC_testing_1 \ + --s3-bucket $BUCKET_NAME \ + --healthcheck-port 3011 \ + $COMMON_ARGS & +PID_1=$! + +$TARGET_DIR/release/rerandomize-db rerandomize-continuous \ + --party-id 2 \ + --db-url postgres://postgres:postgres@localhost:6202 \ + --schema-name SMPC_testing_2 \ + --s3-bucket $BUCKET_NAME \ + --healthcheck-port 3012 \ + $COMMON_ARGS & +PID_2=$! + +echo "Rerand servers started: PIDs $PID_0, $PID_1, $PID_2" +echo "Waiting for one epoch to complete (watching for completion markers in S3)..." + +# Poll until epoch 0 completion markers exist for all parties +MAX_WAIT=300 +ELAPSED=0 +while [ $ELAPSED -lt $MAX_WAIT ]; do + COMPLETE=true + for P in 0 1 2; do + KEY="rerand/epoch-0/party-${P}/complete" + if ! aws_local s3api head-object --bucket $BUCKET_NAME --key "$KEY" >/dev/null 2>&1; then + COMPLETE=false + break + fi + done + if [ "$COMPLETE" = true ]; then + echo "=== Epoch 0 completed! ===" + break + fi + sleep 5 + ELAPSED=$((ELAPSED + 5)) + echo "Waiting... ($ELAPSED s)" +done + +if [ $ELAPSED -ge $MAX_WAIT ]; then + echo "ERROR: Epoch 0 did not complete within ${MAX_WAIT}s" +fi + +# Stop the rerand servers +kill $PID_0 $PID_1 $PID_2 2>/dev/null || true +wait $PID_0 $PID_1 $PID_2 2>/dev/null || true + +echo "=== Continuous rerandomization test finished ===" diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.rand.yaml b/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.rand.yaml index d87dde3dc3..6cfce451fa 100644 --- a/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.rand.yaml +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.rand.yaml @@ -1,48 +1,48 @@ services: new-db-1: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6200:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-2: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6201:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-3: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6202:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-4: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6203:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-5: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6204:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-6: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6205:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" localstack: - image: localstack/localstack + image: public.ecr.aws/localstack/localstack:4.14 ports: - "127.0.0.1:4566:4566" - "127.0.0.1:4571:4571" diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.yaml b/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.yaml index a4418a5340..d5f0020c33 100644 --- a/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.yaml +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/docker-compose.yaml @@ -1,55 +1,55 @@ services: old-db-shares-1: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6100:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" old-db-shares-2: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6101:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" old-db-masks-1: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6111:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-1: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6200:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-2: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6201:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-3: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6202:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" new-db-4: - image: postgres:16 + image: public.ecr.aws/docker/library/postgres:16 ports: - "6203:5432" environment: POSTGRES_USER: "postgres" POSTGRES_PASSWORD: "postgres" localstack: - image: localstack/localstack + image: public.ecr.aws/localstack/localstack:4.14 ports: - "127.0.0.1:4566:4566" - "127.0.0.1:4571:4571" @@ -64,7 +64,7 @@ services: ports: - "7000:7000" nginx: - image: nginx:1.27.1 + image: public.ecr.aws/nginx/nginx:1.27.1 depends_on: - reshare-server-2 ports: diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs b/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs index 58910d38b3..27c650b07a 100644 --- a/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/rerandomize_db.rs @@ -11,19 +11,17 @@ use base64::Engine; use clap::Parser; use eyre::Result; use futures::TryStreamExt; -use iris_mpc_common::galois; -use iris_mpc_common::galois::degree4::basis::Monomial; -use iris_mpc_common::galois::degree4::GaloisRingElement; use iris_mpc_common::galois_engine::degree4::{ GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare, }; -use iris_mpc_common::id::PartyID; use iris_mpc_common::postgres::{AccessMode, PostgresClient}; use iris_mpc_store::{DbStoredIris, Store, StoredIrisRef}; use iris_mpc_upgrade::config::{ KeyGenConfig, ReRandomizeCheckConfig, ReRandomizeConfig, ReRandomizeDbSubCommand, + RerandomizeContinuousConfig, }; -use iris_mpc_upgrade::rerandomization::randomize_iris; +use iris_mpc_upgrade::continuous_rerand; +use iris_mpc_upgrade::rerandomization::{randomize_iris, reconstruct_shares}; use iris_mpc_upgrade::tripartite_dh; use iris_mpc_upgrade::{ config::ReRandomizeDbConfig, @@ -42,6 +40,9 @@ async fn main() -> Result<()> { ReRandomizeDbSubCommand::RerandomizeDb(config) => rerandomize_db_main(config).await, ReRandomizeDbSubCommand::KeyGen(config) => keygen_main(config).await, ReRandomizeDbSubCommand::RerandomizeCheck(config) => rerandomize_check_main(config).await, + ReRandomizeDbSubCommand::RerandomizeContinuous(config) => { + rerandomize_continuous_main(config).await + } } } @@ -531,6 +532,88 @@ async fn rerandomize_check_main(config: ReRandomizeCheckConfig) -> Result<()> { Ok(()) } +async fn rerandomize_continuous_main(config: RerandomizeContinuousConfig) -> Result<()> { + tracing::info!( + "Starting continuous rerandomization for party {}", + config.party_id + ); + + let mut background_tasks = TaskMonitor::new(); + let healthcheck_port = config.healthcheck_port; + let _health_check_abort = + background_tasks.spawn(async move { spawn_healthcheck_server(healthcheck_port).await }); + background_tasks.check_tasks(); + + let cancel = tokio_util::sync::CancellationToken::new(); + let cancel_for_signal = cancel.clone(); + tokio::spawn(async move { + #[cfg(unix)] + { + use tokio::signal::unix::{signal, SignalKind}; + let mut sigterm = + signal(SignalKind::terminate()).expect("failed to install SIGTERM handler"); + tokio::select! { + _ = tokio::signal::ctrl_c() => {} + _ = sigterm.recv() => {} + } + } + #[cfg(not(unix))] + { + tokio::signal::ctrl_c() + .await + .expect("failed to install CTRL+C handler"); + } + tracing::info!("Received shutdown signal, requesting graceful rerand shutdown…"); + cancel_for_signal.cancel(); + }); + + let sdk_config = aws_config::from_env().load().await; + tracing::info!( + region = ?sdk_config.region(), + "AWS SDK config loaded" + ); + let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config); + let sm_config = aws_sdk_secretsmanager::config::Builder::from(&sdk_config); + let s3_client = S3Client::from_conf(s3_config.build()); + let sm_client = SecretsManagerClient::from_conf(sm_config.build()); + + tracing::info!( + s3_bucket = %config.s3_bucket, + party_id = config.party_id, + environment = %config.env, + "Continuous rerand starting with config" + ); + + let postgres_client = + PostgresClient::new(&config.db_url, &config.schema_name, AccessMode::ReadWrite).await?; + let store = Store::new(&postgres_client).await?; + + // Publish a DB-side heartbeat so the main server can detect that a rerand + // worker is actively running. The main server refuses to start with + // `rerand_enabled=false` while this heartbeat is fresh, catching the + // dangerous "worker deployed but server says disabled" misconfig. + let heartbeat_pool = store.pool.clone(); + let heartbeat_cancel = cancel.clone(); + let _heartbeat_abort = background_tasks.spawn(async move { + iris_mpc_store::rerand::run_worker_heartbeat_loop(&heartbeat_pool, heartbeat_cancel).await; + Ok(()) + }); + background_tasks.check_tasks(); + + continuous_rerand::run_continuous_rerand( + &config, + &s3_client, + &sm_client, + &store, + Some(&cancel), + ) + .await?; + + tracing::info!("Continuous rerand shut down gracefully"); + background_tasks.abort_and_wait_for_finish().await; + Ok(()) +} + async fn download_public_key(config: &ReRandomizeConfig, party_id: u8) -> Result { if config.env == "testing" { let bucket = config.public_key_bucket_name.as_ref().ok_or_else(|| { @@ -570,70 +653,6 @@ async fn build_read_only_store(db_url: &str, schema_name: &str) -> Result Store::new(&postgres_client).await } -fn reconstruct_shares(share0: &[u16], share1: &[u16], share2: &[u16]) -> Vec { - let lag_01 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID0, - PartyID::ID1, - ); - let lag_10 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID1, - PartyID::ID0, - ); - let lag_02 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID0, - PartyID::ID2, - ); - let lag_20 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID2, - PartyID::ID0, - ); - let lag_12 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID1, - PartyID::ID2, - ); - let lag_21 = galois::degree4::ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero( - PartyID::ID2, - PartyID::ID1, - ); - - assert!(share0.len() == share1.len() && share1.len() == share2.len()); - - let recon01 = share0 - .chunks_exact(4) - .zip_eq(share1.chunks_exact(4)) - .flat_map(|(a, b)| { - let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); - let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); - let c = a * lag_01 + b * lag_10; - c.coefs - }) - .collect_vec(); - let recon12 = share1 - .chunks_exact(4) - .zip_eq(share2.chunks_exact(4)) - .flat_map(|(a, b)| { - let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); - let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); - let c = a * lag_12 + b * lag_21; - c.coefs - }) - .collect_vec(); - let recon02 = share0 - .chunks_exact(4) - .zip_eq(share2.chunks_exact(4)) - .flat_map(|(a, b)| { - let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); - let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); - let c = a * lag_02 + b * lag_20; - c.coefs - }) - .collect_vec(); - - assert_eq!(recon01, recon12); - assert_eq!(recon01, recon02); - recon01 -} - async fn download_public_key_from_localstack(bucket: &str, party_id: u8) -> Result { let key = format!("{}-{}", PUBLIC_KEY_S3_KEY_NAME_PREFIX, party_id); let request_url = format!("http://localhost:4566/{}/{}", bucket, key); diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh b/iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh new file mode 100755 index 0000000000..e85a7cbd07 --- /dev/null +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/run-rerand-e2e-tests.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# +# Run the continuous rerandomization e2e chaos tests. +# Starts Postgres + localstack via docker-compose, runs the Rust tests, then +# tears everything down. +# +# Usage: +# ./run-rerand-e2e-tests.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +COMPOSE_FILE="$SCRIPT_DIR/docker-compose.rand.yaml" + +if command -v docker-compose >/dev/null 2>&1; then + COMPOSE=(docker-compose) +elif docker compose version >/dev/null 2>&1; then + COMPOSE=(docker compose) +else + echo "Neither docker-compose nor docker compose is available." + exit 1 +fi + +if ! command -v protoc >/dev/null 2>&1; then + echo "protoc not found. Install protobuf compiler (protobuf-compiler) before running these tests." + echo "In GitHub Actions this workflow installs it automatically." + exit 1 +fi + +cleanup() { + echo "=== Tearing down containers ===" + "${COMPOSE[@]}" -f "$COMPOSE_FILE" down --remove-orphans -v 2>/dev/null || true +} +trap cleanup EXIT + +echo "=== Starting Postgres + localstack ===" +"${COMPOSE[@]}" -f "$COMPOSE_FILE" down --remove-orphans -v 2>/dev/null || true +"${COMPOSE[@]}" -f "$COMPOSE_FILE" up -d + +echo "Waiting for services to be ready..." +for i in $(seq 1 30); do + if docker exec iris-mpc-upgrade-new-db-1-1 pg_isready -U postgres -q 2>/dev/null; then + break + fi + sleep 1 +done +docker exec iris-mpc-upgrade-new-db-1-1 pg_isready -U postgres || { echo "Postgres not ready"; exit 1; } + +for i in $(seq 1 30); do + STATUS=$(docker inspect --format='{{.State.Health.Status}}' iris-mpc-upgrade-localstack-1 2>/dev/null || echo "unknown") + if [ "$STATUS" = "healthy" ]; then + break + fi + sleep 1 +done +echo "Infrastructure ready." + +echo "=== Running e2e chaos tests ===" +cd "$REPO_ROOT" +AWS_ACCESS_KEY_ID=test \ +AWS_SECRET_ACCESS_KEY=test \ +AWS_DEFAULT_REGION=us-east-1 \ +AWS_ENDPOINT_URL=http://127.0.0.1:4566 \ +ENVIRONMENT=testing \ + cargo test -p iris-mpc-upgrade --test continuous_rerand_e2e --features db_dependent -- --include-ignored --nocapture + +echo "=== All tests passed ===" diff --git a/iris-mpc-bins/bin/iris-mpc-upgrade/verify_shares.rs b/iris-mpc-bins/bin/iris-mpc-upgrade/verify_shares.rs new file mode 100644 index 0000000000..384b6d8186 --- /dev/null +++ b/iris-mpc-bins/bin/iris-mpc-upgrade/verify_shares.rs @@ -0,0 +1,236 @@ +// WARNING: This tool reconstructs plaintext iris codes from secret shares. +// It is intended strictly for local development and staging environments with synthetic test data. + +use std::io::Write; +use std::path::PathBuf; + +use clap::Parser; +use eyre::{ensure, Result}; +use iris_mpc_common::postgres::{AccessMode, PostgresClient}; +use iris_mpc_store::Store; +use iris_mpc_upgrade::rerandomization::{try_reconstruct_shares, ReconstructionMismatch}; + +type ShareComponent<'a> = (&'a str, &'a [u16], &'a [u16], &'a [u16]); + +#[derive(Parser)] +#[command( + name = "verify-shares", + about = "Connect to all 3 party databases, reconstruct every iris entry from \ + all party-pair combinations, and produce per-row + overall hashes.\n\n\ + WARNING: This tool reconstructs plaintext iris codes from secret shares. \ + It is intended strictly for local development and staging environments \ + with synthetic test data." +)] +struct Args { + #[arg(long, env = "PARTY0_DB_URL")] + party0_db_url: String, + + #[arg(long, env = "PARTY1_DB_URL")] + party1_db_url: String, + + #[arg(long, env = "PARTY2_DB_URL")] + party2_db_url: String, + + /// Schema name shared by all parties. Overridden per-party by + /// --party{0,1,2}-schema if provided. + #[arg(long, env = "SCHEMA")] + schema: String, + + #[arg(long, env = "PARTY0_SCHEMA")] + party0_schema: Option, + + #[arg(long, env = "PARTY1_SCHEMA")] + party1_schema: Option, + + #[arg(long, env = "PARTY2_SCHEMA")] + party2_schema: Option, + + /// Output file for the per-row hash list (one hex hash per line). + #[arg(long, default_value = "iris_hashes.txt")] + output: PathBuf, + + /// Output file for detailed verification failures. + #[arg(long, default_value = "verification-output.txt")] + failures_output: PathBuf, +} + +async fn connect(url: &str, schema: &str) -> Result { + let client = PostgresClient::new(url, schema, AccessMode::ReadOnly).await?; + Store::new(&client).await +} + +fn log_mismatch( + out: &mut impl Write, + id: i64, + component: &str, + mismatch: &ReconstructionMismatch, + v0: i16, + v1: i16, + v2: i16, +) -> std::io::Result<()> { + // recon(0,1) vs recon(1,2) vs recon(0,2). + // If two pair-reconstructions agree, the party NOT in both agreeing pairs + // is the one with the bad share. + let divergent_party = match (mismatch.pairs_01_vs_12, mismatch.pairs_01_vs_02) { + // recon(0,1) != recon(1,2), but recon(0,1) == recon(0,2) + // agreeing pairs share parties 0; party 2 is the outlier + (true, false) => "party2 (recon(0,1)==recon(0,2), recon(1,2) differs)", + // recon(0,1) == recon(1,2), but recon(0,1) != recon(0,2) + // agreeing pairs share party 1; party 0 is the outlier + (false, true) => "party0 (recon(0,1)==recon(1,2), recon(0,2) differs)", + // all three disagree — cannot isolate a single bad party + (true, true) => "unknown (all three pair reconstructions differ)", + (false, false) => unreachable!(), + }; + + let msg = format!( + "id={id} component={component} version_ids=[{v0},{v1},{v2}] suspect={divergent_party}" + ); + tracing::error!("{}", msg); + writeln!(out, "{}", msg) +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let args = Args::parse(); + + tracing::warn!("*** This tool reconstructs plaintext iris codes from secret shares. ***"); + tracing::warn!("*** Only use with local/staging environments and synthetic test data. ***"); + + tracing::info!("Connecting to party databases…"); + let s0 = args.party0_schema.as_deref().unwrap_or(&args.schema); + let s1 = args.party1_schema.as_deref().unwrap_or(&args.schema); + let s2 = args.party2_schema.as_deref().unwrap_or(&args.schema); + + let stores = tokio::try_join!( + connect(&args.party0_db_url, s0), + connect(&args.party1_db_url, s1), + connect(&args.party2_db_url, s2), + )?; + let stores = [stores.0, stores.1, stores.2]; + + let counts: [usize; 3] = [ + stores[0].count_irises().await?, + stores[1].count_irises().await?, + stores[2].count_irises().await?, + ]; + tracing::info!( + "Row counts: party0={}, party1={}, party2={}", + counts[0], + counts[1], + counts[2] + ); + ensure!( + counts[0] == counts[1] && counts[1] == counts[2], + "Row counts differ across parties: {:?}", + counts + ); + let total = counts[0]; + if total == 0 { + tracing::warn!("Databases are empty, nothing to verify"); + return Ok(()); + } + + let mut overall_hasher = blake3::Hasher::new(); + let mut out = std::io::BufWriter::new(std::fs::File::create(&args.output)?); + let mut failures_out = std::io::BufWriter::new(std::fs::File::create(&args.failures_output)?); + + let mut verified = 0u64; + let mut failed = 0u64; + let log_interval = (total / 100).max(1); + + for id in 1..=(total as i64) { + let rows = tokio::try_join!( + stores[0].get_iris_data_by_id(id), + stores[1].get_iris_data_by_id(id), + stores[2].get_iris_data_by_id(id), + )?; + let (r0, r1, r2) = rows; + + let components: [ShareComponent; 4] = [ + ("left_code", r0.left_code(), r1.left_code(), r2.left_code()), + ("left_mask", r0.left_mask(), r1.left_mask(), r2.left_mask()), + ( + "right_code", + r0.right_code(), + r1.right_code(), + r2.right_code(), + ), + ( + "right_mask", + r0.right_mask(), + r1.right_mask(), + r2.right_mask(), + ), + ]; + + let mut row_ok = true; + let mut reconstructed: Vec> = Vec::with_capacity(4); + + for (name, s0, s1, s2) in &components { + match try_reconstruct_shares(s0, s1, s2) { + Ok(plain) => reconstructed.push(plain), + Err(mismatch) => { + row_ok = false; + log_mismatch( + &mut failures_out, + id, + name, + &mismatch, + r0.version_id(), + r1.version_id(), + r2.version_id(), + )?; + } + } + } + + if row_ok { + let mut row_hasher = blake3::Hasher::new(); + for plain in &reconstructed { + row_hasher.update(bytemuck::cast_slice::(plain)); + } + let row_hash = row_hasher.finalize(); + writeln!(out, "{}:{}", id, row_hash.to_hex())?; + overall_hasher.update(row_hash.as_bytes()); + } else { + failed += 1; + } + + verified += 1; + if verified as usize % log_interval == 0 { + tracing::info!( + "Progress {}/{} ({} failures so far)", + verified, + total, + failed + ); + } + } + + out.flush()?; + failures_out.flush()?; + let overall_hash = overall_hasher.finalize(); + + if failed > 0 { + tracing::error!( + "Verification completed with {} inconsistent rows out of {} (details in {})", + failed, + total, + args.failures_output.display() + ); + eyre::bail!( + "{} rows have inconsistent shares across parties. See {}", + failed, + args.failures_output.display() + ); + } + + tracing::info!("Verified all {} entries", total); + tracing::info!("Overall hash: {}", overall_hash.to_hex()); + tracing::info!("Per-row hashes written to {}", args.output.display()); + + println!("{}", overall_hash.to_hex()); + Ok(()) +} diff --git a/iris-mpc-bins/bin/iris-mpc/server.rs b/iris-mpc-bins/bin/iris-mpc/server.rs index 264cd7de20..cacf7e997c 100644 --- a/iris-mpc-bins/bin/iris-mpc/server.rs +++ b/iris-mpc-bins/bin/iris-mpc/server.rs @@ -9,7 +9,8 @@ use ampc_server_utils::batch_sync::{ use ampc_server_utils::{ delete_messages_until_sequence_num, get_next_sns_seq_num, get_others_sync_state, init_heartbeat_task, set_node_ready, shutdown_handler::ShutdownHandler, - start_coordination_server, wait_for_others_ready, wait_for_others_unready, TaskMonitor, + start_coordination_server_with_extra_routes, wait_for_others_ready, wait_for_others_unready, + TaskMonitor, }; use aws_sdk_s3::Client as S3Client; use aws_sdk_secretsmanager::Client as SecretsManagerClient; @@ -63,6 +64,7 @@ use iris_mpc_common::{ }; use iris_mpc_gpu::server::ServerActor; use iris_mpc_store::loader::load_iris_db; +use iris_mpc_store::rerand as rerand_store; use iris_mpc_store::{ fetch_and_parse_chunks, last_snapshot_timestamp, DbStoredIris, ObjectStore, S3Store, S3StoredIris, Store, StoredIrisRef, @@ -258,11 +260,13 @@ async fn server_main(config: Config) -> Result<()> { // -------------------------------------------------------------------------- tracing::info!("⚓️ ANCHOR: Starting Healthcheck, Readiness and Sync server"); + let rerand_state = rerand_store::build_rerand_sync_state(&store.pool).await?; let my_state = SyncState { db_len: store_len as u64, modifications: store.last_modifications(max_modification_lookback).await?, next_sns_sequence_num: next_sns_seq_number_future.await?, common_config: CommonConfig::from(config.clone()), + rerand_state, }; tracing::info!("Sync state: {:?}", my_state); @@ -270,12 +274,44 @@ async fn server_main(config: Config) -> Result<()> { let server_coord_config = config.server_coordination.clone().unwrap_or_else(|| { panic!("Server coordination config must be provided for healthcheck server"); }); - let (is_ready_flag, verified_peers, uuid) = start_coordination_server( + let rerand_watermark_route = { + let pool = store.pool.clone(); + axum::Router::new().route( + "/rerand-watermark", + axum::routing::get(move || { + let pool = pool.clone(); + async move { + let wm = rerand_store::get_applied_watermark_from_pool(&pool).await; + match wm { + Ok(Some((epoch, chunk))) => ( + axum::http::StatusCode::OK, + serde_json::to_string(&serde_json::json!({ + "epoch": epoch, + "max_applied_chunk": chunk, + })) + .unwrap(), + ), + Ok(None) => (axum::http::StatusCode::OK, "null".to_string()), + Err(e) => { + tracing::warn!("rerand-watermark query failed: {:?}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("DB error: {}", e), + ) + } + } + } + }), + ) + }; + + let (is_ready_flag, verified_peers, uuid) = start_coordination_server_with_extra_routes( &server_coord_config, &mut background_tasks, &shutdown_handler, &my_state, Some(batch_sync_shared_state.clone()), + Some(rerand_watermark_route), ) .await; @@ -334,7 +370,7 @@ async fn server_main(config: Config) -> Result<()> { None, &aws_clients, &shares_encryption_key_pair, - sync_result, + sync_result.clone(), ) .await?; } @@ -353,99 +389,164 @@ async fn server_main(config: Config) -> Result<()> { } } - if download_shutdown_handler.is_shutting_down() { - tracing::warn!("Shutting down has been triggered"); - return Ok(()); + // --- Coordinated rerand freeze with watermark convergence --- + if config.rerand_enabled { + eyre::ensure!( + server_coord_config.node_hostnames.len() == server_coord_config.healthcheck_ports.len(), + "node_hostnames ({}) and healthcheck_ports ({}) must have the same length", + server_coord_config.node_hostnames.len(), + server_coord_config.healthcheck_ports.len(), + ); + let peer_addrs: Vec<(&str, usize)> = server_coord_config + .node_hostnames + .iter() + .zip(server_coord_config.healthcheck_ports.iter()) + .enumerate() + .filter(|(i, _)| *i != config.party_id) + .map(|(_, (h, p))| -> eyre::Result<_> { Ok((h.as_str(), p.parse::()?)) }) + .collect::>>()?; + rerand_store::freeze_and_verify_watermarks(&store.pool, &peer_addrs).await?; + } else if rerand_store::is_worker_alive(&store.pool).await? { + // Worker heartbeat is fresh but this server is configured with rerand + // off. Starting up now would skip freeze/watermark coordination and + // risk loading a cross-party-inconsistent DB snapshot. Fail closed. + eyre::bail!( + "rerand_enabled=false in config but the rerand worker is alive \ + (heartbeat within the last {:?}). Either set SMPC__RERAND_ENABLED=true, \ + or stop the rerand worker on all parties before restarting this server.", + rerand_store::WORKER_HEARTBEAT_STALE_AFTER, + ); + } else { + tracing::info!( + "rerand_enabled=false and no fresh worker heartbeat — skipping rerand coordination" + ); } + // Worker is now frozen with verified equal watermarks. + // Everything from here until freeze release must be wrapped so that + // errors always release the freeze. + let freeze_pool = store.pool.clone(); - // refetch store_len in case we rolled back - let store_len = store.count_irises().await?; - tracing::info!("Database store length after sync: {}", store_len); - - let runtime_handle = tokio::runtime::Handle::current(); - let anon_stats_writer = if let Some(url) = config.get_anon_stats_db_url() { - let schema = config.get_anon_stats_db_schema(); - let anon_client = - AnonStatsPgClient::new(&url, &schema, AnonStatsAccessMode::ReadWrite).await?; - let anon_store = AnonStatsStore::new(&anon_client).await?; - Some((anon_store, runtime_handle.clone())) - } else { - tracing::warn!("No database URL configured for anon stats; skipping DB persistence"); - None - }; - let anon_stats_writer_for_actor = anon_stats_writer.clone(); - - let (tx, rx) = oneshot::channel(); - let config_clone = config.clone(); - background_tasks.spawn_blocking(move || { - let config = config_clone; - // -------------------------------------------------------------------------- - // ANCHOR: Load the database - // -------------------------------------------------------------------------- - tracing::info!("⚓️ ANCHOR: Starting server actor"); - match ServerActor::new( - config.party_id, - chacha_seeds, - 8, - config.max_db_size, - config.max_batch_size, - config.match_distances_buffer_size, - config.match_distances_buffer_size_extra_percent, - config.return_partial_results, - config.disable_persistence, - config.enable_debug_timing, - config.full_scan_side, - config.full_scan_side_switching_enabled, - anon_stats_writer_for_actor, - ) { - Ok((mut actor, handle)) => { - tracing::info!("⚓️ ANCHOR: Load the database"); - let res = if config.fake_db_size > 0 { - // TODO: does this even still work, since we do not page-lock the memory here? - actor.fake_db(config.fake_db_size); - Ok(()) - } else { - tracing::info!( - "Initialize iris db: Loading from DB (parallelism: {})", - parallelism - ); - let download_shutdown_handler = Arc::clone(&download_shutdown_handler); - - tokio::runtime::Handle::current().block_on(async { - load_iris_db( - &mut actor, - &store, - store_len, - parallelism, - None, - &config, - download_shutdown_handler, - ) - .await - }) - }; + let frozen_result = async { + let rerand_lock_conn = rerand_store::acquire_apply_lock(&store.pool).await?; + + if download_shutdown_handler.is_shutting_down() { + rerand_store::release_apply_lock(rerand_lock_conn).await?; + return Ok::<_, eyre::Report>(None); + } - match res { - Ok(_) => { - tx.send(Ok((handle, store))).unwrap(); + let startup_result = async { + let store_len = store.count_irises().await?; + tracing::info!("Database store length after sync: {}", store_len); + + let runtime_handle = tokio::runtime::Handle::current(); + let anon_stats_writer = if let Some(url) = config.get_anon_stats_db_url() { + let schema = config.get_anon_stats_db_schema(); + let anon_client = + AnonStatsPgClient::new(&url, &schema, AnonStatsAccessMode::ReadWrite).await?; + let anon_store = AnonStatsStore::new(&anon_client).await?; + Some((anon_store, runtime_handle.clone())) + } else { + tracing::warn!( + "No database URL configured for anon stats; skipping DB persistence" + ); + None + }; + let anon_stats_writer_for_actor = anon_stats_writer.clone(); + + let (tx, rx) = oneshot::channel(); + let config_clone = config.clone(); + background_tasks.spawn_blocking(move || { + let config = config_clone; + tracing::info!("⚓️ ANCHOR: Starting server actor"); + match ServerActor::new( + config.party_id, + chacha_seeds, + 8, + config.max_db_size, + config.max_batch_size, + config.match_distances_buffer_size, + config.match_distances_buffer_size_extra_percent, + config.return_partial_results, + config.disable_persistence, + config.enable_debug_timing, + config.full_scan_side, + config.full_scan_side_switching_enabled, + anon_stats_writer_for_actor, + ) { + Ok((mut actor, handle)) => { + tracing::info!("⚓️ ANCHOR: Load the database"); + let res = if config.fake_db_size > 0 { + actor.fake_db(config.fake_db_size); + Ok(()) + } else { + tracing::info!( + "Initialize iris db: Loading from DB (parallelism: {})", + parallelism + ); + let download_shutdown_handler = Arc::clone(&download_shutdown_handler); + + tokio::runtime::Handle::current().block_on(async { + load_iris_db( + &mut actor, + &store, + store_len, + parallelism, + None, + &config, + download_shutdown_handler, + ) + .await + }) + }; + + match res { + Ok(_) => { + tx.send(Ok((handle, store))).unwrap(); + } + Err(e) => { + tx.send(Err(e)).unwrap(); + return Ok(()); + } + } + + actor.run(); // forever } Err(e) => { tx.send(Err(e)).unwrap(); return Ok(()); } - } + }; + Ok(()) + }); - actor.run(); // forever - } - Err(e) => { - tx.send(Err(e)).unwrap(); - return Ok(()); - } - }; - Ok(()) - }); + let startup_result = rx.await; + let (handle, store) = startup_result??; + Ok::<_, eyre::Report>((handle, store)) + } + .await; - let (mut handle, store) = rx.await??; + rerand_store::release_apply_lock(rerand_lock_conn).await?; + Ok(Some(startup_result)) + } + .await; + + // Always attempt freeze release, but never let its failure undo a + // successful startup. `release_rerand_freeze` already retries internally, + // and a subsequent startup will re-issue a freeze with a new generation + // that the worker re-acknowledges (see `check_and_handle_freeze` generation + // change handling). + if let Err(e) = rerand_store::release_rerand_freeze(&freeze_pool).await { + tracing::error!( + "Failed to release rerand freeze after startup: {:?}. \ + Worker will re-acknowledge on next startup freeze.", + e + ); + } + + let (mut handle, store) = match frozen_result? { + None => return Ok(()), + Some(r) => r?, + }; background_tasks.check_tasks(); @@ -744,6 +845,10 @@ async fn server_main(config: Config) -> Result<()> { let mut tx = store_bg.tx().await?; + if !config_bg.disable_persistence { + iris_mpc_store::rerand::acquire_modify_lock(&mut tx).await?; + } + store_bg .update_modifications(&mut tx, &modifications.values().collect::>()) .await?; diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index 63c4816cdb..48b17685c9 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -255,6 +255,9 @@ pub struct Config { #[serde(default)] pub enable_modifications_replay: bool, + #[serde(default)] + pub rerand_enabled: bool, + #[serde(default = "default_pprof_s3_bucket")] pub pprof_s3_bucket: String, @@ -664,6 +667,7 @@ pub struct CommonConfig { max_modifications_lookback: usize, enable_modifications_sync: bool, enable_modifications_replay: bool, + rerand_enabled: bool, sqs_sync_long_poll_seconds: i32, schema_name: String, hnsw_schema_name_suffix: String, @@ -750,6 +754,7 @@ impl From for CommonConfig { max_modifications_lookback, enable_modifications_sync, enable_modifications_replay, + rerand_enabled, sqs_sync_long_poll_seconds, schema_name, hnsw_schema_name_suffix, @@ -819,6 +824,7 @@ impl From for CommonConfig { max_modifications_lookback, enable_modifications_sync, enable_modifications_replay, + rerand_enabled, sqs_sync_long_poll_seconds, schema_name, hnsw_schema_name_suffix, diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs index 21fc4e1cd6..794b05262b 100644 --- a/iris-mpc-common/src/helpers/smpc_request.rs +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -188,6 +188,9 @@ pub enum ReceiveRequestError { BatchPollingTimeout(i32), #[error("Failed to parse shares: {0}")] FailedToProcessIrisShares(Report), + + #[error("Failed to mark request as deleted: {0}")] + FailedToMarkRequestAsDeleted(Report), } impl From> for ReceiveRequestError { diff --git a/iris-mpc-common/src/helpers/sync.rs b/iris-mpc-common/src/helpers/sync.rs index 0f8ab9923e..d9f1ce5179 100644 --- a/iris-mpc-common/src/helpers/sync.rs +++ b/iris-mpc-common/src/helpers/sync.rs @@ -10,6 +10,16 @@ pub struct SyncState { pub modifications: Vec, pub next_sns_sequence_num: Option, pub common_config: CommonConfig, + #[serde(default)] + pub rerand_state: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RerandSyncState { + pub epoch: i32, + /// Highest chunk_id where `live_applied = TRUE`, or `None` if no chunks + /// have been applied yet. + pub max_applied_chunk: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -239,11 +249,13 @@ impl SyncResult { let max_id = completed_max_mod_ids.iter().flatten().copied().max(); if let (Some(min_id), Some(max_id)) = (min_id, max_id) { let mod_id_diff = max_id.saturating_sub(min_id) as usize; - if mod_id_diff > self.my_state.common_config.get_max_modifications_lookback() { + let lookback = self.my_state.common_config.get_max_modifications_lookback(); + if mod_id_diff > lookback { panic!( - "Modification ID difference across nodes is too large: {:?}. Min: {:?}, Max: {:?}. \ - Can not safely handle this case, consider bumping lookback. Crashing!", - completed_max_mod_ids, min_id, max_id + "Modification ID difference across nodes ({}) exceeds lookback ({}): {:?}. \ + Min: {:?}, Max: {:?}. Cannot safely reconcile. \ + Bump max_modifications_lookback or investigate drift.", + mod_id_diff, lookback, completed_max_mod_ids, min_id, max_id ); } } @@ -281,16 +293,17 @@ impl SyncResult { .expect("At least one completed modification"); match local_copy { None => { - // If an item is completed for a party, it should at least exist in the - // local state because it should have been added during receive_batch. - // This can only happen when other party misses an in_progress mod. - // Local party will fetch until modification id X while the other party will - // fetch until mod id X-1. In this case, local party won't find X-1. - // We log and skip updating to avoid rolling back to an older share in local. - tracing::info!( - "Skip missing completed modification: {:?}", - first_completed + // The local node never received this modification (e.g., SQS + // message was lost). Roll it forward from a peer's completed + // copy so the local DB converges with the other parties. + let mut roll_forward = first_completed.clone(); + roll_forward.status = ModificationStatus::Completed.to_string(); + roll_forward.persisted = any_persisted; + tracing::warn!( + "Recovering missing completed modification from peer: {:?}", + roll_forward ); + to_update.push(roll_forward); } Some(local_m) => { if local_m.status != ModificationStatus::Completed.to_string() @@ -405,6 +418,7 @@ mod tests { modifications, next_sns_sequence_num: None, common_config: CommonConfig::from(config), + rerand_state: None, } } @@ -765,10 +779,13 @@ mod tests { // Compare modifications across nodes. let (to_update, to_delete) = sync_result.compare_modifications(); - assert_eq!(to_update.len(), 0, "Expected no modification to update"); + assert_eq!( + to_update.len(), + 1, + "Expected mod1 to be recovered from peer" + ); + assert_eq!(to_update[0].id, mod1_other.id); assert_eq!(to_delete.len(), 1, "Expected one modification to delete"); - - // Expectation: Local party should delete mod3. assert_eq!(to_delete[0], mod3_local); } @@ -848,18 +865,21 @@ mod tests { modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(200), common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 30, modifications: vec![], next_sns_sequence_num: Some(150), common_config: CommonConfig::default(), + rerand_state: None, }, ]; @@ -872,6 +892,7 @@ mod tests { modifications: vec![], next_sns_sequence_num: None, common_config: CommonConfig::default(), + rerand_state: None, }; let all_states = vec![ state_with_none_sequence_num.clone(), @@ -894,18 +915,21 @@ mod tests { modifications: vec![], next_sns_sequence_num: None, // NodeX - advanced but empty queue common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(123), // Other nodes still have messages common_config: CommonConfig::default(), + rerand_state: None, }, SyncState { db_len: 30, modifications: vec![], next_sns_sequence_num: Some(123), common_config: CommonConfig::default(), + rerand_state: None, }, ]; @@ -1027,18 +1051,21 @@ mod tests { modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::from(config1), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::from(config2), + rerand_state: None, }, SyncState { db_len: 20, modifications: vec![], next_sns_sequence_num: Some(100), common_config: CommonConfig::from(config3), + rerand_state: None, }, ]; @@ -1047,13 +1074,8 @@ mod tests { } #[test] - #[should_panic(expected = "Modification ID difference across nodes is too large")] + #[should_panic(expected = "Modification ID difference across nodes")] fn test_compare_modifications_large_id_difference_panic() { - // Create a scenario where nodes have completed modifications with IDs - // that differ by more than the max_modifications_lookback limit. - // Test lookback is (100 + 64) * 2 = 328, so we'll create a difference of 350. - - // Node 1: has completed modification with ID 1 let mod1_node1 = create_modification( 1, Some(100), @@ -1065,7 +1087,6 @@ mod tests { ); let my_state = create_sync_state_with_lookback(vec![mod1_node1], 10); - // Node 2: has completed modification with ID 15 (difference = 14 > 10) let mod15_node2 = create_modification( 15, Some(1500), @@ -1077,7 +1098,6 @@ mod tests { ); let other_state1 = create_sync_state_with_lookback(vec![mod15_node2], 10); - // Node 3: has completed modification with ID 20 (even larger) let mod20_node3 = create_modification( 20, Some(2000), @@ -1096,7 +1116,6 @@ mod tests { all_states, }; - // This should panic because max_id (20) - min_id (1) = 19 > 10 (test lookback) sync_result.compare_modifications(); } } diff --git a/iris-mpc-store/Cargo.toml b/iris-mpc-store/Cargo.toml index 97bb12933f..a4d3f1dab9 100644 --- a/iris-mpc-store/Cargo.toml +++ b/iris-mpc-store/Cargo.toml @@ -20,6 +20,10 @@ eyre.workspace = true itertools.workspace = true tracing.workspace = true tokio.workspace = true +tokio-util.workspace = true +uuid.workspace = true +reqwest.workspace = true +serde_json.workspace = true rand.workspace = true ampc-server-utils.workspace = true diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index 3b0715447b..97962c7a4f 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -1,4 +1,5 @@ pub mod loader; +pub mod rerand; mod s3_importer; use bytemuck::cast_slice; @@ -669,6 +670,39 @@ WHERE id = $1; Ok(()) } + /// Insert a modification recovered from a peer. Uses the peer's `id` to + /// keep modification IDs consistent across parties. If the `id` already + /// exists, updates the row to match the peer's state. + pub async fn upsert_recovered_modification( + &self, + tx: &mut Transaction<'_, Postgres>, + m: &Modification, + ) -> Result<()> { + sqlx::query( + r#" + INSERT INTO modifications (id, serial_id, request_type, s3_url, status, persisted, result_message_body, graph_mutation) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (id) DO UPDATE SET + status = EXCLUDED.status, + persisted = EXCLUDED.persisted, + result_message_body = EXCLUDED.result_message_body, + serial_id = EXCLUDED.serial_id, + graph_mutation = EXCLUDED.graph_mutation + "#, + ) + .bind(m.id) + .bind(m.serial_id) + .bind(&m.request_type) + .bind(&m.s3_url) + .bind(&m.status) + .bind(m.persisted) + .bind(&m.result_message_body) + .bind(&m.graph_mutation) + .execute(tx.deref_mut()) + .await?; + Ok(()) + } + /// Delete modifications based on their id. pub async fn delete_modifications( &self, diff --git a/iris-mpc-store/src/rerand.rs b/iris-mpc-store/src/rerand.rs new file mode 100644 index 0000000000..00dcc77f83 --- /dev/null +++ b/iris-mpc-store/src/rerand.rs @@ -0,0 +1,909 @@ +use std::time::Duration; + +use eyre::Result; +use iris_mpc_common::helpers::sync::RerandSyncState; +use sqlx::PgPool; + +pub const RERAND_APPLY_LOCK: i64 = 0x5245_5241_4E44; +pub const RERAND_MODIFY_LOCK: i64 = 0x5245_4D4F_4446; + +/// Acquire `RERAND_MODIFY_LOCK` as a transaction-level advisory lock. +/// Auto-released on commit/rollback. +pub async fn acquire_modify_lock(tx: &mut sqlx::Transaction<'_, sqlx::Postgres>) -> Result<()> { + sqlx::query("SELECT pg_advisory_xact_lock($1)") + .bind(RERAND_MODIFY_LOCK) + .execute(&mut **tx) + .await?; + Ok(()) +} + +pub struct StagingIrisEntry { + pub epoch: i32, + pub id: i64, + pub chunk_id: i32, + pub left_code: Vec, + pub left_mask: Vec, + pub right_code: Vec, + pub right_mask: Vec, + pub original_version_id: i16, + pub rerand_epoch: i32, +} + +#[derive(sqlx::FromRow, Debug, Clone)] +pub struct RerandProgress { + pub epoch: i32, + pub chunk_id: i32, + pub staging_written: bool, + pub all_confirmed: bool, + pub live_applied: bool, +} + +pub fn staging_schema_name(live_schema: &str) -> String { + format!("{}_rerand_staging", live_schema) +} + +fn validate_identifier(name: &str) -> Result<()> { + if name.is_empty() { + eyre::bail!("SQL identifier must not be empty"); + } + if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { + eyre::bail!( + "SQL identifier contains invalid characters (only ASCII alphanumeric and _ allowed): {:?}", + name + ); + } + Ok(()) +} + +/// Delete any partial staging data for a chunk before (re-)staging. +/// Ensures all rows come from one read pass, preventing mixed-snapshot +/// version_ids after a crash-and-retry. +pub async fn delete_staging_chunk( + pool: &PgPool, + staging_schema: &str, + epoch: i32, + chunk_id: i32, +) -> Result { + validate_identifier(staging_schema)?; + let sql = format!( + r#"DELETE FROM "{}".irises WHERE epoch = $1 AND chunk_id = $2"#, + staging_schema, + ); + let result = sqlx::query(&sql) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(result.rows_affected()) +} + +/// Return the (id, original_version_id) pairs from staging for a chunk. +pub async fn get_staging_version_map( + pool: &PgPool, + staging_schema: &str, + epoch: i32, + chunk_id: i32, +) -> Result> { + validate_identifier(staging_schema)?; + let sql = format!( + r#"SELECT id, original_version_id FROM "{}".irises WHERE epoch = $1 AND chunk_id = $2 ORDER BY id"#, + staging_schema, + ); + let rows: Vec<(i64, i16)> = sqlx::query_as(&sql) + .bind(epoch) + .bind(chunk_id) + .fetch_all(pool) + .await?; + Ok(rows) +} + +async fn delete_staging_ids_tx( + tx: &mut sqlx::Transaction<'_, sqlx::Postgres>, + staging_schema: &str, + epoch: i32, + chunk_id: i32, + ids: &[i64], +) -> Result { + if ids.is_empty() { + return Ok(0); + } + validate_identifier(staging_schema)?; + let sql = format!( + r#"DELETE FROM "{}".irises WHERE epoch = $1 AND chunk_id = $2 AND id = ANY($3)"#, + staging_schema, + ); + let result = sqlx::query(&sql) + .bind(epoch) + .bind(chunk_id) + .bind(ids) + .execute(&mut **tx) + .await?; + Ok(result.rows_affected()) +} + +pub async fn insert_staging_irises( + pool: &PgPool, + staging_schema: &str, + entries: &[StagingIrisEntry], +) -> Result<()> { + if entries.is_empty() { + return Ok(()); + } + validate_identifier(staging_schema)?; + + let table = format!("\"{}\".irises", staging_schema); + let header = format!( + "INSERT INTO {} (epoch, id, chunk_id, left_code, left_mask, right_code, right_mask, original_version_id, rerand_epoch)", + table + ); + + let mut qb = sqlx::QueryBuilder::new(header); + qb.push_values(entries, |mut b, e| { + b.push_bind(e.epoch); + b.push_bind(e.id); + b.push_bind(e.chunk_id); + b.push_bind(&e.left_code); + b.push_bind(&e.left_mask); + b.push_bind(&e.right_code); + b.push_bind(&e.right_mask); + b.push_bind(e.original_version_id); + b.push_bind(e.rerand_epoch); + }); + + qb.push(" ON CONFLICT (epoch, id) DO NOTHING"); + qb.build().execute(pool).await?; + Ok(()) +} + +/// Apply a confirmed staging chunk to the live `irises` table. +/// +/// Opens a single transaction that: +/// 1. Acquires `RERAND_MODIFY_LOCK` (blocks modification writes) +/// 2. Acquires `RERAND_APPLY_LOCK` (blocks startup DB load) +/// 3. Deletes `staging_divergent` IDs from staging (cross-party disagreements) +/// 4. Applies remaining staging rows via `version_id` CAS +/// 5. Cleans up staging and marks progress +/// +/// The `version_id` CAS (`WHERE irises.version_id = staging.original_version_id`) +/// silently skips any rows that were modified between staging and apply. This is +/// safe: the modification will propagate to all parties and overwrite whatever +/// was there, restoring consistency. See the spec's "Conflict Resolution" section. +pub async fn apply_confirmed_chunk( + pool: &PgPool, + staging_schema: &str, + epoch: i32, + chunk_id: i32, + staging_divergent: &[i64], +) -> Result { + validate_identifier(staging_schema)?; + let mut tx = pool.begin().await?; + + acquire_modify_lock(&mut tx).await?; + sqlx::query("SELECT pg_advisory_xact_lock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *tx) + .await?; + + if !staging_divergent.is_empty() { + let deleted = + delete_staging_ids_tx(&mut tx, staging_schema, epoch, chunk_id, staging_divergent) + .await?; + tracing::info!( + "Rerand apply: removed {} staging-divergent rows (epoch={}, chunk={})", + deleted, + epoch, + chunk_id, + ); + } + + let update_sql = format!( + r#" + UPDATE irises SET + left_code = staging.left_code, + left_mask = staging.left_mask, + right_code = staging.right_code, + right_mask = staging.right_mask, + rerand_epoch = staging.rerand_epoch + FROM "{}".irises AS staging + WHERE irises.id = staging.id + AND staging.epoch = $1 + AND staging.chunk_id = $2 + AND irises.version_id = staging.original_version_id + "#, + staging_schema, + ); + let result = sqlx::query(&update_sql) + .bind(epoch) + .bind(chunk_id) + .execute(&mut *tx) + .await?; + let rows_updated = result.rows_affected(); + + let delete_sql = format!( + r#"DELETE FROM "{}".irises WHERE epoch = $1 AND chunk_id = $2"#, + staging_schema, + ); + sqlx::query(&delete_sql) + .bind(epoch) + .bind(chunk_id) + .execute(&mut *tx) + .await?; + + sqlx::query( + "UPDATE rerand_progress SET live_applied = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(rows_updated) +} + +pub async fn upsert_rerand_progress(pool: &PgPool, epoch: i32, chunk_id: i32) -> Result<()> { + sqlx::query( + r#" + INSERT INTO rerand_progress (epoch, chunk_id) + VALUES ($1, $2) + ON CONFLICT (epoch, chunk_id) DO NOTHING + "#, + ) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(()) +} + +pub async fn set_staging_written(pool: &PgPool, epoch: i32, chunk_id: i32) -> Result<()> { + sqlx::query( + "UPDATE rerand_progress SET staging_written = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(()) +} + +pub async fn set_all_confirmed(pool: &PgPool, epoch: i32, chunk_id: i32) -> Result<()> { + sqlx::query( + "UPDATE rerand_progress SET all_confirmed = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .execute(pool) + .await?; + Ok(()) +} + +pub async fn get_rerand_progress( + pool: &PgPool, + epoch: i32, + chunk_id: i32, +) -> Result> { + let row = sqlx::query_as::<_, RerandProgress>( + "SELECT epoch, chunk_id, staging_written, all_confirmed, live_applied FROM rerand_progress WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk_id) + .fetch_optional(pool) + .await?; + Ok(row) +} + +/// Returns the highest `chunk_id` where `live_applied = TRUE` for a given +/// epoch, or `None` if no chunks have been applied in that epoch yet. +pub async fn get_max_applied_chunk_for_epoch(pool: &PgPool, epoch: i32) -> Result> { + let row: (Option,) = sqlx::query_as( + "SELECT MAX(chunk_id) FROM rerand_progress WHERE epoch = $1 AND live_applied = TRUE", + ) + .bind(epoch) + .fetch_one(pool) + .await?; + Ok(row.0) +} + +/// Delete all staging rows for epochs older than `current_epoch`. +pub async fn delete_staging_for_old_epochs( + pool: &PgPool, + staging_schema: &str, + current_epoch: i32, +) -> Result { + validate_identifier(staging_schema)?; + let sql = format!( + r#"DELETE FROM "{}".irises WHERE epoch < $1"#, + staging_schema + ); + let result = sqlx::query(&sql).bind(current_epoch).execute(pool).await?; + Ok(result.rows_affected()) +} + +/// Delete rerand progress rows for epochs strictly older than the one +/// immediately preceding `current_epoch`. +/// +/// We intentionally keep the rows from the immediately prior epoch so that +/// `get_applied_watermark_from_pool` does not transiently return `None` +/// between the end of epoch `E` and the first applied chunk of epoch `E+1`. +/// A transient `None` here used to cause the cross-party startup watermark +/// check (`freeze_and_verify_inner`) to spuriously classify this party as +/// behind any peer still reporting `Some((E, last_chunk))`, producing a +/// release / re-freeze oscillation during rolling deploys at epoch bumps. +pub async fn delete_rerand_progress_for_old_epochs( + pool: &PgPool, + current_epoch: i32, +) -> Result { + let result = sqlx::query("DELETE FROM rerand_progress WHERE epoch < $1 - 1") + .bind(current_epoch) + .execute(pool) + .await?; + Ok(result.rows_affected()) +} + +/// Returns the highest epoch that has any rerand_progress rows. +pub async fn get_current_epoch(pool: &PgPool) -> Result> { + let row: (Option,) = sqlx::query_as("SELECT MAX(epoch) FROM rerand_progress") + .fetch_one(pool) + .await?; + Ok(row.0) +} + +// --------------------------------------------------------------------------- +// Shared startup helpers (used by both HNSW and GPU servers) +// --------------------------------------------------------------------------- + +/// Build the rerand sync state from the local `rerand_progress` table. +/// +/// Returns `Ok(None)` when the `rerand_progress` table does not exist yet +/// (rolling deploy before migration). Returns `Err` for real DB failures. +pub async fn build_rerand_sync_state(pool: &PgPool) -> Result> { + let epoch = match get_current_epoch(pool).await { + Ok(e) => e.unwrap_or(0), + Err(e) => { + if is_undefined_table(&e) { + return Ok(None); + } + return Err(e); + } + }; + let max_applied = get_max_applied_chunk_for_epoch(pool, epoch).await?; + Ok(Some(RerandSyncState { + epoch, + max_applied_chunk: max_applied, + })) +} + +fn is_undefined_table(err: &eyre::Report) -> bool { + if let Some(db_err) = err.root_cause().downcast_ref::() { + return is_undefined_table_sqlx(db_err); + } + false +} + +fn is_undefined_table_sqlx(err: &sqlx::Error) -> bool { + if let sqlx::Error::Database(pg) = err { + return pg.code().as_deref() == Some("42P01"); + } + false +} + +// --------------------------------------------------------------------------- +// Worker heartbeat: lets the main server detect that a rerand worker is +// actively running, independently of any config flag. Paired with the +// server's `rerand_enabled` config flag to catch misconfigs at startup. +// --------------------------------------------------------------------------- + +/// How often the worker writes its heartbeat while alive. +pub const WORKER_HEARTBEAT_WRITE_INTERVAL: Duration = Duration::from_secs(10); + +/// How long after the last heartbeat we still consider the worker alive. +/// Must be comfortably larger than `WORKER_HEARTBEAT_WRITE_INTERVAL` to avoid +/// false "dead" verdicts during transient DB lag or worker restarts. +pub const WORKER_HEARTBEAT_STALE_AFTER: Duration = Duration::from_secs(60); + +fn is_pre_heartbeat_schema(err: &sqlx::Error) -> bool { + // 42P01 = undefined_table (rerand_control doesn't exist yet). + // 42703 = undefined_column (heartbeat column not yet migrated). + if !rerand_control_exists(err) { + return true; + } + if let sqlx::Error::Database(pg) = err { + return pg.code().as_deref() == Some("42703"); + } + false +} + +/// Write `NOW()` into `rerand_control.worker_last_heartbeat`. +/// +/// Silently succeeds when the table or column does not exist yet +/// (pre-migration); the worker can still run before the server deploys the +/// heartbeat migration. +pub async fn write_worker_heartbeat(pool: &PgPool) -> Result<()> { + match sqlx::query("UPDATE rerand_control SET worker_last_heartbeat = NOW() WHERE id = 1") + .execute(pool) + .await + { + Ok(_) => Ok(()), + Err(e) if is_pre_heartbeat_schema(&e) => Ok(()), + Err(e) => Err(e.into()), + } +} + +/// Drive the worker heartbeat in a loop until cancelled. +/// +/// Writes an immediate heartbeat on entry, then writes one every +/// `WORKER_HEARTBEAT_WRITE_INTERVAL` until `cancel` is triggered. Errors are +/// logged but do not terminate the loop — a transient DB outage should not +/// take the worker down. +pub async fn run_worker_heartbeat_loop(pool: &PgPool, cancel: tokio_util::sync::CancellationToken) { + loop { + if let Err(e) = write_worker_heartbeat(pool).await { + tracing::warn!("Failed to write rerand worker heartbeat: {:?}", e); + } + tokio::select! { + _ = cancel.cancelled() => return, + _ = tokio::time::sleep(WORKER_HEARTBEAT_WRITE_INTERVAL) => {} + } + } +} + +/// Returns `true` iff `rerand_control.worker_last_heartbeat` is set and +/// younger than `WORKER_HEARTBEAT_STALE_AFTER`. +/// +/// Returns `false` when the table or column does not exist (pre-migration), +/// when no heartbeat has ever been written, or when the last heartbeat is +/// older than the staleness threshold. +pub async fn is_worker_alive(pool: &PgPool) -> Result { + let stale_secs = WORKER_HEARTBEAT_STALE_AFTER.as_secs() as i64; + let row = sqlx::query_as::<_, (Option,)>( + "SELECT worker_last_heartbeat > NOW() - make_interval(secs => $1) \ + FROM rerand_control WHERE id = 1", + ) + .bind(stale_secs) + .fetch_optional(pool) + .await; + + match row { + Ok(Some((Some(true),))) => Ok(true), + Ok(_) => Ok(false), + Err(e) if is_pre_heartbeat_schema(&e) => Ok(false), + Err(e) => Err(e.into()), + } +} + +// --------------------------------------------------------------------------- +// Freeze protocol: coordinated pause of the rerand worker during startup +// --------------------------------------------------------------------------- + +const FREEZE_TIMEOUT: Duration = Duration::from_secs(120); +const FREEZE_POLL: Duration = Duration::from_secs(2); + +fn rerand_control_exists(err: &sqlx::Error) -> bool { + !is_undefined_table_sqlx(err) +} + +/// Strict-less-than comparator for applied watermarks that documents the +/// intended semantics at the call site and avoids accidentally relying on +/// `Option`'s derived ordering (which treats `None < Some(_)`). +/// +/// Semantics: +/// - `None` means "never applied anything" — this is the legitimate day-0 +/// state before any epoch has completed its first chunk on any party. +/// - When combined with +/// [`delete_rerand_progress_for_old_epochs`]'s retain-prior-epoch policy, +/// `None` should only arise on genuinely fresh deployments, not as a +/// transient epoch-boundary artifact. +fn watermark_lt(a: Option<(i32, i32)>, b: Option<(i32, i32)>) -> bool { + match (a, b) { + (None, Some(_)) => true, + (Some(x), Some(y)) => x < y, + _ => false, + } +} + +/// Request the rerand worker to freeze. Writes a unique `freeze_generation` +/// to `rerand_control`. Returns the generation token. +pub async fn request_rerand_freeze(pool: &PgPool) -> Result> { + let generation = uuid::Uuid::now_v7().to_string(); + match sqlx::query( + "UPDATE rerand_control SET freeze_requested = TRUE, freeze_generation = $1, frozen_generation = NULL WHERE id = 1", + ) + .bind(&generation) + .execute(pool) + .await + { + Ok(_) => Ok(Some(generation)), + Err(e) if !rerand_control_exists(&e) => { + tracing::info!("rerand_control table missing; skipping freeze (pre-migration)"); + Ok(None) + } + Err(e) => Err(e.into()), + } +} + +/// Wait until the rerand worker acknowledges the freeze by writing +/// `frozen_generation = generation`. Fails closed on timeout. +pub async fn wait_for_rerand_frozen(pool: &PgPool, generation: &str) -> Result<()> { + let deadline = tokio::time::Instant::now() + FREEZE_TIMEOUT; + loop { + let row: Option<(Option,)> = + sqlx::query_as("SELECT frozen_generation FROM rerand_control WHERE id = 1") + .fetch_optional(pool) + .await?; + + if let Some((Some(frozen_gen),)) = row { + if frozen_gen == generation { + tracing::info!("Rerand worker confirmed freeze (generation={})", generation); + return Ok(()); + } + } + + if tokio::time::Instant::now() >= deadline { + eyre::bail!( + "Rerand worker did not acknowledge freeze after {:?} (generation={}). \ + Ensure the rerand worker is running and healthy.", + FREEZE_TIMEOUT, + generation, + ); + } + tokio::time::sleep(FREEZE_POLL).await; + } +} + +/// Called by the rerand worker between chunks. If a freeze is requested, +/// acknowledge it and block until the freeze is lifted. Returns `true` if +/// the worker should continue, `false` if cancelled while frozen. +pub async fn check_and_handle_freeze( + pool: &PgPool, + cancel: Option<&tokio_util::sync::CancellationToken>, +) -> Result { + let row: Option<(bool, Option)> = match sqlx::query_as( + "SELECT freeze_requested, freeze_generation FROM rerand_control WHERE id = 1", + ) + .fetch_optional(pool) + .await + { + Ok(r) => r, + Err(e) if !rerand_control_exists(&e) => return Ok(true), + Err(e) => return Err(e.into()), + }; + + let Some((true, Some(generation))) = row else { + return Ok(true); + }; + + tracing::info!( + "Rerand freeze requested (generation={}), pausing...", + generation + ); + + // Acknowledge the freeze. + sqlx::query("UPDATE rerand_control SET frozen_generation = $1 WHERE id = 1") + .bind(&generation) + .execute(pool) + .await?; + + let mut current_gen = generation.to_string(); + + // Block until freeze is lifted. Re-read freeze_generation each iteration + // so that if the requesting server crashes and restarts with a new + // generation, we re-acknowledge instead of deadlocking. + loop { + if cancel.is_some_and(|c| c.is_cancelled()) { + return Ok(false); + } + + let row: Option<(bool, Option)> = sqlx::query_as( + "SELECT freeze_requested, freeze_generation FROM rerand_control WHERE id = 1", + ) + .fetch_optional(pool) + .await?; + + match row { + Some((false, _)) | None => { + tracing::info!("Rerand freeze lifted, resuming"); + return Ok(true); + } + Some((true, Some(ref new_gen))) if *new_gen != current_gen => { + tracing::info!( + "Rerand freeze generation changed ({} -> {}), re-acknowledging", + current_gen, + new_gen + ); + sqlx::query("UPDATE rerand_control SET frozen_generation = $1 WHERE id = 1") + .bind(new_gen) + .execute(pool) + .await?; + current_gen = new_gen.clone(); + } + _ => {} + } + + tokio::time::sleep(FREEZE_POLL).await; + } +} + +/// Lift the freeze and clear the generation. Called after `load_iris_db`. +/// Retries on transient DB errors to avoid leaving the worker permanently frozen. +/// Silently succeeds if the `rerand_control` table doesn't exist (pre-migration). +pub async fn release_rerand_freeze(pool: &PgPool) -> Result<()> { + for attempt in 0..5 { + match sqlx::query( + "UPDATE rerand_control SET freeze_requested = FALSE, freeze_generation = NULL, frozen_generation = NULL WHERE id = 1", + ) + .execute(pool) + .await + { + Ok(_) => { + tracing::info!("Rerand freeze released"); + return Ok(()); + } + Err(e) if !rerand_control_exists(&e) => { + return Ok(()); + } + Err(e) => { + tracing::warn!( + "Failed to release rerand freeze (attempt {}): {:?}", + attempt + 1, + e + ); + tokio::time::sleep(FREEZE_POLL).await; + } + } + } + eyre::bail!("Failed to release rerand freeze after 5 attempts — worker may be stuck frozen"); +} + +/// Acquire `RERAND_APPLY_LOCK` on a detached connection. The lock is held +/// through `load_iris_db` to prevent any concurrent rerand applies (belt +/// and suspenders — the freeze should already have paused the worker). +pub async fn acquire_apply_lock(pool: &PgPool) -> Result> { + let mut conn = pool.acquire().await?; + + // If rerand tables don't exist yet, skip. + match sqlx::query_as::<_, (i64,)>("SELECT COUNT(*) FROM rerand_progress LIMIT 1") + .fetch_one(&mut *conn) + .await + { + Err(e) if is_undefined_table_sqlx(&e) => return Ok(None), + Err(e) => return Err(e.into()), + Ok(_) => {} + } + + sqlx::query("SELECT pg_advisory_lock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut *conn) + .await?; + + Ok(Some(conn.detach())) +} + +/// Release the advisory lock and close the connection. +pub async fn release_apply_lock(lock_conn: Option) -> Result<()> { + if let Some(mut conn) = lock_conn { + let _ = sqlx::query("SELECT pg_advisory_unlock($1)") + .bind(RERAND_APPLY_LOCK) + .execute(&mut conn) + .await; + drop(conn); + tracing::info!("RERAND_APPLY_LOCK released after DB load"); + } + Ok(()) +} + +/// Get the local applied watermark: `(epoch, max_chunk_id)` where +/// `live_applied = TRUE`. Returns `None` pre-migration or if no chunks +/// have been applied. +pub async fn get_applied_watermark_from_pool(pool: &PgPool) -> Result> { + let row: Option<(i32, i32)> = match sqlx::query_as( + "SELECT epoch, chunk_id FROM rerand_progress \ + WHERE live_applied = TRUE \ + ORDER BY epoch DESC, chunk_id DESC \ + LIMIT 1", + ) + .fetch_optional(pool) + .await + { + Ok(row) => row, + Err(e) if is_undefined_table_sqlx(&e) => return Ok(None), + Err(e) => return Err(e.into()), + }; + Ok(row) +} + +async fn fetch_peer_watermark(host: &str, port: usize) -> Result> { + let url = format!("http://{}:{}/rerand-watermark", host, port); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .build()?; + let resp = client + .get(&url) + .send() + .await + .map_err(|e| eyre::eyre!("Failed to reach {} for watermark: {}", url, e))?; + if !resp.status().is_success() { + eyre::bail!("Peer {} returned HTTP {} for watermark", url, resp.status()); + } + let body = resp + .text() + .await + .map_err(|e| eyre::eyre!("Failed to read watermark from {}: {}", url, e))?; + + if body.trim() == "null" { + return Ok(None); + } + let v: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| eyre::eyre!("Failed to parse watermark from {}: {}", url, e))?; + Ok(Some(( + v["epoch"] + .as_i64() + .ok_or_else(|| eyre::eyre!("Missing epoch in watermark from {}", url))? as i32, + v["max_applied_chunk"] + .as_i64() + .ok_or_else(|| eyre::eyre!("Missing max_applied_chunk in watermark from {}", url))? + as i32, + ))) +} + +/// Freeze the local rerand worker, then verify all peers report the exact +/// same applied watermark. If this party is behind, release the freeze +/// briefly so the worker can catch up, then re-freeze and re-check. +/// If this party is at the max, stay frozen and wait for peers to catch up. +/// +/// Guarantees: when this returns `Ok(())`, the local worker is frozen and +/// all parties have the same `(epoch, max_applied_chunk)`. +/// On any error, the freeze is released before the error propagates. +pub async fn freeze_and_verify_watermarks(pool: &PgPool, peers: &[(&str, usize)]) -> Result<()> { + if peers.is_empty() { + eyre::bail!("freeze_and_verify_watermarks called with no peers"); + } + + let result = freeze_and_verify_inner(pool, peers).await; + if result.is_err() { + if let Err(release_err) = release_rerand_freeze(pool).await { + tracing::error!( + "Failed to release rerand freeze during error cleanup: {:?}. \ + Worker may be stuck frozen until next successful startup.", + release_err + ); + } + } + result +} + +async fn freeze_and_verify_inner(pool: &PgPool, peers: &[(&str, usize)]) -> Result<()> { + let deadline = tokio::time::Instant::now() + FREEZE_TIMEOUT; + + loop { + let gen = match request_rerand_freeze(pool).await? { + Some(g) => g, + None => return Ok(()), // pre-migration, no rerand tables + }; + + // Bound the ack wait by whatever remains of the outer deadline. The + // helper's own `FREEZE_TIMEOUT` resets per call, so without this the + // total elapsed time across repeated catchup iterations can exceed the + // advertised `FREEZE_TIMEOUT` by a large factor. + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + let _ = release_rerand_freeze(pool).await; + eyre::bail!( + "Rerand freeze+convergence timeout after {:?}. \ + Ensure all rerand workers and main servers are healthy.", + FREEZE_TIMEOUT, + ); + } + match tokio::time::timeout(remaining, wait_for_rerand_frozen(pool, &gen)).await { + Ok(r) => r?, + Err(_) => { + let _ = release_rerand_freeze(pool).await; + eyre::bail!( + "Rerand freeze+convergence timeout after {:?} (ack wait; generation={}). \ + Ensure all rerand workers and main servers are healthy.", + FREEZE_TIMEOUT, + gen, + ); + } + } + + loop { + if tokio::time::Instant::now() >= deadline { + let _ = release_rerand_freeze(pool).await; + eyre::bail!( + "Rerand watermark convergence timeout after {:?}. \ + Ensure all rerand workers and main servers are healthy.", + FREEZE_TIMEOUT, + ); + } + + let local = get_applied_watermark_from_pool(pool).await?; + let mut all_equal = true; + let mut max_wm = local; + + for (host, port) in peers { + let peer = fetch_peer_watermark(host, *port).await?; + if peer != local { + all_equal = false; + } + if watermark_lt(max_wm, peer) { + max_wm = peer; + } + } + + if all_equal { + tracing::info!( + "Rerand watermark equality confirmed across all parties: {:?}", + local + ); + return Ok(()); + } + + if watermark_lt(local, max_wm) { + tracing::info!( + "Local watermark {:?} behind max {:?}, releasing freeze to catch up", + local, + max_wm + ); + release_rerand_freeze(pool).await?; + tokio::time::sleep(Duration::from_secs(5)).await; + break; // outer loop will re-freeze and re-check + } + + tracing::info!( + "Local watermark {:?} at max, waiting for peers to catch up...", + local + ); + tokio::time::sleep(FREEZE_POLL).await; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_staging_schema_name() { + assert_eq!(staging_schema_name("public"), "public_rerand_staging"); + } + + #[test] + fn test_validate_identifier_ok() { + assert!(validate_identifier("public_rerand_staging").is_ok()); + } + + #[test] + fn test_validate_identifier_rejects_injection() { + assert!(validate_identifier("public; DROP TABLE irises").is_err()); + } + + #[test] + fn test_watermark_lt_both_none_is_false() { + assert!(!watermark_lt(None, None)); + } + + #[test] + fn test_watermark_lt_none_lt_some() { + assert!(watermark_lt(None, Some((0, 0)))); + } + + #[test] + fn test_watermark_lt_some_not_lt_none() { + // Regression guard for the epoch-boundary bug: a local `Some((E, k))` + // must not be considered behind a peer reporting `None`, since peers + // reporting `None` are strictly at earlier progress than any applied + // chunk. + assert!(!watermark_lt(Some((0, 0)), None)); + assert!(!watermark_lt(Some((3, 42)), None)); + } + + #[test] + fn test_watermark_lt_some_some_uses_lexicographic() { + assert!(watermark_lt(Some((0, 4)), Some((1, 0)))); + assert!(watermark_lt(Some((1, 0)), Some((1, 1)))); + assert!(!watermark_lt(Some((1, 1)), Some((1, 1)))); + assert!(!watermark_lt(Some((1, 1)), Some((1, 0)))); + assert!(!watermark_lt(Some((1, 0)), Some((0, 99)))); + } +} diff --git a/iris-mpc-upgrade/Cargo.toml b/iris-mpc-upgrade/Cargo.toml index 93eeb19589..e9f4c61b19 100644 --- a/iris-mpc-upgrade/Cargo.toml +++ b/iris-mpc-upgrade/Cargo.toml @@ -12,18 +12,24 @@ ark-bls12-381 = "0.5.0" ark-ff = "0.5.0" ark-ec = "0.5.0" ark-serialize = "0.5.0" +aws-config.workspace = true +aws-sdk-s3.workspace = true +aws-sdk-secretsmanager.workspace = true axum.workspace = true iris-mpc-common = { path = "../iris-mpc-common" } iris-mpc-store = { path = "../iris-mpc-store" } clap = { workspace = true, features = ["env"] } eyre.workspace = true bytemuck.workspace = true +base64.workspace = true serde.workspace = true +serde_json.workspace = true serde-big-array = "0.5" tracing.workspace = true itertools.workspace = true rand.workspace = true rand_chacha = "0.3" +sqlx.workspace = true tokio.workspace = true tracing-subscriber.workspace = true @@ -36,11 +42,17 @@ prost = "0.13.3" sha2.workspace = true thiserror.workspace = true blake3 = "1.8.2" +futures.workspace = true +tokio-util.workspace = true [dev-dependencies] criterion = "0.5" rayon = "1.10.0" +dotenvy.workspace = true + +[features] +db_dependent = [] [build-dependencies] diff --git a/iris-mpc-upgrade/src/config.rs b/iris-mpc-upgrade/src/config.rs index 3444eb3011..49568a2aff 100644 --- a/iris-mpc-upgrade/src/config.rs +++ b/iris-mpc-upgrade/src/config.rs @@ -240,6 +240,7 @@ pub enum ReRandomizeDbSubCommand { KeyGen(KeyGenConfig), RerandomizeDb(ReRandomizeConfig), RerandomizeCheck(ReRandomizeCheckConfig), + RerandomizeContinuous(RerandomizeContinuousConfig), } #[derive(Args)] @@ -341,3 +342,45 @@ pub struct ReRandomizeCheckConfig { #[clap(long, env = "NEW_SCHEMA_NAME_PARTY_2")] pub new_schema_name_party_2: String, } + +#[derive(Args, Debug)] +pub struct RerandomizeContinuousConfig { + #[clap(long, env = "PARTY_ID")] + pub party_id: u8, + + #[clap(long, env = "DB_URL")] + pub db_url: String, + + #[clap(long, env = "ENVIRONMENT")] + pub env: String, + + #[clap(long, env = "SMPC__SERVICE__SERVICE_NAME")] + pub service_name: String, + + #[clap(long, env = "RERAND_S3_BUCKET")] + pub s3_bucket: String, + + #[clap(long, env = "SCHEMA_NAME")] + pub schema_name: String, + + #[clap(long, default_value = "10000", env = "CHUNK_SIZE")] + pub chunk_size: u64, + + #[clap(long, default_value = "5", env = "CHUNK_DELAY_SECS")] + pub chunk_delay_secs: u64, + + #[clap(long, default_value = "0", env = "SAFETY_BUFFER_IDS")] + pub safety_buffer_ids: u64, + + #[clap(long, default_value = "5000", env = "S3_POLL_INTERVAL_MS")] + pub s3_poll_interval_ms: u64, + + #[clap(long, default_value = "3000", env = "HEALTHCHECK_PORT")] + pub healthcheck_port: usize, + + /// Must match the server's `SMPC__RERAND_ENABLED` setting. The worker + /// refuses to start when this is `false` so that an accidentally-deployed + /// worker can't run against a server that doesn't expect it. + #[clap(long, default_value = "false", env = "RERAND_ENABLED")] + pub rerand_enabled: bool, +} diff --git a/iris-mpc-upgrade/src/continuous_rerand.rs b/iris-mpc-upgrade/src/continuous_rerand.rs new file mode 100644 index 0000000000..5e1e381df5 --- /dev/null +++ b/iris-mpc-upgrade/src/continuous_rerand.rs @@ -0,0 +1,466 @@ +use aws_sdk_s3::Client as S3Client; +use aws_sdk_secretsmanager::Client as SecretsManagerClient; +use bytemuck::cast_slice; +use eyre::Result; +use futures::StreamExt; +use iris_mpc_store::rerand::{ + apply_confirmed_chunk, check_and_handle_freeze, delete_rerand_progress_for_old_epochs, + delete_staging_chunk, delete_staging_for_old_epochs, get_current_epoch, + get_max_applied_chunk_for_epoch, get_rerand_progress, get_staging_version_map, + insert_staging_irises, set_all_confirmed, set_staging_written, staging_schema_name, + upsert_rerand_progress, StagingIrisEntry, +}; +use iris_mpc_store::Store; +use sqlx::PgPool; +use std::future::Future; +use std::time::Duration; +use tokio::time::{sleep, timeout, Instant}; +use tokio_util::sync::CancellationToken; + +use crate::config::RerandomizeContinuousConfig; +use crate::epoch; +use crate::rerandomization::randomize_iris; +use crate::s3_coordination::{self, Manifest}; + +const INTERRUPTIBLE_POLL_TIMEOUT: Duration = Duration::from_secs(30 * 60); +const MIN_POLL_SLICE: Duration = Duration::from_secs(2); +const MAX_POLL_SLICE: Duration = Duration::from_secs(30); + +enum PollOutcome { + Completed(T), + Cancelled, +} + +/// Run the continuous rerandomization loop. +/// +/// If `cancel` is provided, the loop checks for cancellation between chunk +/// stages and exits cleanly with `Ok(())` when cancelled. Pass `None` for +/// production use where the loop runs until the process is killed. +pub async fn run_continuous_rerand( + config: &RerandomizeContinuousConfig, + s3: &S3Client, + sm: &SecretsManagerClient, + store: &Store, + cancel: Option<&CancellationToken>, +) -> Result<()> { + if config.chunk_size == 0 { + eyre::bail!("chunk_size must be > 0"); + } + if config.s3_poll_interval_ms == 0 { + eyre::bail!("s3_poll_interval_ms must be > 0"); + } + + if !config.rerand_enabled { + eyre::bail!( + "RERAND_ENABLED is false — continuous rerand worker exiting. \ + Set RERAND_ENABLED=true on the worker (and SMPC__RERAND_ENABLED=true on the server) to enable." + ); + } + + let pool = &store.pool; + let staging_schema = staging_schema_name(&store.schema_name); + let poll_interval = Duration::from_millis(config.s3_poll_interval_ms); + let chunk_delay = Duration::from_secs(config.chunk_delay_secs); + + loop { + if is_cancelled(cancel) { + return Ok(()); + } + + if !check_and_handle_freeze(pool, cancel).await? { + return Ok(()); + } + + let epoch_hint = get_current_epoch(pool).await?.unwrap_or(0) as u32; + let active_epoch = epoch::determine_active_epoch(s3, &config.s3_bucket, epoch_hint).await?; + tracing::info!("Active epoch: {}", active_epoch); + + let shared_secret = epoch::derive_shared_secret( + sm, + s3, + &config.s3_bucket, + &config.env, + &config.service_name, + active_epoch, + config.party_id, + poll_interval, + ) + .await?; + + let manifest = + get_or_create_manifest(s3, store, config, active_epoch, poll_interval).await?; + tracing::info!( + "Epoch {} manifest: chunk_size={}, max_id_inclusive={}", + active_epoch, + manifest.chunk_size, + manifest.max_id_inclusive + ); + + let cleaned = + delete_staging_for_old_epochs(pool, &staging_schema, active_epoch as i32).await?; + if cleaned > 0 { + tracing::info!( + "Epoch {}: cleaned {} orphaned staging rows from prior epochs", + active_epoch, + cleaned + ); + } + let cleaned_progress = + delete_rerand_progress_for_old_epochs(pool, active_epoch as i32).await?; + if cleaned_progress > 0 { + tracing::info!( + "Epoch {}: cleaned {} rerand_progress rows from prior epochs", + active_epoch, + cleaned_progress + ); + } + + let start_chunk_id = get_max_applied_chunk_for_epoch(pool, active_epoch as i32) + .await? + .map(|max_chunk| (max_chunk + 1) as u32) + .unwrap_or(0); + + let mut chunk_id: u32 = start_chunk_id; + loop { + if is_cancelled(cancel) { + return Ok(()); + } + + // Honor startup freeze requests between chunks. + if !check_and_handle_freeze(pool, cancel).await? { + return Ok(()); + } + + if manifest.chunk_is_empty(chunk_id) { + break; + } + + let progress = get_rerand_progress(pool, active_epoch as i32, chunk_id as i32).await?; + upsert_rerand_progress(pool, active_epoch as i32, chunk_id as i32).await?; + + // --- Stage --- + if !progress.as_ref().is_some_and(|p| p.staging_written) { + process_chunk_staging( + pool, + store, + &staging_schema, + &shared_secret, + config.party_id, + active_epoch, + chunk_id, + &manifest, + ) + .await?; + set_staging_written(pool, active_epoch as i32, chunk_id as i32).await?; + } + + // Load the version map once; used for the hash upload below and + // for on-demand full-map upload if hashes diverge across parties. + let version_map = get_staging_version_map( + pool, + &staging_schema, + active_epoch as i32, + chunk_id as i32, + ) + .await?; + + // --- Upload version hash + staged marker (both idempotent) --- + if !progress.as_ref().is_some_and(|p| p.all_confirmed) { + s3_coordination::upload_chunk_version_hash( + s3, + &config.s3_bucket, + active_epoch, + config.party_id, + chunk_id, + &version_map, + ) + .await?; + s3_coordination::upload_chunk_staged( + s3, + &config.s3_bucket, + active_epoch, + config.party_id, + chunk_id, + ) + .await?; + tracing::info!( + "Epoch {} chunk {}: version hash + staged marker uploaded", + active_epoch, + chunk_id + ); + } + + if is_cancelled(cancel) { + return Ok(()); + } + + // --- Wait for all parties to confirm staging --- + if !progress.as_ref().is_some_and(|p| p.all_confirmed) { + match run_interruptible_poll( + pool, + cancel, + poll_interval, + "chunk staged confirmation", + || { + s3_coordination::poll_chunk_staged_all( + s3, + &config.s3_bucket, + active_epoch, + chunk_id, + poll_interval, + ) + }, + ) + .await? + { + PollOutcome::Completed(()) => {} + PollOutcome::Cancelled => return Ok(()), + } + set_all_confirmed(pool, active_epoch as i32, chunk_id as i32).await?; + tracing::info!( + "Epoch {} chunk {}: all parties confirmed", + active_epoch, + chunk_id + ); + } + + if is_cancelled(cancel) { + return Ok(()); + } + + // --- Apply --- + // 1. Compute staging-time cross-party disagreements from version maps. + // This is pure S3 reads — no DB lock held. + let staging_divergent = match run_interruptible_poll( + pool, + cancel, + poll_interval, + "cross-party version-map convergence", + || { + s3_coordination::compute_cross_party_divergent_ids( + s3, + &config.s3_bucket, + active_epoch, + chunk_id, + config.party_id, + &version_map, + poll_interval, + ) + }, + ) + .await? + { + PollOutcome::Completed(ids) => ids, + PollOutcome::Cancelled => return Ok(()), + }; + + // 2. Apply under lock. The function acquires RERAND_MODIFY_LOCK + + // RERAND_APPLY_LOCK, deletes staging_divergent, applies via + // version_id CAS, cleans up staging, and commits. + // No S3 I/O happens while the lock is held. + let rows = apply_confirmed_chunk( + pool, + &staging_schema, + active_epoch as i32, + chunk_id as i32, + &staging_divergent, + ) + .await?; + + tracing::info!( + "Epoch {} chunk {}: applied to live DB ({} rows updated, {} staging-divergent skipped)", + active_epoch, + chunk_id, + rows, + staging_divergent.len(), + ); + + chunk_id += 1; + + if chunk_delay > Duration::ZERO { + sleep(chunk_delay).await; + } + } + + if chunk_id == 0 { + let empty_epoch_sleep = chunk_delay.max(Duration::from_secs(30)); + tracing::info!( + "Epoch {} is empty (max_id_inclusive={}), sleeping {:.0}s to avoid spinning", + active_epoch, + manifest.max_id_inclusive, + empty_epoch_sleep.as_secs_f64(), + ); + sleep(empty_epoch_sleep).await; + } + + epoch::complete_epoch( + sm, + s3, + &config.s3_bucket, + &config.env, + &config.service_name, + active_epoch, + config.party_id, + poll_interval, + ) + .await?; + tracing::info!("Epoch {} completed, moving to next epoch", active_epoch); + + if is_cancelled(cancel) { + return Ok(()); + } + + if chunk_delay > Duration::ZERO { + sleep(chunk_delay).await; + } + } +} + +fn is_cancelled(cancel: Option<&CancellationToken>) -> bool { + cancel.is_some_and(|c| c.is_cancelled()) +} + +fn poll_slice_duration(poll_interval: Duration) -> Duration { + poll_interval + .saturating_add(poll_interval) + .max(MIN_POLL_SLICE) + .min(MAX_POLL_SLICE) +} + +async fn run_interruptible_poll( + pool: &PgPool, + cancel: Option<&CancellationToken>, + poll_interval: Duration, + stage_name: &str, + mut op: F, +) -> Result> +where + F: FnMut() -> Fut, + Fut: Future>, +{ + let deadline = Instant::now() + INTERRUPTIBLE_POLL_TIMEOUT; + let slice = poll_slice_duration(poll_interval); + + loop { + if is_cancelled(cancel) { + return Ok(PollOutcome::Cancelled); + } + if !check_and_handle_freeze(pool, cancel).await? { + return Ok(PollOutcome::Cancelled); + } + if Instant::now() >= deadline { + eyre::bail!( + "Timeout after {:?} while waiting for {}", + INTERRUPTIBLE_POLL_TIMEOUT, + stage_name + ); + } + + match timeout(slice, op()).await { + Ok(result) => return Ok(PollOutcome::Completed(result?)), + Err(_) => { + tracing::debug!("Still waiting for {}; rechecking freeze/cancel", stage_name); + } + } + } +} + +async fn get_or_create_manifest( + s3: &S3Client, + store: &Store, + config: &RerandomizeContinuousConfig, + epoch: u32, + poll_interval: Duration, +) -> Result { + if s3_coordination::manifest_exists(s3, &config.s3_bucket, epoch).await? { + return s3_coordination::download_manifest(s3, &config.s3_bucket, epoch, poll_interval) + .await; + } + + let local_max = store.get_max_serial_id().await? as u64; + s3_coordination::upload_max_id(s3, &config.s3_bucket, epoch, config.party_id, local_max) + .await?; + + if config.party_id == 0 { + let all_max_ids = + s3_coordination::download_all_max_ids(s3, &config.s3_bucket, epoch, poll_interval) + .await?; + let min_max = *all_max_ids.iter().min().unwrap(); + let max_id_inclusive = min_max.saturating_sub(config.safety_buffer_ids); + if max_id_inclusive == 0 { + tracing::warn!( + "Epoch {}: max_id_inclusive is 0 (min_max={}, safety_buffer_ids={}). \ + Epoch will be empty.", + epoch, + min_max, + config.safety_buffer_ids + ); + } + + let manifest = Manifest { + epoch, + chunk_size: config.chunk_size, + max_id_inclusive, + }; + s3_coordination::upload_manifest(s3, &config.s3_bucket, epoch, &manifest).await?; + tracing::info!( + "Epoch {}: manifest created (max_id_inclusive={}, chunk_size={})", + epoch, + max_id_inclusive, + config.chunk_size + ); + Ok(manifest) + } else { + s3_coordination::download_manifest(s3, &config.s3_bucket, epoch, poll_interval).await + } +} + +#[allow(clippy::too_many_arguments)] +async fn process_chunk_staging( + pool: &PgPool, + store: &Store, + staging_schema: &str, + shared_secret: &[u8; 32], + party_id: u8, + epoch: u32, + chunk_id: u32, + manifest: &Manifest, +) -> Result<()> { + delete_staging_chunk(pool, staging_schema, epoch as i32, chunk_id as i32).await?; + + let (start, end) = manifest.chunk_range(chunk_id); + + const BATCH_SIZE: usize = 500; + + let mut stream = store.stream_irises_in_range(start..end); + let mut batch: Vec = Vec::with_capacity(BATCH_SIZE); + + while let Some(iris) = stream.next().await.transpose()? { + let version_id = iris.version_id(); + let iris_id = iris.id(); + let (_, lc, lm, rc, rm) = randomize_iris(iris, shared_secret, party_id as usize); + + batch.push(StagingIrisEntry { + epoch: epoch as i32, + id: iris_id, + chunk_id: chunk_id as i32, + left_code: cast_slice::(&lc.coefs).to_vec(), + left_mask: cast_slice::(&lm.coefs).to_vec(), + right_code: cast_slice::(&rc.coefs).to_vec(), + right_mask: cast_slice::(&rm.coefs).to_vec(), + original_version_id: version_id, + rerand_epoch: (epoch + 1) as i32, + }); + + if batch.len() >= BATCH_SIZE { + insert_staging_irises(pool, staging_schema, &batch).await?; + batch.clear(); + } + } + + if !batch.is_empty() { + insert_staging_irises(pool, staging_schema, &batch).await?; + } + + Ok(()) +} diff --git a/iris-mpc-upgrade/src/epoch.rs b/iris-mpc-upgrade/src/epoch.rs new file mode 100644 index 0000000000..6600c4fa1d --- /dev/null +++ b/iris-mpc-upgrade/src/epoch.rs @@ -0,0 +1,278 @@ +use aws_sdk_s3::Client as S3Client; +use aws_sdk_secretsmanager::Client as SecretsManagerClient; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use eyre::{eyre, Result}; +use std::time::Duration; + +use crate::s3_coordination; +use crate::tripartite_dh; + +fn service_prefix(service_name: &str) -> &str { + service_name + .rsplit_once('-') + .map(|(prefix, _)| prefix) + .unwrap_or(service_name) +} + +fn secret_id(env: &str, service_name: &str, epoch: u32, party_id: u8) -> String { + format!( + "{}/{}/epoch-{}/private-key-party-{}", + env, + service_prefix(service_name), + epoch, + party_id + ) +} + +/// Check if a private key for this epoch already exists in Secrets Manager. +async fn load_private_key_from_sm( + sm: &SecretsManagerClient, + env: &str, + service_name: &str, + epoch: u32, + party_id: u8, +) -> Result> { + let sid = secret_id(env, service_name, epoch, party_id); + match sm + .get_secret_value() + .secret_id(&sid) + .version_stage("AWSCURRENT") + .send() + .await + { + Ok(output) => { + let b64 = output + .secret_string() + .ok_or_else(|| eyre!("Secret {} has no string value", sid))?; + let bytes = STANDARD.decode(b64)?; + let key = tripartite_dh::PrivateKey::deserialize(&bytes) + .map_err(|e| eyre!("Failed to deserialize private key from SM: {:?}", e))?; + Ok(Some(key)) + } + Err(e) => { + let svc = e.into_service_error(); + if svc.is_resource_not_found_exception() { + Ok(None) + } else { + Err(eyre!("SM GetSecretValue failed for {}: {}", sid, svc)) + } + } + } +} + +async fn save_private_key_to_sm( + sm: &SecretsManagerClient, + env: &str, + service_name: &str, + epoch: u32, + party_id: u8, + key: &tripartite_dh::PrivateKey, +) -> Result { + let sid = secret_id(env, service_name, epoch, party_id); + let b64 = STANDARD.encode(key.serialize()); + + match sm + .create_secret() + .name(&sid) + .secret_string(&b64) + .send() + .await + { + Ok(_) => Ok(true), + Err(e) => { + let svc = e.into_service_error(); + if svc.is_resource_exists_exception() { + Ok(false) + } else { + Err(eyre!("SM CreateSecret failed for {}: {}", sid, svc)) + } + } + } +} + +async fn delete_private_key_from_sm( + sm: &SecretsManagerClient, + env: &str, + service_name: &str, + epoch: u32, + party_id: u8, +) -> Result<()> { + let sid = secret_id(env, service_name, epoch, party_id); + sm.delete_secret() + .secret_id(&sid) + .force_delete_without_recovery(true) + .send() + .await + .map_err(|e| eyre!("SM DeleteSecret failed for {}: {}", sid, e))?; + tracing::info!("Deleted epoch {} private key from SM", epoch); + Ok(()) +} + +/// Idempotent key generation for an epoch. +/// +/// 1. Best-effort cleanup of previous epoch's key (covers crash between +/// `poll_epoch_complete_all` and `delete_private_key_from_sm`) +/// 2. Check SM for existing private key +/// 3. If found: load it, derive public key, re-upload to S3 (covers crash between SM write and S3 upload) +/// 4. If not found: generate new keypair, write to SM first, then upload public key to S3 +pub async fn idempotent_keygen( + sm: &SecretsManagerClient, + s3: &S3Client, + bucket: &str, + env: &str, + service_name: &str, + epoch: u32, + party_id: u8, +) -> Result { + if epoch > 0 { + if let Err(e) = delete_private_key_from_sm(sm, env, service_name, epoch - 1, party_id).await + { + tracing::debug!("Cleanup of epoch {} key (best-effort): {}", epoch - 1, e); + } + } + + if let Some(existing) = load_private_key_from_sm(sm, env, service_name, epoch, party_id).await? + { + tracing::info!( + "Epoch {}: private key found in SM, re-uploading public key to S3", + epoch + ); + let public_key = existing.public_key(); + let pk_b64 = STANDARD.encode(public_key.serialize()); + s3_coordination::upload_public_key(s3, bucket, epoch, party_id, &pk_b64).await?; + return Ok(existing); + } + + tracing::info!( + "Epoch {}: generating fresh BLS12-381 keypair for party {}", + epoch, + party_id + ); + let mut rng = rand::rngs::OsRng; + let private_key = tripartite_dh::PrivateKey::random(&mut rng); + + let saved = + save_private_key_to_sm(sm, env, service_name, epoch, party_id, &private_key).await?; + let private_key = if saved { + private_key + } else { + // This branch is hit when two instances of the binary race during a rolling deployment, + // which should not happen. It only exists for defensive purposes. + tracing::warn!( + "Epoch {}: private key already exists in SM (likely concurrent start); reloading it", + epoch + ); + load_private_key_from_sm(sm, env, service_name, epoch, party_id) + .await? + .ok_or_else(|| { + eyre!( + "Secret existed but could not be loaded: {}", + secret_id(env, service_name, epoch, party_id) + ) + })? + }; + + let public_key = private_key.public_key(); + let pk_b64 = STANDARD.encode(public_key.serialize()); + s3_coordination::upload_public_key(s3, bucket, epoch, party_id, &pk_b64).await?; + + Ok(private_key) +} + +/// Derive the shared secret for an epoch: keygen + download peer keys + BLS pairing. +#[allow(clippy::too_many_arguments)] +pub async fn derive_shared_secret( + sm: &SecretsManagerClient, + s3: &S3Client, + bucket: &str, + env: &str, + service_name: &str, + epoch: u32, + party_id: u8, + poll_interval: Duration, +) -> Result<[u8; 32]> { + let private_key = idempotent_keygen(sm, s3, bucket, env, service_name, epoch, party_id).await?; + + let next_id = (party_id + 1) % 3; + let prev_id = (party_id + 2) % 3; + + let pk_next_b64 = + s3_coordination::download_public_key_for_party(s3, bucket, epoch, next_id, poll_interval) + .await?; + let pk_next = + tripartite_dh::PublicKeys::deserialize(&STANDARD.decode(&pk_next_b64)?).map_err(|e| { + eyre!( + "Failed to deserialize public key for party {}: {:?}", + next_id, + e + ) + })?; + + let pk_prev_b64 = + s3_coordination::download_public_key_for_party(s3, bucket, epoch, prev_id, poll_interval) + .await?; + let pk_prev = + tripartite_dh::PublicKeys::deserialize(&STANDARD.decode(&pk_prev_b64)?).map_err(|e| { + eyre!( + "Failed to deserialize public key for party {}: {:?}", + prev_id, + e + ) + })?; + + let shared_secret = private_key.derive_shared_secret(&pk_next, &pk_prev); + let hash = blake3::hash(&shared_secret); + tracing::info!( + "Epoch {}: derived shared secret (blake3 fingerprint: {})", + epoch, + hash.to_hex() + ); + Ok(shared_secret) +} + +/// Determine the active epoch by scanning S3 for the highest epoch with a +/// manifest but without all three `complete` markers. +/// +/// `start_hint` allows callers to skip already-completed epochs (e.g. from +/// `get_current_epoch`). Use `0` when no prior epoch information is available. +pub async fn determine_active_epoch(s3: &S3Client, bucket: &str, start_hint: u32) -> Result { + let mut epoch: u32 = start_hint; + loop { + if !s3_coordination::manifest_exists(s3, bucket, epoch).await? { + break; + } + if s3_coordination::all_parties_complete(s3, bucket, epoch).await? { + epoch += 1; + continue; + } + return Ok(epoch); + } + Ok(epoch) +} + +/// Upload completion marker, poll for all three, then delete the epoch key from SM. +#[allow(clippy::too_many_arguments)] +pub async fn complete_epoch( + sm: &SecretsManagerClient, + s3: &S3Client, + bucket: &str, + env: &str, + service_name: &str, + epoch: u32, + party_id: u8, + poll_interval: Duration, +) -> Result<()> { + s3_coordination::upload_epoch_complete(s3, bucket, epoch, party_id).await?; + tracing::info!( + "Epoch {}: uploaded completion marker for party {}", + epoch, + party_id + ); + + s3_coordination::poll_epoch_complete_all(s3, bucket, epoch, poll_interval).await?; + tracing::info!("Epoch {}: all parties completed", epoch); + + delete_private_key_from_sm(sm, env, service_name, epoch, party_id).await?; + Ok(()) +} diff --git a/iris-mpc-upgrade/src/lib.rs b/iris-mpc-upgrade/src/lib.rs index 2ab2e2018d..7d32241d74 100644 --- a/iris-mpc-upgrade/src/lib.rs +++ b/iris-mpc-upgrade/src/lib.rs @@ -8,10 +8,13 @@ use std::{ }; pub mod config; +pub mod continuous_rerand; +pub mod epoch; pub mod packets; pub mod proto; pub mod rerandomization; pub mod reshare; +pub mod s3_coordination; pub mod tripartite_dh; pub mod utils; diff --git a/iris-mpc-upgrade/src/rerandomization.rs b/iris-mpc-upgrade/src/rerandomization.rs index 3b1e93ea3a..bc8533ff26 100644 --- a/iris-mpc-upgrade/src/rerandomization.rs +++ b/iris-mpc-upgrade/src/rerandomization.rs @@ -1,10 +1,12 @@ use std::io::Read; use iris_mpc_common::{ - galois::degree4::{basis::Monomial, GaloisRingElement}, + galois::degree4::{basis::Monomial, GaloisRingElement, ShamirGaloisRingShare}, galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, + id::PartyID, }; use iris_mpc_store::DbStoredIris; +use itertools::Itertools; pub fn randomize_iris( iris: DbStoredIris, @@ -87,6 +89,81 @@ fn randomize_galois_ring_coefs(coefs: &mut [u16], xof: &mut blake3::OutputReader } } +/// Which pair(s) of parties disagree during reconstruction. +#[derive(Debug)] +pub struct ReconstructionMismatch { + pub pairs_01_vs_12: bool, + pub pairs_01_vs_02: bool, +} + +/// Reconstruct the plaintext from 3 Shamir shares using Lagrange interpolation. +/// +/// Returns `Ok(plaintext)` when all 3 pair-wise reconstructions agree, or +/// `Err(mismatch)` indicating which pairs diverge. +pub fn try_reconstruct_shares( + share0: &[u16], + share1: &[u16], + share2: &[u16], +) -> Result, ReconstructionMismatch> { + let lag_01 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID0, PartyID::ID1); + let lag_10 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID1, PartyID::ID0); + let lag_02 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID0, PartyID::ID2); + let lag_20 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID2, PartyID::ID0); + let lag_12 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID1, PartyID::ID2); + let lag_21 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID2, PartyID::ID1); + + assert!(share0.len() == share1.len() && share1.len() == share2.len()); + + let recon01 = share0 + .chunks_exact(4) + .zip_eq(share1.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + let c = a * lag_01 + b * lag_10; + c.coefs + }) + .collect_vec(); + let recon12 = share1 + .chunks_exact(4) + .zip_eq(share2.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + let c = a * lag_12 + b * lag_21; + c.coefs + }) + .collect_vec(); + let recon02 = share0 + .chunks_exact(4) + .zip_eq(share2.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + let c = a * lag_02 + b * lag_20; + c.coefs + }) + .collect_vec(); + + let mismatch_01_12 = recon01 != recon12; + let mismatch_01_02 = recon01 != recon02; + if mismatch_01_12 || mismatch_01_02 { + return Err(ReconstructionMismatch { + pairs_01_vs_12: mismatch_01_12, + pairs_01_vs_02: mismatch_01_02, + }); + } + Ok(recon01) +} + +/// Reconstruct the plaintext from 3 Shamir shares using Lagrange interpolation. +/// Verifies consistency by reconstructing from all 3 pairs (0-1, 1-2, 0-2) and +/// asserting they agree. +pub fn reconstruct_shares(share0: &[u16], share1: &[u16], share2: &[u16]) -> Vec { + try_reconstruct_shares(share0, share1, share2) + .expect("Reconstruction mismatch: shares are inconsistent across party pairs") +} + #[cfg(test)] mod tests { use iris_mpc_common::{ diff --git a/iris-mpc-upgrade/src/s3_coordination.rs b/iris-mpc-upgrade/src/s3_coordination.rs new file mode 100644 index 0000000000..ad3e0de930 --- /dev/null +++ b/iris-mpc-upgrade/src/s3_coordination.rs @@ -0,0 +1,463 @@ +use aws_sdk_s3::error::ProvideErrorMetadata; +use aws_sdk_s3::Client as S3Client; +use eyre::{eyre, Result}; +use futures::future::try_join_all; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::time::{sleep, Instant}; + +const NUM_PARTIES: u8 = 3; +const DEFAULT_POLL_TIMEOUT: Duration = Duration::from_secs(30 * 60); + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Manifest { + pub epoch: u32, + pub chunk_size: u64, + pub max_id_inclusive: u64, +} + +impl Manifest { + /// Returns (start_id_inclusive, end_id_exclusive) for a given chunk_id. + /// IDs are 1-based. + pub fn chunk_range(&self, chunk_id: u32) -> (u64, u64) { + let start = 1 + (chunk_id as u64) * self.chunk_size; + let end = std::cmp::min( + start + self.chunk_size, + self.max_id_inclusive.saturating_add(1), + ); + (start, end) + } + + pub fn chunk_is_empty(&self, chunk_id: u32) -> bool { + let (start, end) = self.chunk_range(chunk_id); + start >= end + } +} + +fn epoch_party_prefix(epoch: u32, party: u8) -> String { + format!("rerand/epoch-{}/party-{}", epoch, party) +} + +pub async fn upload_marker(s3: &S3Client, bucket: &str, key: &str, body: Vec) -> Result<()> { + s3.put_object() + .bucket(bucket) + .key(key) + .body(body.into()) + .send() + .await + .map_err(|e| eyre!("S3 PutObject failed for key {}: {}", key, e))?; + Ok(()) +} + +pub async fn marker_exists(s3: &S3Client, bucket: &str, key: &str) -> Result { + tracing::debug!(bucket = bucket, key = key, "S3 HeadObject request"); + match s3.head_object().bucket(bucket).key(key).send().await { + Ok(_) => { + tracing::debug!(key = key, "S3 HeadObject: exists"); + Ok(true) + } + Err(e) => { + let status = e.raw_response().map(|r| r.status().as_u16()); + let err_display = format!("{e}"); + let debug_display = format!("{e:?}"); + let svc_err = e.into_service_error(); + let code = svc_err.code().unwrap_or(""); + let message = svc_err.message().unwrap_or(""); + tracing::warn!( + bucket = bucket, + key = key, + http_status = ?status, + is_not_found = svc_err.is_not_found(), + error_code = code, + error_message = message, + error_display = %err_display, + error_debug = %debug_display, + "S3 HeadObject error details" + ); + if svc_err.is_not_found() { + Ok(false) + } else { + Err(eyre!( + "S3 HeadObject failed for key {} in bucket {}: status={:?} code={} message={}", + key, + bucket, + status, + code, + message, + )) + } + } + } +} + +pub async fn download_marker(s3: &S3Client, bucket: &str, key: &str) -> Result> { + let resp = s3 + .get_object() + .bucket(bucket) + .key(key) + .send() + .await + .map_err(|e| eyre!("S3 GetObject failed for key {}: {}", key, e))?; + let bytes = resp + .body + .collect() + .await + .map_err(|e| eyre!("Failed to read S3 body for key {}: {}", key, e))?; + Ok(bytes.to_vec()) +} + +pub async fn poll_until_marker_exists( + s3: &S3Client, + bucket: &str, + key: &str, + poll_interval: Duration, +) -> Result<()> { + let deadline = Instant::now() + DEFAULT_POLL_TIMEOUT; + loop { + if marker_exists(s3, bucket, key).await? { + return Ok(()); + } + if Instant::now() > deadline { + eyre::bail!( + "Timeout after {:?} waiting for S3 marker: {}", + DEFAULT_POLL_TIMEOUT, + key + ); + } + tracing::debug!("Waiting for S3 marker: {}", key); + sleep(poll_interval).await; + } +} + +/// Polls until all three parties have uploaded a given marker suffix for an epoch. +pub async fn poll_until_all_parties_marker( + s3: &S3Client, + bucket: &str, + epoch: u32, + marker_suffix: &str, + poll_interval: Duration, +) -> Result<()> { + let deadline = Instant::now() + DEFAULT_POLL_TIMEOUT; + loop { + let mut all_present = true; + for party in 0..NUM_PARTIES { + let key = format!("{}/{}", epoch_party_prefix(epoch, party), marker_suffix); + if !marker_exists(s3, bucket, &key).await? { + all_present = false; + break; + } + } + if all_present { + return Ok(()); + } + if Instant::now() > deadline { + eyre::bail!( + "Timeout after {:?} waiting for all parties' {} markers for epoch {}", + DEFAULT_POLL_TIMEOUT, + marker_suffix, + epoch + ); + } + tracing::debug!( + "Waiting for all parties' {} markers for epoch {}", + marker_suffix, + epoch + ); + sleep(poll_interval).await; + } +} + +// ---- Public key ---- + +pub async fn upload_public_key( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + key_b64: &str, +) -> Result<()> { + let key = format!("{}/public-key", epoch_party_prefix(epoch, party)); + upload_marker(s3, bucket, &key, key_b64.as_bytes().to_vec()).await +} + +pub async fn download_public_key_for_party( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + poll_interval: Duration, +) -> Result { + let key = format!("{}/public-key", epoch_party_prefix(epoch, party)); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + Ok(String::from_utf8(bytes)?) +} + +// ---- Max ID watermark ---- + +pub async fn upload_max_id( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + max_id: u64, +) -> Result<()> { + let key = format!("{}/max-id", epoch_party_prefix(epoch, party)); + upload_marker(s3, bucket, &key, max_id.to_string().into_bytes()).await +} + +pub async fn download_all_max_ids( + s3: &S3Client, + bucket: &str, + epoch: u32, + poll_interval: Duration, +) -> Result<[u64; 3]> { + let keys: Vec = (0..NUM_PARTIES) + .map(|party| format!("{}/max-id", epoch_party_prefix(epoch, party))) + .collect(); + + for key in &keys { + poll_until_marker_exists(s3, bucket, key, poll_interval).await?; + } + + let all_bytes: Vec> = + try_join_all(keys.iter().map(|key| download_marker(s3, bucket, key))).await?; + let mut ids = [0u64; 3]; + for (party, bytes) in all_bytes.into_iter().enumerate() { + let s = String::from_utf8(bytes)?; + ids[party] = s + .trim() + .parse() + .map_err(|e| eyre!("Failed to parse max-id from party {}: {}", party, e))?; + } + Ok(ids) +} + +// ---- Manifest ---- + +pub async fn upload_manifest( + s3: &S3Client, + bucket: &str, + epoch: u32, + manifest: &Manifest, +) -> Result<()> { + let key = format!("{}/manifest.json", epoch_party_prefix(epoch, 0)); + let body = serde_json::to_vec(manifest)?; + upload_marker(s3, bucket, &key, body).await +} + +pub async fn download_manifest( + s3: &S3Client, + bucket: &str, + epoch: u32, + poll_interval: Duration, +) -> Result { + let key = format!("{}/manifest.json", epoch_party_prefix(epoch, 0)); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + let manifest: Manifest = serde_json::from_slice(&bytes)?; + Ok(manifest) +} + +pub async fn manifest_exists(s3: &S3Client, bucket: &str, epoch: u32) -> Result { + let key = format!("{}/manifest.json", epoch_party_prefix(epoch, 0)); + marker_exists(s3, bucket, &key).await +} + +// ---- Chunk staged markers ---- + +pub async fn upload_chunk_staged( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + chunk_id: u32, +) -> Result<()> { + let key = format!( + "{}/chunk-{}/staged", + epoch_party_prefix(epoch, party), + chunk_id + ); + upload_marker(s3, bucket, &key, b"ok".to_vec()).await +} + +pub async fn poll_chunk_staged_all( + s3: &S3Client, + bucket: &str, + epoch: u32, + chunk_id: u32, + poll_interval: Duration, +) -> Result<()> { + let suffix = format!("chunk-{}/staged", chunk_id); + poll_until_all_parties_marker(s3, bucket, epoch, &suffix, poll_interval).await +} + +// ---- Chunk version map (modification fence) ---- + +fn version_map_hash(version_map: &[(i64, i16)]) -> [u8; 32] { + let mut hasher = blake3::Hasher::new(); + for (id, ver) in version_map { + hasher.update(&id.to_le_bytes()); + hasher.update(&ver.to_le_bytes()); + } + *hasher.finalize().as_bytes() +} + +/// Upload only the blake3 hash of a chunk's version map. The full map is +/// deferred and only uploaded when a cross-party hash mismatch is detected +/// (see [`compute_cross_party_divergent_ids`]), avoiding per-chunk S3 storage +/// on the happy path. +pub async fn upload_chunk_version_hash( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + chunk_id: u32, + version_map: &[(i64, i16)], +) -> Result<()> { + let prefix = format!("{}/chunk-{}", epoch_party_prefix(epoch, party), chunk_id); + let hash = version_map_hash(version_map); + upload_marker(s3, bucket, &format!("{prefix}/version-hash"), hash.to_vec()).await +} + +async fn upload_chunk_version_map_body( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + chunk_id: u32, + version_map: &[(i64, i16)], +) -> Result<()> { + let prefix = format!("{}/chunk-{}", epoch_party_prefix(epoch, party), chunk_id); + let body = serde_json::to_vec(version_map)?; + upload_marker(s3, bucket, &format!("{prefix}/version-map"), body).await +} + +async fn download_chunk_version_hash( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + chunk_id: u32, + poll_interval: Duration, +) -> Result<[u8; 32]> { + let key = format!( + "{}/chunk-{}/version-hash", + epoch_party_prefix(epoch, party), + chunk_id + ); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + let hash: [u8; 32] = bytes + .try_into() + .map_err(|b: Vec| eyre!("version-hash has wrong length: {}", b.len()))?; + Ok(hash) +} + +async fn download_chunk_version_map( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, + chunk_id: u32, + poll_interval: Duration, +) -> Result> { + let key = format!( + "{}/chunk-{}/version-map", + epoch_party_prefix(epoch, party), + chunk_id + ); + poll_until_marker_exists(s3, bucket, &key, poll_interval).await?; + let bytes = download_marker(s3, bucket, &key).await?; + let map: Vec<(i64, i16)> = serde_json::from_slice(&bytes)?; + Ok(map) +} + +/// Compare version maps across all 3 parties and return IDs where any +/// party disagrees on the `original_version_id`. +/// +/// Fast path: download only the 32-byte blake3 hashes concurrently. If all +/// match, return empty (no disagreements). Slow path (hash mismatch): +/// upload this party's full version map, then download all maps concurrently +/// and compute the exact disagreement set. All three parties independently +/// detect the mismatch and upload, so polling converges without extra +/// signaling. +pub async fn compute_cross_party_divergent_ids( + s3: &S3Client, + bucket: &str, + epoch: u32, + chunk_id: u32, + party: u8, + version_map: &[(i64, i16)], + poll_interval: Duration, +) -> Result> { + let hashes: Vec<[u8; 32]> = try_join_all((0..NUM_PARTIES).map(|party| { + download_chunk_version_hash(s3, bucket, epoch, party, chunk_id, poll_interval) + })) + .await?; + + if hashes[0] == hashes[1] && hashes[1] == hashes[2] { + return Ok(Vec::new()); + } + + tracing::info!( + "Epoch {} chunk {}: version-map hashes differ, uploading full map and downloading peers", + epoch, + chunk_id, + ); + + upload_chunk_version_map_body(s3, bucket, epoch, party, chunk_id, version_map).await?; + + use std::collections::HashMap; + let all_maps: Vec> = try_join_all((0..NUM_PARTIES).map(|party| { + download_chunk_version_map(s3, bucket, epoch, party, chunk_id, poll_interval) + })) + .await? + .into_iter() + .map(|v| v.into_iter().collect::>()) + .collect(); + + let mut divergent = Vec::new(); + let all_ids: std::collections::BTreeSet = + all_maps.iter().flat_map(|m| m.keys().copied()).collect(); + + for id in all_ids { + let versions: Vec> = all_maps.iter().map(|m| m.get(&id)).collect(); + let first = versions[0]; + if !versions.iter().all(|v| *v == first) { + divergent.push(id); + } + } + Ok(divergent) +} + +// ---- Epoch completion ---- + +pub async fn upload_epoch_complete( + s3: &S3Client, + bucket: &str, + epoch: u32, + party: u8, +) -> Result<()> { + let key = format!("{}/complete", epoch_party_prefix(epoch, party)); + upload_marker(s3, bucket, &key, b"done".to_vec()).await +} + +pub async fn all_parties_complete(s3: &S3Client, bucket: &str, epoch: u32) -> Result { + for party in 0..NUM_PARTIES { + let key = format!("{}/complete", epoch_party_prefix(epoch, party)); + if !marker_exists(s3, bucket, &key).await? { + return Ok(false); + } + } + Ok(true) +} + +pub async fn poll_epoch_complete_all( + s3: &S3Client, + bucket: &str, + epoch: u32, + poll_interval: Duration, +) -> Result<()> { + poll_until_all_parties_marker(s3, bucket, epoch, "complete", poll_interval).await +} diff --git a/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs b/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs new file mode 100644 index 0000000000..626f319538 --- /dev/null +++ b/iris-mpc-upgrade/tests/continuous_rerand_e2e.rs @@ -0,0 +1,746 @@ +#![cfg(feature = "db_dependent")] + +mod test_utils; + +use eyre::Result; +use iris_mpc_store::rerand as rerand_store; +use serde_json::json; +use std::sync::Mutex; +use test_utils::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +const STACK_SIZE: usize = 16 * 1024 * 1024; + +/// Tests share 3 Postgres instances and a global advisory lock constant, so +/// they must run sequentially. This mutex enforces that even without +/// `--test-threads=1`. +static SERIAL: Mutex<()> = Mutex::new(()); + +fn run_async(f: impl std::future::Future> + Send + 'static) { + let _guard = SERIAL.lock().unwrap_or_else(|e| e.into_inner()); + let result = std::thread::Builder::new() + .stack_size(STACK_SIZE) + .name("e2e".into()) + .spawn(move || { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .thread_stack_size(STACK_SIZE) + .enable_all() + .build() + .unwrap() + .block_on(f) + }) + .unwrap() + .join() + .unwrap(); + result.unwrap(); +} + +async fn set_live_applied_chunk(pool: &sqlx::PgPool, epoch: i32, max_chunk: i32) -> Result<()> { + for chunk in 0..=max_chunk { + rerand_store::upsert_rerand_progress(pool, epoch, chunk).await?; + sqlx::query( + "UPDATE rerand_progress SET live_applied = TRUE WHERE epoch = $1 AND chunk_id = $2", + ) + .bind(epoch) + .bind(chunk) + .execute(pool) + .await?; + } + Ok(()) +} + +fn spawn_checking_worker(pool: sqlx::PgPool) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + while let Ok(true) = rerand_store::check_and_handle_freeze(&pool, None).await {} + }) +} + +async fn simulate_server_startup_with_freeze( + pool: &sqlx::PgPool, + peer_addrs: &[(&str, usize)], +) -> Result<()> { + rerand_store::freeze_and_verify_watermarks(pool, peer_addrs).await?; + + // Mimic startup DB load behind apply lock. + let startup_lock = rerand_store::acquire_apply_lock(pool).await?; + let _: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM irises") + .fetch_one(pool) + .await?; + rerand_store::release_apply_lock(startup_lock).await?; + + Ok(()) +} + +async fn start_peer_watermark_server( + pool: &sqlx::PgPool, +) -> Result<(usize, tokio::task::JoinHandle<()>)> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let port = listener.local_addr()?.port() as usize; + let pool = pool.clone(); + let handle = tokio::spawn(async move { + loop { + let (mut socket, _) = match listener.accept().await { + Ok(value) => value, + Err(_) => return, + }; + + let pool = pool.clone(); + tokio::spawn(async move { + let mut buf = vec![0u8; 2048]; + let _ = socket.read(&mut buf).await; + + let wm = match rerand_store::get_applied_watermark_from_pool(&pool).await { + Ok(Some((epoch, chunk_id))) => json!({ + "epoch": epoch, + "max_applied_chunk": chunk_id, + }) + .to_string(), + Ok(None) => "null".to_string(), + Err(e) => { + let body = format!("{{\"error\":\"{}\"}}", e); + let response = format!( + "HTTP/1.1 500 Internal Server Error\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + body.len(), + body + ); + let _ = socket.write_all(response.as_bytes()).await; + return; + } + }; + + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + wm.len(), + wm + ); + let _ = socket.write_all(response.as_bytes()).await; + }); + } + }); + + Ok((port, handle)) +} + +// ============================================================================ +// Phase 1: Clean epoch -- run one full epoch, verify crypto correctness +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase1_clean_epoch() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 1] Clean epoch..."); + + let all_ids: Vec = (1..=DB_SIZE as i64).collect(); + let pre_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + + let (h, t) = env.spawn_all(); + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness, &[]).await?; + assert!(ep >= 1, "Expected rerand_epoch >= 1, got {}", ep); + assert_rerand_epoch_equals_for_ids(&env.harness, &all_ids, 1).await?; + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + + let post_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + for &id in &all_ids { + assert_ne!( + &pre_shares[&id], &post_shares[&id], + "Shares for id={} should differ after rerandomization", + id + ); + } + println!("[phase 1] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 2: Kill-and-resume -- kill mid-epoch, restart, verify recovery +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase2_kill_and_resume() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 2] Kill-and-resume..."); + + let all_ids: Vec = (1..=DB_SIZE as i64).collect(); + let pre_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + + // Run epoch 0, let 2 chunks stage, then kill + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 2).await?; + println!("[phase 2] killing after 2 chunks staged"); + stop_all(t, h).await; + + // Restart -- should resume from where it left off + println!("[phase 2] restarting..."); + let (h, t) = env.spawn_all(); + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness, &[]).await?; + assert!(ep >= 1); + assert_rerand_epoch_equals_for_ids(&env.harness, &all_ids, 1).await?; + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + + let post_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + for &id in &all_ids { + assert_ne!( + &pre_shares[&id], &post_shares[&id], + "Shares for id={} should differ after rerandomization", + id + ); + } + println!("[phase 2] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 3: Concurrent modifications -- bump version_id mid-epoch, verify +// optimistic lock skips those rows +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase3_concurrent_modifications() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + let modified_ids: Vec = vec![5, 10, 15]; + let all_ids: Vec = (1..=DB_SIZE as i64).collect(); + let non_modified_ids: Vec = all_ids + .iter() + .copied() + .filter(|id| !modified_ids.contains(id)) + .collect(); + let pre_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + println!("[phase 3] Concurrent modifications..."); + + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 1).await?; + + // Bump version_id on a few rows (simulates a reauth) + for &id in &modified_ids { + for party in &env.harness.parties { + let (before,): (i16,) = + sqlx::query_as("SELECT version_id FROM irises WHERE id = $1") + .bind(id) + .fetch_one(&party.store.pool) + .await?; + + // Flip one byte so the `increment_version_id` trigger fires. + sqlx::query( + r#" + UPDATE irises + SET left_code = set_byte(left_code, 0, get_byte(left_code, 0) # 1) + WHERE id = $1 + "#, + ) + .bind(id) + .execute(&party.store.pool) + .await?; + + let (after,): (i16,) = + sqlx::query_as("SELECT version_id FROM irises WHERE id = $1") + .bind(id) + .fetch_one(&party.store.pool) + .await?; + eyre::ensure!( + after > before, + "Expected version_id to increase for id={id}" + ); + } + } + println!("[phase 3] bumped version_id on {:?}", modified_ids); + + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness, &modified_ids).await?; + assert!(ep >= 1); + // Modified IDs can be 0 or 1 depending on whether each party staged the + // chunk before or after the local version bump. We skip them in strict + // cross-party checks and focus on non-modified IDs + fingerprint safety. + assert_rerand_epoch_equals_for_ids(&env.harness, &non_modified_ids, 1).await?; + verify_fingerprints(&env.harness, &env.fingerprints, &modified_ids).await?; + + let post_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + for &id in &non_modified_ids { + assert_ne!( + &pre_shares[&id], &post_shares[&id], + "Shares for id={} should differ after rerandomization", + id + ); + } + println!("[phase 3] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 4: Server restart during rerand -- simulate main server startup while +// rerand is running, verify advisory lock serializes access +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase4_server_restart_during_rerand() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 4] Server restart during rerand..."); + + let all_ids: Vec = (1..=DB_SIZE as i64).collect(); + let pre_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + + let (h, t) = env.spawn_all(); + wait_chunks_staged(&env.harness, 0, 1).await?; + + for p in 0..NUM_PARTIES { + let r = simulate_server_startup(&env.harness, p).await; + println!("[phase 4] party {} server startup: {:?}", p, r.is_ok()); + } + + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness, &[]).await?; + assert!(ep >= 1); + assert_rerand_epoch_equals_for_ids(&env.harness, &all_ids, 1).await?; + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + + let post_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + for &id in &all_ids { + assert_ne!( + &pre_shares[&id], &post_shares[&id], + "Shares for id={} should differ after rerandomization", + id + ); + } + println!("[phase 4] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 6: Multiple Epochs -- let the system run continuously across multiple +// epochs, verify seamless transition and correct rerandomization +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase6_multiple_epochs() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 6] Multiple epochs..."); + + let all_ids: Vec = (1..=DB_SIZE as i64).collect(); + let pre_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + + let (h, t) = env.spawn_all(); + + // Wait for epoch 0 to finish + wait_epoch_done(&env.harness, 0).await?; + println!("[phase 6] epoch 0 completed"); + + // The continuous rerand servers should automatically move to epoch 1 + wait_epoch_done(&env.harness, 1).await?; + println!("[phase 6] epoch 1 completed"); + + // The inter-epoch delay (chunk_delay) gives the cancel token time to + // be observed before the next epoch's work begins. + stop_all(t, h).await; + + let ep = assert_consistent_rerand_epoch(&env.harness, &[]).await?; + assert!(ep >= 2, "Expected rerand_epoch >= 2, got {}", ep); + assert_rerand_epoch_at_least_for_ids(&env.harness, &all_ids, 2).await?; + verify_fingerprints(&env.harness, &env.fingerprints, &[]).await?; + + let post_shares = snapshot_raw_shares(&env.harness, &all_ids).await?; + for &id in &all_ids { + assert_ne!( + &pre_shares[&id], &post_shares[&id], + "Shares for id={} should differ after rerandomization", + id + ); + } + println!("[phase 6] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 7: Startup validation rejects fatal desync and accepts in-sync state. +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase7_startup_validation() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 7] Startup validation..."); + + // Fatal desync (gap > 1) → immediate bail + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (0, 0, TRUE, TRUE, TRUE)") + .execute(pool).await.unwrap(); + } + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (2, 0, TRUE, TRUE, TRUE)") + .execute(&env.harness.parties[0].store.pool).await.unwrap(); + + let r_fatal = simulate_server_startup_with_rerand_validation(&env.harness, 1).await; + assert!(r_fatal.is_err(), "Fatal epoch gap should bail immediately"); + + // In-sync → startup succeeds immediately + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + sqlx::query("DELETE FROM rerand_progress") + .execute(pool) + .await + .unwrap(); + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (0, 0, TRUE, TRUE, TRUE)") + .execute(pool).await.unwrap(); + } + + let r_ok = simulate_server_startup_with_rerand_validation(&env.harness, 0).await; + assert!(r_ok.is_ok(), "In-sync startup should succeed"); + + println!("[phase 7] PASSED"); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 8: Disallow loading mismatched peers +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase8_reject_desync() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 8] Reject desync..."); + + // Setup the exact boundary desync state in DB manually + // P1 is on Epoch 0 (has max epoch 0) + // P0 and P2 are on Epoch 2 (have max epoch 2) + // If a peer is *more than 1 epoch ahead*, we should panic/reject + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (0, 0, TRUE, TRUE, TRUE)") + .execute(pool).await.unwrap(); + } + + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (2, 0, TRUE, TRUE, FALSE)") + .execute(&env.harness.parties[0].store.pool).await.unwrap(); + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (2, 0, TRUE, TRUE, FALSE)") +.execute(&env.harness.parties[2].store.pool).await.unwrap(); + + let r1 = simulate_server_startup_with_rerand_validation(&env.harness, 1).await; + assert!( + r1.is_err(), + "P1 startup should have failed due to large epoch gap" + ); + + // Now test the new chunk gap logic + // P1 has chunk 0 confirmed, P0 has chunk 2 confirmed (gap > 1) in the same epoch + for p in 0..NUM_PARTIES { + let pool = &env.harness.parties[p].store.pool; + sqlx::query("DELETE FROM rerand_progress") + .execute(pool) + .await + .unwrap(); + } + + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (3, 0, TRUE, TRUE, TRUE)") +.execute(&env.harness.parties[1].store.pool).await.unwrap(); + + sqlx::query("INSERT INTO rerand_progress (epoch, chunk_id, staging_written, all_confirmed, live_applied) VALUES (3, 2, TRUE, TRUE, FALSE)") +.execute(&env.harness.parties[0].store.pool).await.unwrap(); + + let r1_chunk_desync = simulate_server_startup_with_rerand_validation(&env.harness, 1).await; + assert!( + r1_chunk_desync.is_err(), + "P1 startup should have failed due to large chunk gap" + ); + + println!("[phase 8] PASSED"); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 9: Asymmetric modification — a modification landing on only one +// party's DB must NOT cause cross-party share divergence. +// The modification fence (version-map exchange + skip-set union) +// detects the asymmetry and excludes the affected row. +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase9_asymmetric_modification_consistency() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + let target_id: i64 = 20; + let non_target_ids: Vec = (1..=DB_SIZE as i64).filter(|id| *id != target_id).collect(); + let pre_shares = snapshot_raw_shares(&env.harness, &[target_id]).await?; + println!("[phase 9] Asymmetric modification consistency..."); + + // Modify iris on P0 ONLY — simulates a reauth that propagated to + // P0 via SQS but hasn't reached P1/P2 yet. + sqlx::query( + r#" + UPDATE irises + SET left_code = set_byte(left_code, 0, get_byte(left_code, 0) # 1) + WHERE id = $1 + "#, + ) + .bind(target_id) + .execute(&env.harness.parties[0].store.pool) + .await?; + println!("[phase 9] modified id={} on P0 only", target_id); + + // Run a full epoch across all 3 parties. + let (h, t) = env.spawn_all(); + wait_epoch_done(&env.harness, 0).await?; + stop_all(t, h).await; + + // The modification fence should have detected the asymmetric + // version_id and excluded id=20 from the apply on all parties. + // Non-modified rows should still be rerandomized consistently. + let ep = assert_consistent_rerand_epoch(&env.harness, &[target_id]).await?; + assert!(ep >= 1); + assert_rerand_epoch_equals_for_ids(&env.harness, &non_target_ids, 1).await?; + + // The modified ID should have been skipped (rerand_epoch stays 0) + // on ALL parties, OR applied consistently. Either way shares must + // reconstruct. + let epochs = get_rerand_epochs_for_id(&env.harness, target_id).await?; + let epochs_consistent = epochs[0] == epochs[1] && epochs[1] == epochs[2]; + println!( + "[phase 9] rerand_epochs for id={}: {:?} (consistent={})", + target_id, epochs, epochs_consistent + ); + assert!( + epochs_consistent, + "rerand_epoch diverged for id={}: {:?}", + target_id, epochs + ); + + // The excluded iris must have rerand_epoch == 0 (skipped everywhere). + assert_eq!( + epochs[0], 0, + "Excluded iris id={} should have rerand_epoch=0, got {}", + target_id, epochs[0] + ); + + // The iris was skipped by rerand, so P1 and P2 (unmodified) should + // have identical shares to before the test. P0 was byte-flipped. + let post_shares = snapshot_raw_shares(&env.harness, &[target_id]).await?; + for party in 1..NUM_PARTIES { + assert_eq!( + &pre_shares[&target_id][party], &post_shares[&target_id][party], + "P{} shares for excluded id={} should be unchanged", + party, target_id + ); + } + assert_ne!( + &pre_shares[&target_id][0], &post_shares[&target_id][0], + "P0 shares for id={} should differ (byte-flip modification)", + target_id + ); + + println!( + "[phase 9] excluded iris id={} correctly skipped (rerand_epoch=0, P1/P2 unchanged)", + target_id + ); + + // Verify non-modified rows reconstruct correctly. + verify_fingerprints(&env.harness, &env.fingerprints, &[target_id]).await?; + + println!("[phase 9] PASSED (epoch={})", ep); + + env.teardown().await + }); +} + +// ============================================================================ +// Phase 10: Startup freeze catchup path — local party is behind peers and +// advances while freeze is released and re-acquired. +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase10_startup_freeze_local_catchup() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 10] Startup freeze catchup..."); + + let p0_pool = &env.harness.parties[0].store.pool; + let p1_pool = &env.harness.parties[1].store.pool; + let p2_pool = &env.harness.parties[2].store.pool; + + // Local is behind peers in this epoch. + set_live_applied_chunk(p0_pool, 0, 0).await?; + set_live_applied_chunk(p1_pool, 0, 4).await?; + set_live_applied_chunk(p2_pool, 0, 4).await?; + + let (p1_port, p1_server) = start_peer_watermark_server(p1_pool).await?; + let (p2_port, p2_server) = start_peer_watermark_server(p2_pool).await?; + let worker = spawn_checking_worker(p0_pool.clone()); + + // Simulate a main-server startup sequence where this party releases freeze + // so catchup can happen, then re-enters freeze logic. + let catchup = tokio::spawn({ + let p0_pool = p0_pool.clone(); + async move { + loop { + let (freeze_requested,): (bool,) = + sqlx::query_as("SELECT freeze_requested FROM rerand_control WHERE id = 1") + .fetch_one(&p0_pool) + .await?; + if freeze_requested { + set_live_applied_chunk(&p0_pool, 0, 4).await?; + return Ok::<_, eyre::Report>(()); + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + } + }); + + let startup = tokio::time::timeout( + std::time::Duration::from_secs(25), + simulate_server_startup_with_freeze( + p0_pool, + &[("127.0.0.1", p1_port), ("127.0.0.1", p2_port)], + ), + ) + .await; + assert!(startup.is_ok(), "startup freeze converge timed out"); + startup.unwrap()?; + + assert_eq!( + rerand_store::get_applied_watermark_from_pool(p0_pool).await?, + Some((0, 4)) + ); + rerand_store::release_rerand_freeze(p0_pool).await?; + catchup.await?.unwrap(); + + let control = sqlx::query_as::<_, (bool, Option)>( + "SELECT freeze_requested, freeze_generation FROM rerand_control WHERE id = 1", + ) + .fetch_one(p0_pool) + .await?; + assert!( + !control.0, + "freeze should be released after startup converge" + ); + assert!( + control.1.is_none(), + "stale freeze generation should be cleared" + ); + + worker.abort(); + p1_server.abort(); + p2_server.abort(); + env.teardown().await + }); +} + +// ============================================================================ +// Phase 11: Startup freeze wait path — local party is at max and peers catch up. +// ============================================================================ + +#[test] +#[ignore = "Requires 3 local Postgres instances (6200-6202) and localstack; run via run-rerand-e2e-tests.sh"] +fn phase11_startup_freeze_waits_for_peers() { + run_async(async { + let _ = tracing_subscriber::fmt::try_init(); + let env = TestEnv::setup().await?; + println!("[phase 11] Startup freeze peer catchup..."); + + let p0_pool = &env.harness.parties[0].store.pool; + let p1_pool = &env.harness.parties[1].store.pool; + let p2_pool = &env.harness.parties[2].store.pool; + + // Local is fully caught up initially; peers lag at chunk 0. + set_live_applied_chunk(p0_pool, 0, 4).await?; + set_live_applied_chunk(p1_pool, 0, 0).await?; + set_live_applied_chunk(p2_pool, 0, 0).await?; + + let (p1_port, p1_server) = start_peer_watermark_server(p1_pool).await?; + let (p2_port, p2_server) = start_peer_watermark_server(p2_pool).await?; + let worker = spawn_checking_worker(p0_pool.clone()); + + let advance_peers = tokio::spawn({ + let p1_pool = p1_pool.clone(); + let p2_pool = p2_pool.clone(); + async move { + tokio::time::sleep(std::time::Duration::from_millis(300)).await; + set_live_applied_chunk(&p1_pool, 0, 4).await?; + set_live_applied_chunk(&p2_pool, 0, 4).await?; + Result::<(), eyre::Report>::Ok(()) + } + }); + + let startup = tokio::time::timeout( + std::time::Duration::from_secs(25), + simulate_server_startup_with_freeze( + p0_pool, + &[("127.0.0.1", p1_port), ("127.0.0.1", p2_port)], + ), + ) + .await; + assert!(startup.is_ok(), "startup freeze converge timed out"); + startup.unwrap()?; + + assert_eq!( + rerand_store::get_applied_watermark_from_pool(p0_pool).await?, + Some((0, 4)) + ); + rerand_store::release_rerand_freeze(p0_pool).await?; + advance_peers.await??; + + let control = sqlx::query_as::<_, (bool, Option)>( + "SELECT freeze_requested, freeze_generation FROM rerand_control WHERE id = 1", + ) + .fetch_one(p0_pool) + .await?; + assert!( + !control.0, + "freeze should be released after startup converge" + ); + assert!( + control.1.is_none(), + "stale freeze generation should be cleared" + ); + + worker.abort(); + p1_server.abort(); + p2_server.abort(); + env.teardown().await + }); +} diff --git a/iris-mpc-upgrade/tests/test_utils.rs b/iris-mpc-upgrade/tests/test_utils.rs new file mode 100644 index 0000000000..dbb94b7ad3 --- /dev/null +++ b/iris-mpc-upgrade/tests/test_utils.rs @@ -0,0 +1,712 @@ +#![allow(dead_code)] + +use eyre::Result; +use iris_mpc_common::{ + galois_engine::degree4::FullGaloisRingIrisCodeShare, + iris_db::iris::IrisCode, + postgres::{AccessMode, PostgresClient}, +}; +use iris_mpc_store::rerand as rerand_store; +use iris_mpc_store::{Store, StoredIrisRef}; +use iris_mpc_upgrade::config::RerandomizeContinuousConfig; +use iris_mpc_upgrade::continuous_rerand::run_continuous_rerand; +use iris_mpc_upgrade::rerandomization::reconstruct_shares; +use rand::{rngs::StdRng, SeedableRng}; +use std::collections::HashMap; +use std::time::Duration; +use tokio_util::sync::CancellationToken; + +pub const NUM_PARTIES: usize = 3; +pub const DB_SIZE: usize = 50; +pub const CHUNK_SIZE: u64 = 25; + +fn db_urls() -> Vec { + (0..3) + .map(|i| format!("postgres://postgres:postgres@localhost:{}", 6200 + i)) + .collect() +} + +pub struct PartyDb { + pub store: Store, + pub schema_name: String, +} + +pub struct TestHarness { + pub parties: Vec, +} + +impl TestHarness { + pub async fn new(db_urls: &[&str], schema_prefix: &str) -> Result { + let mut parties = Vec::new(); + for (i, url) in db_urls.iter().enumerate() { + let schema = format!("{}_{}", schema_prefix, i); + let pg = PostgresClient::new(url, &schema, AccessMode::ReadWrite).await?; + let store = Store::new(&pg).await?; + parties.push(PartyDb { + store, + schema_name: schema, + }); + } + Ok(Self { parties }) + } + + pub fn store(&self, party: usize) -> &Store { + &self.parties[party].store + } +} + +/// Full test environment: harness + AWS clients + unique prefix + unique S3 bucket. +pub struct TestEnv { + pub harness: TestHarness, + pub s3: aws_sdk_s3::Client, + pub sm: aws_sdk_secretsmanager::Client, + pub prefix: String, + pub bucket: String, + pub fingerprints: PlaintextFingerprints, +} + +impl TestEnv { + pub async fn setup() -> Result { + let id = rand::random::(); + let prefix = format!("SMPC_e2e_{}", id); + let bucket = format!("rerand-e2e-{}", id); + let urls = db_urls(); + let url_refs: Vec<&str> = urls.iter().map(|s| s.as_str()).collect(); + let harness = TestHarness::new(&url_refs, &prefix).await?; + + let sdk = aws_config::from_env().load().await; + let s3 = aws_sdk_s3::Client::new(&sdk); + let sm = aws_sdk_secretsmanager::Client::new(&sdk); + + s3.create_bucket() + .bucket(&bucket) + .send() + .await + .map_err(|e| eyre::eyre!("Failed to create bucket {}: {}", bucket, e))?; + + println!( + " [setup] Seeding {} irises (prefix={}, bucket={})", + DB_SIZE, prefix, bucket + ); + seed_three_party_db(&harness, DB_SIZE).await?; + let fingerprints = snapshot_all_fingerprints(&harness, &[]).await?; + + Ok(Self { + harness, + s3, + sm, + prefix, + bucket, + fingerprints, + }) + } + + pub async fn teardown(&self) -> Result<()> { + cleanup(&self.harness).await?; + // Delete all objects in the bucket then delete the bucket + let mut token = None; + loop { + let mut req = self.s3.list_objects_v2().bucket(&self.bucket); + if let Some(t) = &token { + req = req.continuation_token(t); + } + let resp = req.send().await?; + for obj in resp.contents() { + if let Some(key) = obj.key() { + self.s3 + .delete_object() + .bucket(&self.bucket) + .key(key) + .send() + .await?; + } + } + if resp.is_truncated() == Some(true) { + token = resp.next_continuation_token().map(|s| s.to_string()); + } else { + break; + } + } + let _ = self.s3.delete_bucket().bucket(&self.bucket).send().await; + Ok(()) + } + + pub fn make_config(&self, party_id: u8) -> RerandomizeContinuousConfig { + RerandomizeContinuousConfig { + party_id, + db_url: format!( + "postgres://postgres:postgres@localhost:{}", + 6200 + party_id as u16 + ), + env: "testing".to_string(), + service_name: format!("iris-mpc-{}", party_id), + s3_bucket: self.bucket.clone(), + schema_name: format!("{}_{}", self.prefix, party_id), + chunk_size: CHUNK_SIZE, + chunk_delay_secs: 1, + safety_buffer_ids: 0, + s3_poll_interval_ms: 200, + healthcheck_port: 3020 + party_id as usize, + rerand_enabled: true, + } + } + + pub fn spawn_rerand( + &self, + party_id: u8, + ) -> (tokio::task::JoinHandle>, CancellationToken) { + let config = self.make_config(party_id); + let s3 = self.s3.clone(); + let sm = self.sm.clone(); + let store = self.harness.store(party_id as usize).clone(); + let token = CancellationToken::new(); + let tc = token.clone(); + let h = tokio::spawn(async move { + run_continuous_rerand(&config, &s3, &sm, &store, Some(&tc)).await + }); + (h, token) + } + + pub fn spawn_all( + &self, + ) -> ( + Vec>>, + Vec, + ) { + let mut handles = Vec::new(); + let mut tokens = Vec::new(); + for p in 0u8..3 { + let (h, t) = self.spawn_rerand(p); + handles.push(h); + tokens.push(t); + } + (handles, tokens) + } +} + +pub async fn stop_all( + tokens: Vec, + handles: Vec>>, +) { + for t in &tokens { + t.cancel(); + } + for h in &handles { + h.abort(); + } + for h in handles { + let _ = h.await; + } +} + +// ---- DB seeding ---- + +pub async fn seed_three_party_db(harness: &TestHarness, count: usize) -> Result<()> { + let mut rng = StdRng::seed_from_u64(42); + for chunk_start in (1..=count).step_by(100) { + let chunk_end = std::cmp::min(chunk_start + 100, count + 1); + + struct S { + id: i64, + lc: Vec, + lm: Vec, + rc: Vec, + rm: Vec, + } + + let mut party_data: Vec> = (0..NUM_PARTIES).map(|_| Vec::new()).collect(); + for serial_id in chunk_start..chunk_end { + let il = IrisCode::random_rng(&mut rng); + let ir = IrisCode::random_rng(&mut rng); + let [l0, l1, l2] = FullGaloisRingIrisCodeShare::encode_iris_code(&il, &mut rng); + let [r0, r1, r2] = FullGaloisRingIrisCodeShare::encode_iris_code(&ir, &mut rng); + for (pi, (left, right)) in [(l0, r0), (l1, r1), (l2, r2)].into_iter().enumerate() { + party_data[pi].push(S { + id: serial_id as i64, + lc: left.code.coefs.to_vec(), + lm: left.mask.coefs.to_vec(), + rc: right.code.coefs.to_vec(), + rm: right.mask.coefs.to_vec(), + }); + } + } + for (pi, shares) in party_data.iter().enumerate() { + let refs: Vec = shares + .iter() + .map(|s| StoredIrisRef { + id: s.id, + left_code: &s.lc, + left_mask: &s.lm, + right_code: &s.rc, + right_mask: &s.rm, + }) + .collect(); + let store = harness.store(pi); + let mut tx = store.tx().await?; + store.insert_irises_overriding(&mut tx, &refs).await?; + tx.commit().await?; + } + } + Ok(()) +} + +// ---- Fingerprint verification ---- + +/// blake3 hash of the concatenated reconstructed plaintext (left_code ++ left_mask +/// ++ right_code ++ right_mask) for every iris ID. +pub type PlaintextFingerprints = HashMap; + +/// Compute a fingerprint for every iris in the DB by reconstructing shares +/// from all 3 parties. IDs in `skip_ids` are excluded (their shares may be +/// inconsistent across parties due to concurrent modifications). +pub async fn snapshot_all_fingerprints( + harness: &TestHarness, + skip_ids: &[i64], +) -> Result { + let ids: Vec<(i64,)> = sqlx::query_as("SELECT id FROM irises ORDER BY id") + .fetch_all(&harness.store(0).pool) + .await?; + + let mut fps = PlaintextFingerprints::new(); + for (id,) in ids { + if skip_ids.contains(&id) { + continue; + } + let mut shares = Vec::new(); + for party in 0..NUM_PARTIES { + shares.push(harness.store(party).get_iris_data_by_id(id).await?); + } + let mut hasher = blake3::Hasher::new(); + let fields: Vec<[&[u16]; 3]> = vec![ + [ + shares[0].left_code(), + shares[1].left_code(), + shares[2].left_code(), + ], + [ + shares[0].left_mask(), + shares[1].left_mask(), + shares[2].left_mask(), + ], + [ + shares[0].right_code(), + shares[1].right_code(), + shares[2].right_code(), + ], + [ + shares[0].right_mask(), + shares[1].right_mask(), + shares[2].right_mask(), + ], + ]; + for [s0, s1, s2] in &fields { + let recon = reconstruct_shares(s0, s1, s2); + hasher.update(bytemuck::cast_slice::(&recon)); + } + fps.insert(id, *hasher.finalize().as_bytes()); + } + Ok(fps) +} + +/// Verify that current shares reconstruct to the same plaintexts as the +/// snapshot. `skip_ids` are excluded (modified during test). +pub async fn verify_fingerprints( + harness: &TestHarness, + expected: &PlaintextFingerprints, + skip_ids: &[i64], +) -> Result<()> { + let current = snapshot_all_fingerprints(harness, skip_ids).await?; + let mut checked = 0; + for (id, exp) in expected { + if skip_ids.contains(id) { + continue; + } + let cur = current + .get(id) + .unwrap_or_else(|| panic!("ID {} missing from current DB", id)); + assert_eq!(exp, cur, "Plaintext fingerprint mismatch for id {}", id); + checked += 1; + } + println!( + " verified {}/{} iris fingerprints", + checked, + expected.len() + ); + Ok(()) +} + +// ---- Polling helpers ---- + +pub async fn wait_epoch_done(harness: &TestHarness, epoch: i32) -> Result<()> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(120); + let start = std::time::Instant::now(); + let mut last_print = start; + let expected_chunks: i64 = ((DB_SIZE as i64) + (CHUNK_SIZE as i64) - 1) / (CHUNK_SIZE as i64); + loop { + if tokio::time::Instant::now() > deadline { + eyre::bail!("Timeout waiting for epoch {}", epoch); + } + let mut done = true; + let mut applied = [0usize; 3]; + let mut totals = [0usize; 3]; + for (i, party) in harness.parties.iter().enumerate() { + let (total, applied_count): (i64, i64) = sqlx::query_as( + "SELECT COUNT(*), COUNT(*) FILTER (WHERE live_applied = TRUE) \ + FROM rerand_progress WHERE epoch = $1", + ) + .bind(epoch) + .fetch_one(&party.store.pool) + .await?; + + totals[i] = total as usize; + applied[i] = applied_count as usize; + + if total < expected_chunks || applied_count < expected_chunks { + done = false; + } + } + if done { + println!( + " epoch {} done in {:.1}s", + epoch, + start.elapsed().as_secs_f64() + ); + return Ok(()); + } + if last_print.elapsed() > Duration::from_secs(5) { + println!( + " waiting epoch {}: applied {:?} / totals {:?} ({:.0}s)", + epoch, + applied, + totals, + start.elapsed().as_secs_f64() + ); + last_print = std::time::Instant::now(); + } + tokio::time::sleep(Duration::from_millis(500)).await; + } +} + +pub async fn wait_chunks_staged(harness: &TestHarness, epoch: i32, n: i32) -> Result<()> { + let deadline = tokio::time::Instant::now() + Duration::from_secs(60); + let start = std::time::Instant::now(); + loop { + if tokio::time::Instant::now() > deadline { + eyre::bail!("Timeout waiting for {} chunks staged in epoch {}", n, epoch); + } + let mut max_count = 0i64; + for party in &harness.parties { + let (count,): (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM rerand_progress WHERE epoch = $1 AND staging_written = TRUE", + ) + .bind(epoch) + .fetch_one(&party.store.pool) + .await?; + max_count = max_count.max(count); + } + if max_count >= n as i64 { + println!( + " {} chunks staged for epoch {} in {:.1}s", + max_count, + epoch, + start.elapsed().as_secs_f64() + ); + return Ok(()); + } + tokio::time::sleep(Duration::from_millis(200)).await; + } +} + +// ---- Server simulation ---- + +pub async fn simulate_server_startup(harness: &TestHarness, party: usize) -> Result<()> { + let pool = &harness.parties[party].store.pool; + let lock_conn = rerand_store::acquire_apply_lock(pool).await?; + let query_result: Result<(i64,), sqlx::Error> = sqlx::query_as("SELECT COUNT(*) FROM irises") + .fetch_one(pool) + .await; + rerand_store::release_apply_lock(lock_conn).await?; + let _count = query_result?; + Ok(()) +} + +pub async fn simulate_server_startup_with_rerand_validation( + harness: &TestHarness, + party: usize, +) -> Result<()> { + simulate_server_startup(harness, party).await?; + validate_rerand_startup_safety(harness).await +} + +async fn validate_rerand_startup_safety(harness: &TestHarness) -> Result<()> { + let mut epochs = Vec::with_capacity(harness.parties.len()); + let mut confirmed_chunks = Vec::with_capacity(harness.parties.len()); + + for party in &harness.parties { + let (epoch,): (Option,) = sqlx::query_as("SELECT MAX(epoch) FROM rerand_progress") + .fetch_one(&party.store.pool) + .await?; + let epoch = epoch.unwrap_or(0); + + let (max_confirmed_chunk,): (Option,) = sqlx::query_as( + "SELECT MAX(chunk_id) FROM rerand_progress WHERE epoch = $1 AND all_confirmed = TRUE", + ) + .bind(epoch) + .fetch_one(&party.store.pool) + .await?; + + epochs.push(epoch); + confirmed_chunks.push(max_confirmed_chunk.unwrap_or(-1)); + } + + let min_epoch = *epochs + .iter() + .min() + .ok_or_else(|| eyre::eyre!("No parties found for rerand startup validation"))?; + let max_epoch = *epochs + .iter() + .max() + .ok_or_else(|| eyre::eyre!("No parties found for rerand startup validation"))?; + + if max_epoch - min_epoch > 1 { + eyre::bail!( + "Startup cannot proceed: rerand epoch gap is too large (min={}, max={}).", + min_epoch, + max_epoch + ); + } + + let max_epoch_parties: Vec<_> = epochs + .iter() + .zip(confirmed_chunks.iter()) + .filter(|(e, _)| **e == max_epoch) + .map(|(_, c)| *c) + .collect(); + + if let (Some(min_chunk), Some(max_chunk)) = ( + max_epoch_parties.iter().min().cloned(), + max_epoch_parties.iter().max().cloned(), + ) { + if max_chunk - min_chunk > 1 { + eyre::bail!( + "Startup cannot proceed: rerand confirmed-chunk gap is too large at epoch {} (min={}, max={}).", + max_epoch, + min_chunk, + max_chunk + ); + } + } + + Ok(()) +} + +pub async fn assert_consistent_rerand_epoch( + harness: &TestHarness, + skip_ids: &[i64], +) -> Result { + let mut all: Vec> = Vec::new(); + for party in &harness.parties { + all.push( + sqlx::query_as("SELECT id, rerand_epoch FROM irises ORDER BY id") + .fetch_all(&party.store.pool) + .await?, + ); + } + assert_eq!(all[0].len(), all[1].len()); + assert_eq!(all[1].len(), all[2].len()); + for i in 0..all[0].len() { + if skip_ids.contains(&all[0][i].0) { + continue; + } + assert_eq!( + all[0][i].1, all[1][i].1, + "epoch mismatch id {} p0 vs p1", + all[0][i].0 + ); + assert_eq!( + all[0][i].1, all[2][i].1, + "epoch mismatch id {} p0 vs p2", + all[0][i].0 + ); + } + Ok(all[0] + .iter() + .find(|(id, _)| !skip_ids.contains(id)) + .map(|(_, e)| *e) + .unwrap_or(0)) +} + +/// Check whether shares for a specific iris ID reconstruct consistently +/// across all 3 party-pair combinations. Returns false if the shares are +/// divergent (reconstruction from different pairs disagrees). +pub async fn shares_are_consistent(harness: &TestHarness, id: i64) -> Result { + let mut shares = Vec::new(); + for party in 0..NUM_PARTIES { + shares.push(harness.store(party).get_iris_data_by_id(id).await?); + } + + let pairs: Vec<[&[u16]; 3]> = vec![ + [ + shares[0].left_code(), + shares[1].left_code(), + shares[2].left_code(), + ], + [ + shares[0].left_mask(), + shares[1].left_mask(), + shares[2].left_mask(), + ], + [ + shares[0].right_code(), + shares[1].right_code(), + shares[2].right_code(), + ], + [ + shares[0].right_mask(), + shares[1].right_mask(), + shares[2].right_mask(), + ], + ]; + + use iris_mpc_common::galois::degree4::ShamirGaloisRingShare; + use iris_mpc_common::galois::degree4::{basis::Monomial, GaloisRingElement}; + use iris_mpc_common::id::PartyID; + use itertools::Itertools; + + let lag_01 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID0, PartyID::ID1); + let lag_10 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID1, PartyID::ID0); + let lag_12 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID1, PartyID::ID2); + let lag_21 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID2, PartyID::ID1); + let lag_02 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID0, PartyID::ID2); + let lag_20 = ShamirGaloisRingShare::deg_1_lagrange_polys_at_zero(PartyID::ID2, PartyID::ID0); + + for [s0, s1, s2] in &pairs { + let recon01: Vec = s0 + .chunks_exact(4) + .zip_eq(s1.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + (a * lag_01 + b * lag_10).coefs + }) + .collect(); + let recon12: Vec = s1 + .chunks_exact(4) + .zip_eq(s2.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + (a * lag_12 + b * lag_21).coefs + }) + .collect(); + let recon02: Vec = s0 + .chunks_exact(4) + .zip_eq(s2.chunks_exact(4)) + .flat_map(|(a, b)| { + let a = GaloisRingElement::::from_coefs(a.try_into().unwrap()); + let b = GaloisRingElement::::from_coefs(b.try_into().unwrap()); + (a * lag_02 + b * lag_20).coefs + }) + .collect(); + if recon01 != recon12 || recon01 != recon02 { + return Ok(false); + } + } + Ok(true) +} + +/// Snapshot raw share bytes for a set of IDs (all parties). +/// Returns a map from id to Vec of (left_code, left_mask, right_code, right_mask) per party. +pub async fn snapshot_raw_shares( + harness: &TestHarness, + ids: &[i64], +) -> Result, Vec, Vec, Vec)>>> { + let mut result = HashMap::new(); + for &id in ids { + let mut party_shares = Vec::new(); + for party in 0..NUM_PARTIES { + let iris = harness.store(party).get_iris_data_by_id(id).await?; + party_shares.push(( + bytemuck::cast_slice::(iris.left_code()).to_vec(), + bytemuck::cast_slice::(iris.left_mask()).to_vec(), + bytemuck::cast_slice::(iris.right_code()).to_vec(), + bytemuck::cast_slice::(iris.right_mask()).to_vec(), + )); + } + result.insert(id, party_shares); + } + Ok(result) +} + +/// Get the rerand_epoch for a specific iris ID across all parties. +pub async fn get_rerand_epochs_for_id(harness: &TestHarness, id: i64) -> Result<[i32; 3]> { + let mut epochs = [0i32; 3]; + for (i, party) in harness.parties.iter().enumerate() { + let (ep,): (i32,) = sqlx::query_as("SELECT rerand_epoch FROM irises WHERE id = $1") + .bind(id) + .fetch_one(&party.store.pool) + .await?; + epochs[i] = ep; + } + Ok(epochs) +} + +/// Assert that all parties have the exact expected rerand_epoch for every id. +pub async fn assert_rerand_epoch_equals_for_ids( + harness: &TestHarness, + ids: &[i64], + expected_epoch: i32, +) -> Result<()> { + for &id in ids { + let epochs = get_rerand_epochs_for_id(harness, id).await?; + for (party, epoch) in epochs.iter().enumerate() { + eyre::ensure!( + *epoch == expected_epoch, + "Expected rerand_epoch={} for id={} on party {}, got {}", + expected_epoch, + id, + party, + epoch + ); + } + } + Ok(()) +} + +/// Assert that all parties have rerand_epoch >= min_epoch for every id. +pub async fn assert_rerand_epoch_at_least_for_ids( + harness: &TestHarness, + ids: &[i64], + min_epoch: i32, +) -> Result<()> { + for &id in ids { + let epochs = get_rerand_epochs_for_id(harness, id).await?; + for (party, epoch) in epochs.iter().enumerate() { + eyre::ensure!( + *epoch >= min_epoch, + "Expected rerand_epoch>={} for id={} on party {}, got {}", + min_epoch, + id, + party, + epoch + ); + } + } + Ok(()) +} + +async fn cleanup(harness: &TestHarness) -> Result<()> { + for party in &harness.parties { + let staging = rerand_store::staging_schema_name(&party.schema_name); + let _ = sqlx::query(&format!(r#"DROP SCHEMA IF EXISTS "{}" CASCADE"#, staging)) + .execute(&party.store.pool) + .await; + let _ = sqlx::query(&format!( + r#"DROP SCHEMA IF EXISTS "{}" CASCADE"#, + party.schema_name + )) + .execute(&party.store.pool) + .await; + } + Ok(()) +} diff --git a/iris-mpc/Cargo.toml b/iris-mpc/Cargo.toml index 1f56c59d4f..e1cd8e4524 100644 --- a/iris-mpc/Cargo.toml +++ b/iris-mpc/Cargo.toml @@ -33,6 +33,7 @@ chrono.workspace = true sqlx.workspace = true bincode.workspace = true pprof = { version = "0.15.0", features = ["flamegraph", "prost-codec"] } +axum.workspace = true ampc-anon-stats.workspace = true ampc-server-utils.workspace = true diff --git a/iris-mpc/src/server/mod.rs b/iris-mpc/src/server/mod.rs index c02901087c..9ac9fac5c6 100644 --- a/iris-mpc/src/server/mod.rs +++ b/iris-mpc/src/server/mod.rs @@ -13,8 +13,8 @@ use ampc_server_utils::batch_sync::{CURRENT_BATCH_SHA, CURRENT_BATCH_VALID_ENTRI use ampc_server_utils::shutdown_handler::ShutdownHandler; use ampc_server_utils::{ delete_messages_until_sequence_num, get_next_sns_seq_num, get_others_sync_state, - init_heartbeat_task, set_node_ready, start_coordination_server, wait_for_others_ready, - wait_for_others_unready, BatchSyncSharedState, TaskMonitor, + init_heartbeat_task, set_node_ready, start_coordination_server_with_extra_routes, + wait_for_others_ready, wait_for_others_unready, BatchSyncSharedState, TaskMonitor, }; use chrono::Utc; use eyre::{bail, eyre, Report, Result}; @@ -38,6 +38,7 @@ use iris_mpc_cpu::execution::hawk_main::{ use iris_mpc_cpu::hawkers::aby3::aby3_store::Aby3Store; use iris_mpc_cpu::hnsw::graph::graph_store::GraphPg; use iris_mpc_store::loader::load_iris_db; +use iris_mpc_store::rerand::{self as rerand_store}; use iris_mpc_store::Store; use pprof::protos::Message; use pprof::ProfilerGuardBuilder; @@ -93,13 +94,46 @@ pub async fn server_main(config: Config) -> Result<()> { server_coord_config.healthcheck_ports ); - // Start coordination server - let (is_ready_flag, verified_peers, my_uuid) = start_coordination_server( + // Build a /rerand-watermark route that queries the DB live on each request. + let rerand_watermark_route = { + let pool = iris_store.pool.clone(); + axum::Router::new().route( + "/rerand-watermark", + axum::routing::get(move || { + let pool = pool.clone(); + async move { + let wm = rerand_store::get_applied_watermark_from_pool(&pool).await; + match wm { + Ok(Some((epoch, chunk))) => ( + axum::http::StatusCode::OK, + serde_json::to_string(&serde_json::json!({ + "epoch": epoch, + "max_applied_chunk": chunk, + })) + .unwrap(), + ), + Ok(None) => (axum::http::StatusCode::OK, "null".to_string()), + Err(e) => { + tracing::warn!("rerand-watermark query failed: {:?}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("DB error: {}", e), + ) + } + } + } + }), + ) + }; + + // Start coordination server with the live watermark route injected. + let (is_ready_flag, verified_peers, my_uuid) = start_coordination_server_with_extra_routes( &server_coord_config, &mut background_tasks, &shutdown_handler, &my_state, Some(batch_sync_shared_state.clone()), + Some(rerand_watermark_route), ) .await; tracing::info!("Coordination server started"); @@ -139,33 +173,102 @@ pub async fn server_main(config: Config) -> Result<()> { sync_sqs_queues(&config, &sync_result, &aws_clients).await?; - if shutdown_handler.is_shutting_down() { - tracing::warn!("Shutting down has been triggered"); - return Ok(()); + // --- Coordinated rerand freeze with watermark convergence --- + if config.rerand_enabled { + let sc = config.server_coordination.as_ref().unwrap(); + eyre::ensure!( + sc.node_hostnames.len() == sc.healthcheck_ports.len(), + "node_hostnames ({}) and healthcheck_ports ({}) must have the same length", + sc.node_hostnames.len(), + sc.healthcheck_ports.len(), + ); + let peer_addrs: Vec<(&str, usize)> = sc + .node_hostnames + .iter() + .zip(sc.healthcheck_ports.iter()) + .enumerate() + .filter(|(i, _)| *i != config.party_id) + .map(|(_, (h, p))| -> eyre::Result<_> { Ok((h.as_str(), p.parse::()?)) }) + .collect::>>()?; + rerand_store::freeze_and_verify_watermarks(&iris_store.pool, &peer_addrs).await?; + } else if rerand_store::is_worker_alive(&iris_store.pool).await? { + // Worker heartbeat is fresh but this server is configured with rerand + // off. Starting up now would skip freeze/watermark coordination and + // risk loading a cross-party-inconsistent DB snapshot. Fail closed. + eyre::bail!( + "rerand_enabled=false in config but the rerand worker is alive \ + (heartbeat within the last {:?}). Either set SMPC__RERAND_ENABLED=true, \ + or stop the rerand worker on all parties before restarting this server.", + rerand_store::WORKER_HEARTBEAT_STALE_AFTER, + ); + } else { + tracing::info!( + "rerand_enabled=false and no fresh worker heartbeat — skipping rerand coordination" + ); } + // Worker is now frozen with verified equal watermarks. + // Everything from here until freeze release must be wrapped so that + // errors always release the freeze. + let frozen_result = async { + // Acquire the apply lock to prevent concurrent startup DB loads. + // This should theoretically not be needed since the freeze should have + // prevented concurrent startup DB loads, but it's here for extra safety. + let rerand_lock_conn = rerand_store::acquire_apply_lock(&iris_store.pool).await?; + + if shutdown_handler.is_shutting_down() { + rerand_store::release_apply_lock(rerand_lock_conn).await?; + return Ok::<_, eyre::Report>(None); + } - let mut hawk_actor = init_hawk_actor(&config, &shutdown_handler).await?; + let startup_result = async { + let mut hawk_actor = init_hawk_actor(&config, &shutdown_handler).await?; - if let Some(url) = config.get_anon_stats_db_url() { - let schema = config.get_anon_stats_db_schema(); - let anon_client = - AnonStatsPgClient::new(&url, &schema, AnonStatsAccessMode::ReadWrite).await?; - let anon_store = AnonStatsStore::new(&anon_client).await?; - hawk_actor.set_anon_stats_store(Some(anon_store)); - } else { - tracing::warn!( - "Anon stats persistence enabled but no anon stats database configured; skipping DB writes" - ); + if let Some(url) = config.get_anon_stats_db_url() { + let schema = config.get_anon_stats_db_schema(); + let anon_client = + AnonStatsPgClient::new(&url, &schema, AnonStatsAccessMode::ReadWrite).await?; + let anon_store = AnonStatsStore::new(&anon_client).await?; + hawk_actor.set_anon_stats_store(Some(anon_store)); + } else { + tracing::warn!( + "Anon stats persistence enabled but no anon stats database configured; skipping DB writes" + ); + } + + load_database( + &config, + &iris_store, + &graph_store, + &shutdown_handler, + &mut hawk_actor, + ) + .await?; + Ok::<_, eyre::Report>(hawk_actor) + } + .await; + + rerand_store::release_apply_lock(rerand_lock_conn).await?; + Ok(Some(startup_result)) } + .await; - load_database( - &config, - &iris_store, - &graph_store, - &shutdown_handler, - &mut hawk_actor, - ) - .await?; + // Always attempt freeze release, but never let its failure undo a + // successful startup. `release_rerand_freeze` already retries internally, + // and a subsequent startup will re-issue a freeze with a new generation + // that the worker re-acknowledges (see `check_and_handle_freeze` generation + // change handling). + if let Err(e) = rerand_store::release_rerand_freeze(&iris_store.pool).await { + tracing::error!( + "Failed to release rerand freeze after startup: {:?}. \ + Worker will re-acknowledge on next startup freeze.", + e + ); + } + + let hawk_actor = match frozen_result? { + None => return Ok(()), + Some(r) => r?, + }; background_tasks.check_tasks(); @@ -392,11 +495,14 @@ async fn build_sync_state( tracing::info!("Database store length is: {}", db_len); + let rerand_state = rerand_store::build_rerand_sync_state(&store.pool).await?; + Ok(SyncState { db_len, modifications, next_sns_sequence_num, common_config, + rerand_state, }) } diff --git a/iris-mpc/src/services/processors/batch.rs b/iris-mpc/src/services/processors/batch.rs index 9be6f150ec..5199f7a357 100644 --- a/iris-mpc/src/services/processors/batch.rs +++ b/iris-mpc/src/services/processors/batch.rs @@ -422,8 +422,6 @@ impl<'a> BatchProcessor<'a> { .string_value() .ok_or(ReceiveRequestError::NoMessageTypeAttribute)?; - self.delete_message(&sqs_message).await?; - let res = match request_type { IDENTITY_DELETION_MESSAGE_TYPE => { self.process_identity_deletion(&message, batch_metadata) @@ -435,6 +433,13 @@ impl<'a> BatchProcessor<'a> { } REAUTH_MESSAGE_TYPE => self.process_reauth_request(&message, batch_metadata).await, RECOVERY_CHECK_MESSAGE_TYPE => { + if !self.config.enable_recovery { + metrics::counter!("request.skipped", "type" => "recovery_check").increment(1); + tracing::warn!("Recovery checks are disabled, skipping recovery check request"); + self.delete_message(&sqs_message).await?; + self.msg_counter += 1; + return Ok(()); + } self.process_identity_match_check_request( &message, batch_metadata, @@ -444,6 +449,13 @@ impl<'a> BatchProcessor<'a> { .await } RESET_CHECK_MESSAGE_TYPE => { + if !self.config.enable_reset { + metrics::counter!("request.skipped", "type" => "reset_check").increment(1); + tracing::warn!("Resets are disabled, skipping reset request"); + self.delete_message(&sqs_message).await?; + self.msg_counter += 1; + return Ok(()); + } self.process_identity_match_check_request( &message, batch_metadata, @@ -472,12 +484,19 @@ impl<'a> BatchProcessor<'a> { } _ => { tracing::error!("Error: {}", ReceiveRequestError::InvalidMessageType); - Ok(()) + self.delete_message(&sqs_message).await?; + self.msg_counter += 1; + return Ok(()); } }; + // Only delete from SQS after the message has been successfully + // processed and the modification row is durably persisted. If we + // crash before this point, SQS will redeliver the message. + res?; + self.delete_message(&sqs_message).await?; self.msg_counter += 1; - res + Ok(()) } async fn process_identity_deletion( @@ -1108,10 +1127,13 @@ impl<'a> BatchProcessor<'a> { &self, sqs_message: &aws_sdk_sqs::types::Message, ) -> Result<(), ReceiveRequestError> { + let receipt_handle = sqs_message.receipt_handle.as_deref().ok_or_else(|| { + ReceiveRequestError::FailedToMarkRequestAsDeleted(eyre::eyre!("Missing receipt handle")) + })?; self.client .delete_message() .queue_url(&self.config.requests_queue_url) - .receipt_handle(sqs_message.receipt_handle.as_ref().unwrap()) + .receipt_handle(receipt_handle) .send() .await .map_err(ReceiveRequestError::from)?; diff --git a/iris-mpc/src/services/processors/job.rs b/iris-mpc/src/services/processors/job.rs index e0d00858e5..534ee14488 100644 --- a/iris-mpc/src/services/processors/job.rs +++ b/iris-mpc/src/services/processors/job.rs @@ -297,6 +297,10 @@ pub async fn process_job_result( let persist_total_start = Instant::now(); let mut iris_tx = store.tx().await?; + if !config.disable_persistence { + iris_mpc_store::rerand::acquire_modify_lock(&mut iris_tx).await?; + } + if !codes_and_masks.is_empty() && !config.disable_persistence { let step_start = Instant::now(); let db_serial_ids = store.insert_irises(&mut iris_tx, &codes_and_masks).await?; diff --git a/iris-mpc/src/services/processors/modifications_sync.rs b/iris-mpc/src/services/processors/modifications_sync.rs index 8bde260a35..e932981c89 100644 --- a/iris-mpc/src/services/processors/modifications_sync.rs +++ b/iris-mpc/src/services/processors/modifications_sync.rs @@ -40,20 +40,31 @@ pub async fn sync_modifications( // Sort modifications in id order to_update.sort_by_key(|m| m.id); - // Update node_id for each modification and collect &refs - let to_update_refs: Vec<&Modification> = to_update - .iter_mut() - .map(|modification| { - if let Err(e) = modification.update_result_message_node_id(config.party_id) { - tracing::error!("Failed to update modification node_id: {:?}", e); - } - &*modification - }) - .collect(); + // Update node_id for each modification (mutable pass) + for modification in &mut to_update { + if let Err(e) = modification.update_result_message_node_id(config.party_id) { + tracing::error!("Failed to update modification node_id: {:?}", e); + } + } let mut iris_tx = store.tx().await?; + // Acquire the modification lock to serialize with rerand apply. + iris_mpc_store::rerand::acquire_modify_lock(&mut iris_tx).await?; + + // Ensure recovered modification rows exist locally (completed on peers + // but missing here). Inserted with persisted=false so the loop below + // fetches shares and writes iris data before marking persisted=true. + for m in &to_update { + let mut staging = m.clone(); + staging.persisted = false; + store + .upsert_recovered_modification(&mut iris_tx, &staging) + .await?; + } + // Persist changes into modifications table + let to_update_refs: Vec<&Modification> = to_update.iter().collect(); store .update_modifications(&mut iris_tx, &to_update_refs) .await?; @@ -86,16 +97,19 @@ pub async fn sync_modifications( | RESET_UPDATE_MESSAGE_TYPE | RECOVERY_UPDATE_MESSAGE_TYPE | UNIQUENESS_MESSAGE_TYPE => { + let s3_url = modification.s3_url.clone().ok_or_else(|| { + eyre!("Persisted modification missing s3_url: {:?}", modification) + })?; + let (left_shares, right_shares) = get_iris_shares_parse_task( config.party_id, shares_encryption_key_pair.clone(), Arc::clone(&semaphore), aws_clients.s3_client.clone(), config.shares_bucket_name.clone(), - modification.clone().s3_url.unwrap(), + s3_url, )? - .await? - .unwrap(); + .await??; ( left_shares.code, left_shares.mask, @@ -104,7 +118,7 @@ pub async fn sync_modifications( ) } _ => { - panic!("Unknown modification type: {:?}", modification); + return Err(eyre!("Unknown modification type: {:?}", modification)); } }; @@ -125,7 +139,7 @@ pub async fn sync_modifications( if let Some(serialized) = &modification.graph_mutation { let single_mutation: SingleHawkMutation = bincode::deserialize::(serialized) - .expect("Failed to deserialize SingleHawkMutation"); + .map_err(|e| eyre!("Failed to deserialize SingleHawkMutation: {}", e))?; graph_mutations.push(single_mutation.clone()); } } @@ -167,7 +181,6 @@ pub async fn send_last_modifications_to_sns( let recovery_check_message_attributes = create_message_type_attribute_map(RECOVERY_CHECK_MESSAGE_TYPE); - // Fetch the last modifications from the database let last_modifications = store.last_modifications(lookback).await?; tracing::info!( "Replaying last {} modification results to SNS", @@ -179,7 +192,6 @@ pub async fn send_last_modifications_to_sns( return Ok(()); } - // Collect messages by type let mut deletion_messages = Vec::new(); let mut reauth_messages = Vec::new(); let mut reset_update_messages = Vec::new(); diff --git a/migrations/20260226000001_add_rerand_epoch.down.sql b/migrations/20260226000001_add_rerand_epoch.down.sql new file mode 100644 index 0000000000..97fde78bcd --- /dev/null +++ b/migrations/20260226000001_add_rerand_epoch.down.sql @@ -0,0 +1,14 @@ +ALTER TABLE irises DROP COLUMN IF EXISTS rerand_epoch; + +CREATE OR REPLACE FUNCTION increment_version_id() +RETURNS TRIGGER AS $$ +BEGIN + IF (OLD.left_code IS DISTINCT FROM NEW.left_code OR + OLD.left_mask IS DISTINCT FROM NEW.left_mask OR + OLD.right_code IS DISTINCT FROM NEW.right_code OR + OLD.right_mask IS DISTINCT FROM NEW.right_mask) THEN + NEW.version_id = COALESCE(OLD.version_id, 0) + 1; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/migrations/20260226000001_add_rerand_epoch.up.sql b/migrations/20260226000001_add_rerand_epoch.up.sql new file mode 100644 index 0000000000..ac3822e7e3 --- /dev/null +++ b/migrations/20260226000001_add_rerand_epoch.up.sql @@ -0,0 +1,15 @@ +ALTER TABLE irises ADD COLUMN IF NOT EXISTS rerand_epoch INTEGER NOT NULL DEFAULT 0; + +CREATE OR REPLACE FUNCTION increment_version_id() +RETURNS TRIGGER AS $$ +BEGIN + IF (OLD.left_code IS DISTINCT FROM NEW.left_code OR + OLD.left_mask IS DISTINCT FROM NEW.left_mask OR + OLD.right_code IS DISTINCT FROM NEW.right_code OR + OLD.right_mask IS DISTINCT FROM NEW.right_mask) + AND NEW.rerand_epoch IS NOT DISTINCT FROM OLD.rerand_epoch THEN + NEW.version_id = COALESCE(OLD.version_id, 0) + 1; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/migrations/20260226000002_create_rerand_progress.down.sql b/migrations/20260226000002_create_rerand_progress.down.sql new file mode 100644 index 0000000000..791f86c6c2 --- /dev/null +++ b/migrations/20260226000002_create_rerand_progress.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS rerand_progress; diff --git a/migrations/20260226000002_create_rerand_progress.up.sql b/migrations/20260226000002_create_rerand_progress.up.sql new file mode 100644 index 0000000000..214c4da9c9 --- /dev/null +++ b/migrations/20260226000002_create_rerand_progress.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS rerand_progress ( + epoch INTEGER NOT NULL, + chunk_id INTEGER NOT NULL, + staging_written BOOLEAN NOT NULL DEFAULT FALSE, + all_confirmed BOOLEAN NOT NULL DEFAULT FALSE, + live_applied BOOLEAN NOT NULL DEFAULT FALSE, + PRIMARY KEY (epoch, chunk_id) +); diff --git a/migrations/20260226000003_create_rerand_staging.down.sql b/migrations/20260226000003_create_rerand_staging.down.sql new file mode 100644 index 0000000000..f5df9a0640 --- /dev/null +++ b/migrations/20260226000003_create_rerand_staging.down.sql @@ -0,0 +1,7 @@ +DO $$ +DECLARE + staging_schema TEXT; +BEGIN + staging_schema := current_schema() || '_rerand_staging'; + EXECUTE format('DROP SCHEMA IF EXISTS %I CASCADE', staging_schema); +END $$; diff --git a/migrations/20260226000003_create_rerand_staging.up.sql b/migrations/20260226000003_create_rerand_staging.up.sql new file mode 100644 index 0000000000..9f944816a6 --- /dev/null +++ b/migrations/20260226000003_create_rerand_staging.up.sql @@ -0,0 +1,23 @@ +DO $$ +DECLARE + staging_schema TEXT; +BEGIN + staging_schema := current_schema() || '_rerand_staging'; + EXECUTE format('CREATE SCHEMA IF NOT EXISTS %I', staging_schema); + EXECUTE format('CREATE TABLE IF NOT EXISTS %I.irises ( + epoch INTEGER NOT NULL, + id BIGINT NOT NULL, + chunk_id INTEGER NOT NULL, + left_code BYTEA, + left_mask BYTEA, + right_code BYTEA, + right_mask BYTEA, + original_version_id SMALLINT, + rerand_epoch INTEGER, + PRIMARY KEY (epoch, id) + )', staging_schema); + EXECUTE format( + 'CREATE INDEX IF NOT EXISTS idx_staging_irises_epoch_chunk ON %I.irises (epoch, chunk_id)', + staging_schema + ); +END $$; diff --git a/migrations/20260226000004_create_rerand_control.down.sql b/migrations/20260226000004_create_rerand_control.down.sql new file mode 100644 index 0000000000..831afe63fe --- /dev/null +++ b/migrations/20260226000004_create_rerand_control.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS rerand_control; diff --git a/migrations/20260226000004_create_rerand_control.up.sql b/migrations/20260226000004_create_rerand_control.up.sql new file mode 100644 index 0000000000..a970f19231 --- /dev/null +++ b/migrations/20260226000004_create_rerand_control.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS rerand_control ( + id INTEGER PRIMARY KEY DEFAULT 1 CHECK (id = 1), + freeze_requested BOOLEAN NOT NULL DEFAULT FALSE, + freeze_generation TEXT, + frozen_generation TEXT +); + +INSERT INTO rerand_control (id) VALUES (1) ON CONFLICT DO NOTHING; diff --git a/migrations/20260226000005_add_worker_heartbeat.down.sql b/migrations/20260226000005_add_worker_heartbeat.down.sql new file mode 100644 index 0000000000..f8b72d943c --- /dev/null +++ b/migrations/20260226000005_add_worker_heartbeat.down.sql @@ -0,0 +1 @@ +ALTER TABLE rerand_control DROP COLUMN IF EXISTS worker_last_heartbeat; diff --git a/migrations/20260226000005_add_worker_heartbeat.up.sql b/migrations/20260226000005_add_worker_heartbeat.up.sql new file mode 100644 index 0000000000..73c4c14e1a --- /dev/null +++ b/migrations/20260226000005_add_worker_heartbeat.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE rerand_control + ADD COLUMN IF NOT EXISTS worker_last_heartbeat TIMESTAMPTZ; diff --git a/spec/iris_mpc_server.qnt b/spec/iris_mpc_server.qnt new file mode 100644 index 0000000000..efee8cee29 --- /dev/null +++ b/spec/iris_mpc_server.qnt @@ -0,0 +1,1235 @@ +// Formal specification of the iris-mpc server synchronization protocol. +// +// Models 3 MPC server nodes coordinating to process iris biometric queries. +// Each node has its own database of iris shares and a modification log. +// The spec covers: +// - Server lifecycle (startup, sync, processing, shutdown) +// - Startup sync protocol (compare_modifications algorithm) +// - Batch processing with cross-node batch sync and valid_entries filtering +// - Crash/restart recovery and graceful shutdown +// - Deletion as tombstone (overwrite, not removal) +// - Reauth success/failure (persisted only on success) +// +// MPC computations are abstracted as a nondeterministic oracle. +// +// Reference implementation: +// - sync.rs: compare_modifications (lines 218-321) +// - server/mod.rs: server_main lifecycle (lines 60-211) +// - modifications_sync.rs: sync_modifications (lines 23-149) +// - batch.rs: receive_batch, sync_batch_entries, valid_entries AND logic +// - job.rs: process_job_result, deletion tombstones, reauth success/fail + +module iris_mpc_server { + + // --------------------------------------------------------------------------- + // Constants + // --------------------------------------------------------------------------- + + const NODES: Set[int] + const MAX_DB_SIZE: int + const MAX_BATCH_SIZE: int + const LOOKBACK: int + + // --------------------------------------------------------------------------- + // Types + // --------------------------------------------------------------------------- + + type RequestType = + | Uniqueness + | Reauth + | Deletion + | Update + + type ModStatus = + | InProgress + | Completed + + type Modification = { + id: int, + serial_id: int, // 0 = unassigned (for uniqueness before MPC) + request_type: RequestType, + status: ModStatus, + persisted: bool + } + + type BatchRequest = { + seq_num: int, + request_id: int, + request_type: RequestType, + target_serial_id: int // 0 for uniqueness, >0 for others + } + + type SyncStateRec = { + db_len: int, + mods: Set[Modification], + next_seq_num: int // 0 = empty queue + } + + type NodePhase = + | Down + | WaitingPeersSync + | SyncingMods + | SyncingQueue + | LoadingDB + | WaitingPeersReady + | Ready + | Processing + | PersistingResults + | ShuttingDown // graceful shutdown: finish pending, reject new + + type RequestResult = { + is_match: bool, // for uniqueness: matched existing iris? + assigned_serial_id: int, // for non-match uniqueness: newly assigned serial_id + success: bool // for reauth: did it succeed? + } + + /// Iris entry status: live data or tombstone (deletion overwrites with dummy). + type IrisStatus = + | Live + | Tombstoned + + // --------------------------------------------------------------------------- + // State variables + // --------------------------------------------------------------------------- + + var phase: int -> NodePhase + /// Iris database per node: serial_id -> status (Live or Tombstoned). + /// Deletions overwrite with dummy data, keeping the slot occupied. + var db: int -> (int -> IrisStatus) + var db_len: int -> int + var mods: int -> Set[Modification] + var queue_cursor: int -> int + var queue: List[BatchRequest] + var next_seq_num: int + var published_sync: int -> SyncStateRec + var node_ready: int -> bool + var current_batch: int -> List[BatchRequest] + /// Per-node valid_entries: which batch entries passed decryption. + /// After AND across nodes, invalid entries are filtered out. + var valid_entries: int -> List[bool] + var batch_results: int -> List[RequestResult] + var global_mod_id: int + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + pure def max_of(s: Set[int], default: int): int = + s.fold(default, (acc, x) => if (x > acc) x else acc) + + pure def min_of(s: Set[int], default: int): int = + s.fold(default, (acc, x) => if (x < acc) x else acc) + + /// All occupied serial IDs (both Live and Tombstoned). + pure def occupied_ids(iris_db: int -> IrisStatus): Set[int] = + iris_db.keys() + + val empty_sync: SyncStateRec = { db_len: 0, mods: Set(), next_seq_num: 0 } + + // --------------------------------------------------------------------------- + // Initialization + // --------------------------------------------------------------------------- + + action init = all { + phase' = NODES.mapBy(_ => Down), + db' = NODES.mapBy(_ => Map()), + db_len' = NODES.mapBy(_ => 0), + mods' = NODES.mapBy(_ => Set()), + queue_cursor' = NODES.mapBy(_ => 0), + queue' = List(), + next_seq_num' = 1, + published_sync' = NODES.mapBy(_ => empty_sync), + node_ready' = NODES.mapBy(_ => false), + current_batch' = NODES.mapBy(_ => List()), + valid_entries' = NODES.mapBy(_ => List()), + batch_results' = NODES.mapBy(_ => List()), + global_mod_id' = 1, + } + + // --------------------------------------------------------------------------- + // Environment: external requests arrive in the queue + // --------------------------------------------------------------------------- + + action enqueue_uniqueness = all { + queue.length() < MAX_DB_SIZE, + queue' = queue.append({ + seq_num: next_seq_num, + request_id: next_seq_num, + request_type: Uniqueness, + target_serial_id: 0 + }), + next_seq_num' = next_seq_num + 1, + phase' = phase, db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + action enqueue_reauth(target: int): bool = all { + target >= 1, + queue.length() < MAX_DB_SIZE, + queue' = queue.append({ + seq_num: next_seq_num, + request_id: next_seq_num, + request_type: Reauth, + target_serial_id: target + }), + next_seq_num' = next_seq_num + 1, + phase' = phase, db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + action enqueue_deletion(target: int): bool = all { + target >= 1, + queue.length() < MAX_DB_SIZE, + queue' = queue.append({ + seq_num: next_seq_num, + request_id: next_seq_num, + request_type: Deletion, + target_serial_id: target + }), + next_seq_num' = next_seq_num + 1, + phase' = phase, db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Startup: node starts up, builds sync state, publishes it + // --------------------------------------------------------------------------- + + action node_start(node: int): bool = + val my_mods = mods.get(node) + val max_mod_id = max_of(my_mods.map(m => m.id), 0) + val lookback_mods = my_mods.filter(m => m.id > max_mod_id - LOOKBACK) + // In the real system, get_next_sns_seq_num peeks at the shared SQS queue. + val queue_next = if (queue.length() > 0) queue[0].seq_num else 0 + val sync_state: SyncStateRec = { + db_len: db_len.get(node), + mods: lookback_mods, + next_seq_num: queue_next + } + all { + phase.get(node) == Down, + phase' = phase.set(node, WaitingPeersSync), + published_sync' = published_sync.set(node, sync_state), + node_ready' = node_ready.set(node, false), + current_batch' = current_batch.set(node, List()), + valid_entries' = valid_entries.set(node, List()), + batch_results' = batch_results.set(node, List()), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Wait for peers, then begin sync + // --------------------------------------------------------------------------- + + action begin_sync(node: int): bool = all { + phase.get(node) == WaitingPeersSync, + NODES.filter(n => phase.get(n) != Down).forall(n => phase.get(n) != Down), + phase' = phase.set(node, SyncingMods), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Sync modifications: the core compare_modifications algorithm + // Mirrors sync.rs lines 218-321 and modifications_sync.rs lines 23-149 + // --------------------------------------------------------------------------- + + action sync_modifications(node: int): bool = + val running = NODES.filter(n => phase.get(n) != Down) + val all_published_mods: Set[Modification] = running.fold(Set(), (acc, n) => + acc.union(published_sync.get(n).mods) + ) + val all_mod_ids: Set[int] = all_published_mods.map(m => m.id) + // Safety check: max completed mod_id diff <= LOOKBACK (sync.rs 227-249) + val completed_max_per_node: Set[int] = running.map(n => + max_of( + published_sync.get(n).mods.filter(m => m.status == Completed).map(m => m.id), + 0 + ) + ).filter(x => x > 0) + val safe = if (completed_max_per_node.size() <= 1) true + else max_of(completed_max_per_node, 0) - min_of(completed_max_per_node, 0) <= LOOKBACK + val my_mods = mods.get(node) + val local_mod_ids = my_mods.map(m => m.id) + // to_delete: mod IDs where ALL nodes have InProgress + val to_delete_ids: Set[int] = all_mod_ids.filter(mid => + running.forall(n => + published_sync.get(n).mods.forall(m => m.id != mid or m.status == InProgress) + ) + ) + // to_update: mod IDs where ANY node has Completed + val to_update_ids: Set[int] = all_mod_ids.filter(mid => + running.exists(n => + published_sync.get(n).mods.exists(m => m.id == mid and m.status == Completed) + ) + ) + // For each to_update mod, find a completed copy + val completed_copies: Set[Modification] = to_update_ids.map(mid => + all_published_mods.filter(m => m.id == mid and m.status == Completed).fold( + { id: mid, serial_id: 0, request_type: Uniqueness, status: Completed, persisted: false }, + (_, m) => m + ) + ) + val kept_mods = my_mods.filter(m => + not(to_delete_ids.contains(m.id)) and not(to_update_ids.contains(m.id)) + ) + val updated_mods = completed_copies.filter(m => local_mod_ids.contains(m.id)) + val new_mods = kept_mods.union(updated_mods) + // Apply iris changes for newly-persisted modifications + val new_uniqueness_inserts: Set[int] = updated_mods + .filter(m => m.persisted and m.request_type == Uniqueness and m.serial_id > 0) + .map(m => m.serial_id) + // Apply deletion tombstones + val new_deletion_targets: Set[int] = updated_mods + .filter(m => m.persisted and m.request_type == Deletion and m.serial_id > 0) + .map(m => m.serial_id) + val my_db = db.get(node) + // Add new uniqueness inserts as Live + val db_with_inserts = new_uniqueness_inserts.fold(my_db, (acc, sid) => + acc.put(sid, Live) + ) + // Mark deletions as Tombstoned + val db_with_deletions = new_deletion_targets.fold(db_with_inserts, (acc, sid) => + if (acc.keys().contains(sid)) acc.set(sid, Tombstoned) else acc + ) + val new_db_len = max_of(db_with_deletions.keys(), 0) + all { + phase.get(node) == SyncingMods, + safe, + mods' = mods.set(node, new_mods), + db' = db.set(node, db_with_deletions), + db_len' = db_len.set(node, new_db_len), + phase' = phase.set(node, SyncingQueue), + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Sync SQS queue: advance cursor to max across all nodes + // --------------------------------------------------------------------------- + + action sync_queue(node: int): bool = + val running = NODES.filter(n => phase.get(n) != Down) + val seq_nums: Set[int] = running.map(n => published_sync.get(n).next_seq_num) + val any_empty = seq_nums.contains(0) + val any_nonempty = seq_nums.exists(s => s > 0) + val max_seq = max_of(seq_nums, 0) + val cleaned_queue = queue.select(r => r.seq_num >= max_seq) + all { + phase.get(node) == SyncingQueue, + not(any_empty and any_nonempty), + queue_cursor' = queue_cursor.set(node, max_seq), + queue' = cleaned_queue, + phase' = phase.set(node, LoadingDB), + db' = db, db_len' = db_len, mods' = mods, + next_seq_num' = next_seq_num, + published_sync' = published_sync, node_ready' = node_ready, + current_batch' = current_batch, valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Load database and signal ready + // --------------------------------------------------------------------------- + + action finish_loading(node: int): bool = all { + phase.get(node) == LoadingDB, + phase' = phase.set(node, WaitingPeersReady), + node_ready' = node_ready.set(node, true), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + current_batch' = current_batch, valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + action all_nodes_ready(node: int): bool = all { + phase.get(node) == WaitingPeersReady, + NODES.filter(n => phase.get(n) != Down).forall(n => node_ready.get(n)), + phase' = phase.set(node, Ready), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Batch processing + // --------------------------------------------------------------------------- + + /// All ready nodes receive the next batch. Each node nondeterministically + /// determines which entries it can decrypt (valid_entries). The valid_entries + /// are AND-ed across all nodes (batch.rs sync_batch_entries, job.rs 241). + action receive_batch = + val ready_nodes = NODES.filter(n => phase.get(n) == Ready) + val min_cursor = min_of(ready_nodes.map(n => queue_cursor.get(n)), 0) + val available = queue.select(r => r.seq_num >= min_cursor) + val batch_size = if (available.length() > MAX_BATCH_SIZE) MAX_BATCH_SIZE + else available.length() + val batch = if (batch_size > 0) available.slice(0, batch_size) else List() + val new_cursor = if (batch_size > 0) batch[batch_size - 1].seq_num + 1 else min_cursor + // Create IN_PROGRESS modifications for each batch entry + val new_modifications: Set[Modification] = range(0, batch_size).foldl(Set(), (acc, i) => + acc.union(Set({ + id: global_mod_id + i, + serial_id: batch[i].target_serial_id, + request_type: batch[i].request_type, + status: InProgress, + persisted: false + })) + ) + all { + ready_nodes == NODES, + batch_size > 0, + // Nondeterministically choose valid_entries per node, then AND them. + // In the real system, decryption failures cause entries to be invalid. + nondet invalid_set = range(0, batch_size).foldl(Set(), (acc, i) => + acc.union(Set(i)) + ).powerset().oneOf() + // AND logic: entry valid only if ALL nodes consider it valid + val and_valid: List[bool] = range(0, batch_size).foldl(List(), (acc, i) => + acc.append(not(invalid_set.contains(i))) + ) + all { + current_batch' = ready_nodes.fold(current_batch, (acc, n) => + acc.set(n, batch) + ), + valid_entries' = ready_nodes.fold(valid_entries, (acc, n) => + acc.set(n, and_valid) + ), + queue_cursor' = ready_nodes.fold(queue_cursor, (acc, n) => + acc.set(n, new_cursor) + ), + phase' = ready_nodes.fold(phase, (acc, n) => + acc.set(n, Processing) + ), + mods' = ready_nodes.fold(mods, (acc, n) => + acc.set(n, acc.get(n).union(new_modifications)) + ), + global_mod_id' = global_mod_id + batch_size, + db' = db, db_len' = db_len, + queue' = queue, next_seq_num' = next_seq_num, + published_sync' = published_sync, node_ready' = node_ready, + batch_results' = batch_results, + } + } + + /// MPC processing: nondeterministic oracle decides results. + /// For uniqueness: nondeterministically choose match/no-match. + /// For reauth: nondeterministically choose success/failure. + /// Invalid entries (valid_entries[i] == false) get match=true to skip them. + action process_batch = + val processing_nodes = NODES.filter(n => phase.get(n) == Processing) + val ref_node = min_of(processing_nodes, 0) + val batch = current_batch.get(ref_node) + val ve = valid_entries.get(ref_node) + val batch_size = batch.length() + val uniqueness_indices: Set[int] = range(0, batch_size).foldl(Set(), (acc, i) => + if (batch[i].request_type == Uniqueness and ve[i]) acc.union(Set(i)) else acc + ) + val reauth_indices: Set[int] = range(0, batch_size).foldl(Set(), (acc, i) => + if (batch[i].request_type == Reauth and ve[i]) acc.union(Set(i)) else acc + ) + all { + processing_nodes == NODES, + batch_size > 0, + nondet match_set = uniqueness_indices.powerset().oneOf() + nondet reauth_fail_set = reauth_indices.powerset().oneOf() + val base_serial = max_of(processing_nodes.map(n => db_len.get(n)), 0) + val results: List[RequestResult] = range(0, batch_size).foldl( + { res: List(), next_id: base_serial + 1 }, + (state, i) => + if (not(ve[i])) { + // Invalid entry: skip (treated as match / no-op) + { res: state.res.append({ is_match: true, assigned_serial_id: 0, success: false }), + next_id: state.next_id } + } else if (batch[i].request_type == Uniqueness) { + if (match_set.contains(i)) { + { res: state.res.append({ is_match: true, assigned_serial_id: 0, success: true }), + next_id: state.next_id } + } else { + { res: state.res.append({ is_match: false, assigned_serial_id: state.next_id, success: true }), + next_id: state.next_id + 1 } + } + } else if (batch[i].request_type == Reauth) { + // Reauth can fail (e.g., no matching iris found) + val succeeded = not(reauth_fail_set.contains(i)) + { res: state.res.append({ is_match: false, assigned_serial_id: 0, success: succeeded }), + next_id: state.next_id } + } else { + // Deletion, Update: always succeed + { res: state.res.append({ is_match: false, assigned_serial_id: 0, success: true }), + next_id: state.next_id } + } + ).res + all { + batch_results' = processing_nodes.fold(batch_results, (acc, n) => + acc.set(n, results) + ), + phase' = processing_nodes.fold(phase, (acc, n) => + acc.set(n, PersistingResults) + ), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + global_mod_id' = global_mod_id, + } + } + + /// Persist results to database (atomic transaction per node). + /// - Non-matching uniqueness: insert new iris as Live + /// - Successful reauth: update existing iris (persisted=true) + /// - Failed reauth: no DB write (persisted=false) + /// - Deletion: overwrite iris with dummy data (Tombstoned), persisted=true + /// - Invalid entries: persisted=false + action persist_results(node: int): bool = + val batch = current_batch.get(node) + val results = batch_results.get(node) + val ve = valid_entries.get(node) + val batch_size = batch.length() + val batch_mod_base = global_mod_id - batch_size + // New iris IDs from non-matching uniqueness (insert as Live) + val new_iris_ids: Set[int] = range(0, batch_size).foldl(Set(), (acc, i) => + if (ve[i] and batch[i].request_type == Uniqueness and not(results[i].is_match) + and results[i].assigned_serial_id > 0) { + acc.union(Set(results[i].assigned_serial_id)) + } else { + acc + } + ) + // Deletion targets: overwrite with tombstone + val deletion_targets: Set[int] = range(0, batch_size).foldl(Set(), (acc, i) => + if (ve[i] and batch[i].request_type == Deletion and batch[i].target_serial_id > 0) { + acc.union(Set(batch[i].target_serial_id)) + } else { + acc + } + ) + // Update modifications to COMPLETED + val my_mods = mods.get(node) + val updated_mods: Set[Modification] = my_mods.map(m => + if (m.id >= batch_mod_base and m.id < global_mod_id and m.status == InProgress) { + val batch_idx = m.id - batch_mod_base + val res = results[batch_idx] + val entry_valid = ve[batch_idx] + val was_persisted = if (not(entry_valid)) false + else if (batch[batch_idx].request_type == Uniqueness) not(res.is_match) + else if (batch[batch_idx].request_type == Reauth) res.success + else true // Deletion, Update: always persisted + val new_serial = if (batch[batch_idx].request_type == Uniqueness and not(res.is_match) and entry_valid) + res.assigned_serial_id + else m.serial_id + { ...m, status: Completed, persisted: was_persisted, serial_id: new_serial } + } else { + m + } + ) + // Apply DB changes + val my_db = db.get(node) + val db_with_inserts = new_iris_ids.fold(my_db, (acc, sid) => acc.put(sid, Live)) + val db_with_tombstones = deletion_targets.fold(db_with_inserts, (acc, sid) => + if (acc.keys().contains(sid)) acc.set(sid, Tombstoned) else acc + ) + val new_db_len = max_of(db_with_tombstones.keys(), 0) + all { + phase.get(node) == PersistingResults, + batch_size > 0, + db' = db.set(node, db_with_tombstones), + db_len' = db_len.set(node, new_db_len), + mods' = mods.set(node, updated_mods), + phase' = phase.set(node, Ready), + current_batch' = current_batch.set(node, List()), + valid_entries' = valid_entries.set(node, List()), + batch_results' = batch_results.set(node, List()), + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + node_ready' = node_ready, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Crash, shutdown, restart + // --------------------------------------------------------------------------- + + /// A node crashes. In-memory state is lost but DB persists. + action crash(node: int): bool = all { + phase.get(node) != Down, + phase' = phase.set(node, Down), + node_ready' = node_ready.set(node, false), + current_batch' = current_batch.set(node, List()), + valid_entries' = valid_entries.set(node, List()), + batch_results' = batch_results.set(node, List()), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + global_mod_id' = global_mod_id, + } + + /// Graceful shutdown: node finishes pending batch then stops. + /// If in PersistingResults, persist first then go Down. + /// Otherwise, go Down immediately (main loop won't start new batches). + action graceful_shutdown(node: int): bool = all { + phase.get(node) == Ready or phase.get(node) == WaitingPeersReady, + phase' = phase.set(node, Down), + node_ready' = node_ready.set(node, false), + current_batch' = current_batch.set(node, List()), + valid_entries' = valid_entries.set(node, List()), + batch_results' = batch_results.set(node, List()), + db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Stuttering + // --------------------------------------------------------------------------- + + action stuttering = all { + phase' = phase, db' = db, db_len' = db_len, mods' = mods, + queue_cursor' = queue_cursor, queue' = queue, + next_seq_num' = next_seq_num, published_sync' = published_sync, + node_ready' = node_ready, current_batch' = current_batch, + valid_entries' = valid_entries, + batch_results' = batch_results, global_mod_id' = global_mod_id, + } + + // --------------------------------------------------------------------------- + // Step + // --------------------------------------------------------------------------- + + action step = any { + enqueue_uniqueness, + nondet target = 1.to(MAX_DB_SIZE).oneOf() + enqueue_reauth(target), + nondet target = 1.to(MAX_DB_SIZE).oneOf() + enqueue_deletion(target), + + nondet node = NODES.oneOf() + any { + node_start(node), + begin_sync(node), + sync_modifications(node), + sync_queue(node), + finish_loading(node), + all_nodes_ready(node), + persist_results(node), + crash(node), + graceful_shutdown(node), + }, + + receive_batch, + process_batch, + stuttering, + } + + // --------------------------------------------------------------------------- + // Invariants + // --------------------------------------------------------------------------- + + /// Serial ID contiguity: occupied IDs form range 1..db_len. + /// Tombstoned entries still occupy their slot (deletion does not compact). + val serial_id_contiguity: bool = + NODES.forall(n => db.get(n).keys() == 1.to(db_len.get(n))) + + /// DB length equals count of occupied slots (including tombstones). + val db_len_consistent: bool = + NODES.forall(n => db_len.get(n) == db.get(n).keys().size()) + + /// DB size bounded. + val db_size_bounded: bool = + NODES.forall(n => db_len.get(n) <= MAX_DB_SIZE) + + /// Modification ID safety window across running nodes. + val mod_id_safety_window: bool = { + val running = NODES.filter(n => phase.get(n) != Down) + val completed_maxes: Set[int] = running.map(n => + max_of(mods.get(n).filter(m => m.status == Completed).map(m => m.id), 0) + ).filter(x => x > 0) + if (completed_maxes.size() <= 1) true + else max_of(completed_maxes, 0) - min_of(completed_maxes, 0) <= LOOKBACK + } + + /// Batch consistency: all batch-phase nodes have same batch. + val batch_consistency: bool = { + val batch_nodes = NODES.filter(n => + phase.get(n) == Processing or phase.get(n) == PersistingResults + ) + batch_nodes.forall(n1 => batch_nodes.forall(n2 => + current_batch.get(n1) == current_batch.get(n2) + )) + } + + /// When all nodes are Ready, their completed mods agree. + val ready_state_consistency: bool = { + val ready_nodes = NODES.filter(n => phase.get(n) == Ready) + if (ready_nodes.size() < 2) true + else ready_nodes.forall(n1 => ready_nodes.forall(n2 => + mods.get(n1).filter(m => m.status == Completed).map(m => m.id) + == + mods.get(n2).filter(m => m.status == Completed).map(m => m.id) + )) + } + + /// When all nodes are Ready, their databases are identical. + val ready_db_consistency: bool = { + val ready_nodes = NODES.filter(n => phase.get(n) == Ready) + if (ready_nodes.size() < 2) true + else ready_nodes.forall(n1 => ready_nodes.forall(n2 => + db.get(n1) == db.get(n2) and db_len.get(n1) == db_len.get(n2) + )) + } + + /// Completed persisted uniqueness mods must have their iris in db. + val persisted_mods_have_iris: bool = + NODES.forall(n => + mods.get(n).forall(m => + not(m.status == Completed and m.persisted and m.request_type == Uniqueness and m.serial_id > 0) + or + db.get(n).keys().contains(m.serial_id) + ) + ) + + /// If two nodes both completed the same modification (by ID), + /// they must agree on serial_id and persisted flag. + val completed_mods_agree: bool = + NODES.forall(n1 => NODES.forall(n2 => + mods.get(n1).forall(m1 => + mods.get(n2).forall(m2 => + not(m1.id == m2.id and m1.status == Completed and m2.status == Completed) + or + (m1.serial_id == m2.serial_id and m1.persisted == m2.persisted) + ) + ) + )) + + /// Same mod ID across nodes must have same request_type. + val mod_ids_consistent: bool = + NODES.forall(n1 => NODES.forall(n2 => + mods.get(n1).forall(m1 => + mods.get(n2).forall(m2 => + not(m1.id == m2.id) + or + m1.request_type == m2.request_type + ) + ) + )) + + /// No two persisted uniqueness mods on same node have same serial_id. + val no_duplicate_serial_ids: bool = + NODES.forall(n => + mods.get(n).forall(m1 => + mods.get(n).forall(m2 => + not(m1.id != m2.id + and m1.request_type == Uniqueness and m2.request_type == Uniqueness + and m1.persisted and m2.persisted + and m1.serial_id > 0 and m2.serial_id > 0) + or + m1.serial_id != m2.serial_id + ) + ) + ) + + /// Deleted iris entries must be Tombstoned, not Live. + /// If a completed persisted deletion mod references a serial_id, + /// that entry should be Tombstoned in the node's DB. + val deletions_are_tombstoned: bool = + NODES.forall(n => + mods.get(n).forall(m => + not(m.status == Completed and m.persisted and m.request_type == Deletion + and m.serial_id > 0 and db.get(n).keys().contains(m.serial_id)) + or + db.get(n).get(m.serial_id) == Tombstoned + ) + ) + + /// Failed reauths must NOT be persisted. + val failed_reauth_not_persisted: bool = + NODES.forall(n => + mods.get(n).forall(m => + // If it's a completed reauth with persisted=false, that's fine (failed). + // What we check: if persisted=true, there must be a matching iris. + not(m.status == Completed and m.request_type == Reauth + and m.persisted and m.serial_id > 0) + or + db.get(n).keys().contains(m.serial_id) + ) + ) + + /// Valid entries consistency: during batch processing, all nodes have + /// the same valid_entries vector (AND logic ensures consensus). + val valid_entries_consistency: bool = { + val batch_nodes = NODES.filter(n => + phase.get(n) == Processing or phase.get(n) == PersistingResults + ) + batch_nodes.forall(n1 => batch_nodes.forall(n2 => + valid_entries.get(n1) == valid_entries.get(n2) + )) + } + + val all_invariants: bool = and { + serial_id_contiguity, + db_len_consistent, + db_size_bounded, + mod_id_safety_window, + batch_consistency, + ready_state_consistency, + ready_db_consistency, + persisted_mods_have_iris, + completed_mods_agree, + mod_ids_consistent, + no_duplicate_serial_ids, + deletions_are_tombstoned, + failed_reauth_not_persisted, + valid_entries_consistency, + } + + // --------------------------------------------------------------------------- + // Test runs + // --------------------------------------------------------------------------- + + /// Happy path: 3 nodes start, sync (empty state), process one batch. + run happy_path_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + .then(all { + assert(all_invariants), + stuttering, + }) + } + + /// Two sequential batches, second includes a deletion. + run two_batches_with_deletion_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + // First batch: insert 2 irises + .then(enqueue_uniqueness) + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + // Second batch: delete iris 1 + .then(enqueue_deletion(1)) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + .then(all { + assert(all_invariants), + // Serial ID 1 should still exist but be Tombstoned + assert(deletions_are_tombstoned), + stuttering, + }) + } + + /// Node 2 crashes after MPC but before persist. Recovery via sync. + run crash_after_processing_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(crash(2)) + .then(persist_results(0)) + .then(persist_results(1)) + .then(crash(0)) + .then(crash(1)) + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(all { + assert(ready_db_consistency), + assert(ready_state_consistency), + assert(persisted_mods_have_iris), + stuttering, + }) + } + + /// Crash recovery: node 2 crashes after persist_results of 0,1. + run crash_recovery_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(crash(2)) + .then(crash(0)) + .then(crash(1)) + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(all { + assert(ready_db_consistency), + assert(ready_state_consistency), + stuttering, + }) + } + + /// Multiple batches with a crash between them. + run multi_batch_crash_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + // Batch 1: all succeed + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + // Batch 2: node 2 crashes after persist of 0,1 + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(crash(2)) + // Full restart and sync + .then(crash(0)) + .then(crash(1)) + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(all { + assert(all_invariants), + stuttering, + }) + } + + /// Graceful shutdown test: node shuts down cleanly. + run graceful_shutdown_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + // Graceful shutdown of node 0 + .then(graceful_shutdown(0)) + // Restart all + .then(crash(1)) + .then(crash(2)) + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(all { + assert(all_invariants), + stuttering, + }) + } + /// Mixed batch: uniqueness + reauth + deletion in sequence. + run mixed_operations_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + // Insert 3 irises + .then(enqueue_uniqueness) + .then(enqueue_uniqueness) + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + // Reauth iris 1, delete iris 2 + .then(enqueue_reauth(1)) + .then(enqueue_deletion(2)) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + .then(all { + assert(all_invariants), + stuttering, + }) + } + + /// Crash after mixed batch with deletion, verify tombstone survives sync. + run deletion_crash_recovery_test = { + init + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + // Insert iris + .then(enqueue_uniqueness) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(persist_results(2)) + // Delete iris 1, node 2 crashes before persist + .then(enqueue_deletion(1)) + .then(receive_batch) + .then(process_batch) + .then(persist_results(0)) + .then(persist_results(1)) + .then(crash(2)) + // Restart and sync + .then(crash(0)) + .then(crash(1)) + .then(node_start(0)) + .then(node_start(1)) + .then(node_start(2)) + .then(begin_sync(0)) + .then(begin_sync(1)) + .then(begin_sync(2)) + .then(sync_modifications(0)) + .then(sync_modifications(1)) + .then(sync_modifications(2)) + .then(sync_queue(0)) + .then(sync_queue(1)) + .then(sync_queue(2)) + .then(finish_loading(0)) + .then(finish_loading(1)) + .then(finish_loading(2)) + .then(all_nodes_ready(0)) + .then(all_nodes_ready(1)) + .then(all_nodes_ready(2)) + .then(all { + assert(ready_db_consistency), + assert(ready_state_consistency), + assert(deletions_are_tombstoned), + stuttering, + }) + } +} + +// --------------------------------------------------------------------------- +// Concrete instance for model checking +// --------------------------------------------------------------------------- + +module iris_mpc_server_3 { + import iris_mpc_server( + NODES = Set(0, 1, 2), + MAX_DB_SIZE = 8, + MAX_BATCH_SIZE = 3, + LOOKBACK = 6 + ).* +}