From 35fc298e54262951862d003b8a9747c1a7007da8 Mon Sep 17 00:00:00 2001 From: Copybara Bot Date: Mon, 22 Jun 2026 11:40:56 +0000 Subject: [PATCH] Sync 9538b42b574165e963e5cbd095ce5193a173a1d9 FolderOrigin-RevId: 9538b42b574165e963e5cbd095ce5193a173a1d9 --- Cargo.lock | 40 ++++ pkg/aggregator-cli/Cargo.toml | 32 +++ pkg/aggregator-cli/src/config.rs | 196 ++++++++++++++++ pkg/aggregator-cli/src/error.rs | 37 +++ pkg/aggregator-cli/src/main.rs | 204 +++++++++++++++++ pkg/aggregator-cli/src/runner.rs | 272 ++++++++++++++++++++++ pkg/aggregator-cli/src/tests.rs | 371 +++++++++++++++++++++++++++++++ 7 files changed, 1152 insertions(+) create mode 100644 pkg/aggregator-cli/Cargo.toml create mode 100644 pkg/aggregator-cli/src/config.rs create mode 100644 pkg/aggregator-cli/src/error.rs create mode 100644 pkg/aggregator-cli/src/main.rs create mode 100644 pkg/aggregator-cli/src/runner.rs create mode 100644 pkg/aggregator-cli/src/tests.rs diff --git a/Cargo.lock b/Cargo.lock index d69c599..5967688 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -360,6 +360,35 @@ dependencies = [ "zk-primitives", ] +[[package]] +name = "aggregator-cli" +version = "0.1.0" +dependencies = [ + "aggregator", + "aggregator-interface", + "async-fn-stream", + "async-trait", + "barretenberg-api-client", + "barretenberg-cli", + "barretenberg-interface", + "clap", + "contextful", + "contracts", + "element", + "futures", + "node-client-http", + "node-interface", + "primitives", + "rpc", + "thiserror 1.0.69", + "tokio", + "tracing", + "unimock", + "url", + "zk-circuits", + "zk-primitives", +] + [[package]] name = "aggregator-interface" version = "0.1.0" @@ -1478,6 +1507,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-fn-stream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4ba0c4baf81a0d8ab31618ffa3ae29ceeb970a6d0d82f76130753462e39d0ea" +dependencies = [ + "futures-util", + "pin-project-lite", + "smallvec", +] + [[package]] name = "async-io" version = "2.6.0" diff --git a/pkg/aggregator-cli/Cargo.toml b/pkg/aggregator-cli/Cargo.toml new file mode 100644 index 0000000..f55f57b --- /dev/null +++ b/pkg/aggregator-cli/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "aggregator-cli" +version = "0.1.0" +edition = "2024" +publish = false + +[dependencies] +aggregator = { workspace = true } +aggregator-interface = { workspace = true } +barretenberg-api-client = { workspace = true } +barretenberg-cli = { workspace = true } +barretenberg-interface = { workspace = true } +node-client-http = { workspace = true } +node-interface = { workspace = true } +contracts = { workspace = true } +primitives = { workspace = true } +element = { workspace = true } +zk-circuits = { workspace = true } +clap = { workspace = true } +contextful = { workspace = true } +async-fn-stream = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] } +tracing = { workspace = true } +rpc = { workspace = true } +url = { workspace = true } + +[dev-dependencies] +unimock = { workspace = true } +zk-primitives = { workspace = true } diff --git a/pkg/aggregator-cli/src/config.rs b/pkg/aggregator-cli/src/config.rs new file mode 100644 index 0000000..28e568e --- /dev/null +++ b/pkg/aggregator-cli/src/config.rs @@ -0,0 +1,196 @@ +use clap::Parser; +use element::Element; +use rpc::tracing::{LogFormat, LogLevel}; +use std::{collections::BTreeMap, str::FromStr}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct RootHeightOverride { + pub(crate) root_hash: Element, + pub(crate) height: u64, +} + +pub(crate) fn parse_root_height_override( + raw: &str, +) -> std::result::Result { + let Some((raw_root_hash, raw_height)) = raw.split_once('=') else { + return Err( + "invalid override format, expected ROOT_HASH=HEIGHT (e.g. 0xabc123=42)".to_owned(), + ); + }; + + let root_hash = Element::from_str(raw_root_hash.trim()) + .map_err(|_| format!("invalid root hash `{}` in override", raw_root_hash.trim()))?; + let height = raw_height + .trim() + .parse::() + .map_err(|_| format!("invalid height `{}` in override", raw_height.trim()))?; + + Ok(RootHeightOverride { root_hash, height }) +} + +pub(crate) fn select_rollup_tree_seed_height( + contract_height: u64, + contract_root: Element, + overrides: &[RootHeightOverride], +) -> u64 { + overrides + .iter() + .find_map(|entry| (entry.root_hash == contract_root).then_some(entry.height)) + .unwrap_or(contract_height) +} + +pub(crate) fn validate_rollup_root_height_overrides( + overrides: &[RootHeightOverride], +) -> std::result::Result<(), String> { + let mut heights_by_root = BTreeMap::::new(); + + for override_entry in overrides { + if let Some(existing_height) = heights_by_root.get(&override_entry.root_hash) + && *existing_height != override_entry.height + { + return Err(format!( + "conflicting heights for root hash 0x{}: {} and {}", + override_entry.root_hash.to_hex(), + existing_height, + override_entry.height + )); + } + + heights_by_root.insert(override_entry.root_hash, override_entry.height); + } + + Ok(()) +} + +#[derive(Debug, Parser)] +#[command(author, version, about = "Polybase Aggregator CLI", long_about = None)] +pub struct Config { + /// Node RPC URL + #[arg(long, env = "NODE_RPC_URL", default_value = "http://localhost:8080")] + pub node_rpc_url: String, + + /// Log level + #[arg(long, env = "LOG_LEVEL", default_value = "INFO")] + pub log_level: LogLevel, + + /// Log format + #[arg(long, env = "LOG_FORMAT", default_value = "PRETTY")] + pub log_format: LogFormat, + + /// Environment name + #[arg(long, env = "ENV_NAME", default_value = "dev")] + pub env_name: String, + + /// Ethereum RPC URL used to interact with the rollup contract + #[arg(long, env = "EVM_RPC_URL", default_value = "http://localhost:8545")] + pub evm_rpc_url: String, + + /// Hex-encoded rollup contract address + #[arg( + long, + env = "ROLLUP_CONTRACT_ADDRESS", + default_value = "0xdc64a140aa3e981100a9beca4e685f962f0cf6c9" + )] + pub rollup_contract_address: String, + + /// Hex-encoded (0x-prefixed) secret key for submitting transactions + #[arg( + long, + env = "EVM_SECRET_KEY", + default_value = "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80" + )] + pub evm_secret_key: String, + + /// Ethereum chain id for the rollup contract + #[arg(long, env = "CHAIN_ID", default_value_t = 1337)] + pub chain_id: u64, + + /// Optional minimum gas price (in gwei) to use when submitting rollup proofs + #[arg(long, env = "MINIMUM_GAS_PRICE_GWEI")] + pub minimum_gas_price_gwei: Option, + + /// Poll interval (milliseconds) between aggregation attempts when running continuously + #[arg(long, env = "POLL_INTERVAL_MS", default_value_t = 5_000)] + pub poll_interval_ms: u64, + + /// Number of blocks to aggregate per step (must be a power of two) + #[arg(long, env = "BLOCK_BATCH_SIZE", default_value_t = 2)] + pub block_batch_size: usize, + + /// Gas to spend per burn call when submitting rollups + #[arg(long, env = "GAS_PER_BURN_CALL", default_value_t = 1_000_000)] + pub gas_per_burn_call: u128, + + /// Rollup transaction receipt timeout in milliseconds + #[arg(long, env = "ROLLUP_RECEIPT_TIMEOUT_MS", default_value_t = 300_000)] + pub rollup_receipt_timeout_ms: u64, + + /// Poll interval in milliseconds while waiting for rollup receipt + #[arg(long, env = "ROLLUP_RECEIPT_POLL_INTERVAL_MS", default_value_t = 1_000)] + pub rollup_receipt_poll_interval_ms: u64, + + /// Number of times to retry rollup submission on failure + #[arg(long, env = "ROLLUP_SUBMIT_RETRY_ATTEMPTS", default_value_t = 3)] + pub rollup_submit_retry_attempts: usize, + + /// Delay in milliseconds between rollup submission retries + #[arg(long, env = "ROLLUP_SUBMIT_RETRY_DELAY_MS", default_value_t = 1_000)] + pub rollup_submit_retry_delay_ms: u64, + + /// When set, execute a single aggregation step and exit + #[arg(long, env = "RUN_ONCE", default_value_t = false)] + pub run_once: bool, + + /// Optional barretenberg API server base URL; when provided aggregator-cli uses it instead of bb CLI + #[arg(long, env = "BARRETENBERG_API_URL")] + pub barretenberg_api_url: Option, + + /// Maximum number of concurrent barretenberg proofs/verifications (defaults to 1 to reduce RAM usage) + #[arg(long, env = "BB_MAX_CONCURRENCY", default_value_t = 1)] + pub bb_max_concurrency: usize, + + /// Barretenberg API request timeout in milliseconds + #[arg(long, env = "BB_TIMEOUT_MS", default_value_t = 300_000)] + pub bb_timeout_ms: u64, + + /// Barretenberg API TCP connect timeout in milliseconds + #[arg(long, env = "BB_CONNECT_TIMEOUT_MS", default_value_t = 1_000)] + pub bb_connect_timeout_ms: u64, + + /// Barretenberg API permit acquisition timeout in milliseconds + #[arg(long, env = "BB_PERMIT_TIMEOUT_MS", default_value_t = 100)] + pub bb_permit_timeout_ms: u64, + + /// Buffer time in milliseconds to wait for 100-continue response from Barretenberg API + #[arg( + long, + env = "BB_EXPECT_CONTINUE_TIMEOUT_BUFFER_MS", + default_value_t = 500 + )] + pub bb_expect_continue_timeout_buffer_ms: u64, + + /// Delay in milliseconds between Barretenberg API retries + #[arg(long, env = "BB_RETRY_DELAY_MS", default_value_t = 500)] + pub bb_retry_delay_ms: u64, + + /// Maximum duration in milliseconds to retry Barretenberg API requests + #[arg(long, env = "BB_MAX_RETRY_DURATION_MS", default_value_t = 3_600_000)] + pub bb_max_retry_duration_ms: u64, + + /// Number of steps to pipeline (prepare next step while proving current step) + #[arg(long, env = "PIPELINE_DEPTH", default_value_t = 1)] + pub pipeline_depth: usize, + + /// Optional root-hash keyed overrides for rollup tree seed height. + /// + /// Format per entry: `ROOT_HASH=HEIGHT`, for example `0xabc123=42`. + /// Repeat the flag or provide a comma-separated list. + #[arg( + long, + env = "ROLLUP_ROOT_HEIGHT_OVERRIDES", + value_delimiter = ',', + value_name = "ROOT_HASH=HEIGHT", + value_parser = parse_root_height_override + )] + pub rollup_root_height_overrides: Vec, +} diff --git a/pkg/aggregator-cli/src/error.rs b/pkg/aggregator-cli/src/error.rs new file mode 100644 index 0000000..396b17f --- /dev/null +++ b/pkg/aggregator-cli/src/error.rs @@ -0,0 +1,37 @@ +use aggregator_interface::{BlockProverError, Error as AggregatorError}; +use barretenberg_interface::Error as BarretenbergError; +use contextful::Contextful; +use element::Element; +use node_client_http::Error as NodeError; +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Debug, Error)] +pub enum Error { + #[error("[aggregator-cli] invalid configuration: {0}")] + InvalidConfig(String), + #[error("[aggregator-cli] invalid secret key: {0}")] + InvalidSecretKey(String), + #[error( + "[aggregator-cli] rollup tree mismatch (node {node_root:?} vs contract {contract_root:?})" + )] + RootMismatch { + node_root: Element, + contract_root: Element, + }, + #[error("[aggregator-cli] block prover error")] + BlockProver(#[from] Contextful), + #[error("[aggregator-cli] aggregator error")] + Aggregator(#[from] Contextful), + #[error("[aggregator-cli] contracts error")] + Contracts(#[from] Contextful), + #[error("[aggregator-cli] node rpc error")] + Node(#[from] Contextful), + #[error("[aggregator-cli] barretenberg backend error")] + Barretenberg(#[from] Contextful), + #[error("[aggregator-cli] url parse error")] + Url(#[from] Contextful), + #[error("[aggregator-cli] tokio join error: {0}")] + TokioJoin(#[from] Contextful), +} diff --git a/pkg/aggregator-cli/src/main.rs b/pkg/aggregator-cli/src/main.rs new file mode 100644 index 0000000..48b07cc --- /dev/null +++ b/pkg/aggregator-cli/src/main.rs @@ -0,0 +1,204 @@ +// lint-long-file-override allow-max-lines=300 +mod config; +mod error; +mod runner; + +#[cfg(test)] +mod tests; + +use std::{str::FromStr, sync::Arc, time::Duration}; + +use aggregator::{ + AggAggCircuitInterface, AggFinalCircuitInterface, AggUtxoCircuitInterface, Aggregator, + BlockProver, ContractsRollupContract, LimitedBbBackend, RetryableRollupContract, +}; +use aggregator_interface::{BlockProver as BlockProverTrait, PrioritizableBbBackend}; +use barretenberg_api_client::ClientBackend; +use barretenberg_cli::CliBackend; +use barretenberg_interface::BbBackend; +use clap::Parser; +use contextful::ResultContextExt; +use contracts::{Client, SecretKey}; +use element::Element; +use node_client_http::{NodeClientHttp, Url as NodeUrl}; +use node_interface::NodeClient; +use rpc::tracing::setup_tracing; +use tokio::signal; +use url::Url as ApiUrl; +use zk_circuits::{AggAggCircuit, AggFinalCircuit, AggUtxoCircuit}; + +use crate::config::{ + Config, select_rollup_tree_seed_height, validate_rollup_root_height_overrides, +}; +use crate::error::{Error, Result}; +use crate::runner::{build_rollup_tree, run_loop, run_once}; + +#[tokio::main] +async fn main() -> Result<()> { + let config = Config::parse(); + let _guard = setup_tracing( + &[ + "aggregator", + "aggregator_cli", + "aggregator_interface", + "notes", + ], + &config.log_level, + &config.log_format, + None, + config.env_name.clone(), + ) + .map_err(|err| Error::InvalidConfig(format!("failed to setup tracing: {err}")))?; + + if config.bb_max_concurrency == 0 { + return Err(Error::InvalidConfig( + "BB_MAX_CONCURRENCY must be greater than zero".to_string(), + )); + } + if config.pipeline_depth == 0 { + return Err(Error::InvalidConfig( + "PIPELINE_DEPTH must be greater than zero".to_string(), + )); + } + validate_rollup_root_height_overrides(&config.rollup_root_height_overrides) + .map_err(Error::InvalidConfig)?; + + let secret_key = parse_secret_key(&config.evm_secret_key)?; + let client = + Client::new(&config.evm_rpc_url, config.minimum_gas_price_gwei).with_latest_nonce(); + let rollup_contract = Arc::new( + contracts::RollupContract::load( + client, + config.chain_id as u128, + &config.rollup_contract_address, + secret_key, + ) + .await + .context("load rollup contract")?, + ); + + let node_url = NodeUrl::parse(&config.node_rpc_url).context("parse node rpc url")?; + let http_client = NodeClientHttp::new(node_url); + + let contract_height = rollup_contract + .block_height() + .await + .context("fetch rolled block height")?; + let rolled_root = Element::from_be_bytes( + rollup_contract + .root_hash() + .await + .context("fetch rolled root hash")? + .to_fixed_bytes(), + ); + let rolled_height = select_rollup_tree_seed_height( + contract_height, + rolled_root, + &config.rollup_root_height_overrides, + ); + if rolled_height != contract_height { + tracing::info!( + contract_height, + rolled_height, + contract_root = ?rolled_root, + "using configured rollup tree height override for contract root hash" + ); + } + let rollup_tree = build_rollup_tree(&http_client, rolled_height, rolled_root).await?; + + let node_client: Arc = Arc::new(http_client.clone()); + let bb_backend = build_backend(&config)?; + let agg_utxo_circuit: Arc = Arc::new(AggUtxoCircuit); + let agg_agg_circuit: Arc = Arc::new(AggAggCircuit); + let agg_final_circuit: Arc = Arc::new(AggFinalCircuit); + let block_prover: Arc = Arc::new(BlockProver::new( + Arc::clone(&node_client), + Arc::clone(&agg_utxo_circuit), + Arc::clone(&agg_agg_circuit), + )); + let base_contract_adapter = ContractsRollupContract::new( + Arc::clone(&rollup_contract), + Duration::from_millis(config.rollup_receipt_timeout_ms), + Duration::from_millis(config.rollup_receipt_poll_interval_ms), + ); + let contract_adapter = Arc::new(RetryableRollupContract::new( + Box::new(base_contract_adapter), + config.rollup_submit_retry_attempts, + Duration::from_millis(config.rollup_submit_retry_delay_ms), + )); + + if config.block_batch_size < 2 { + return Err(Error::InvalidConfig( + "BLOCK_BATCH_SIZE must be at least 2".to_string(), + )); + } + if !config.block_batch_size.is_power_of_two() { + return Err(Error::InvalidConfig( + "BLOCK_BATCH_SIZE must be a power of two".to_string(), + )); + } + if config.gas_per_burn_call == 0 { + return Err(Error::InvalidConfig( + "GAS_PER_BURN_CALL must be greater than zero".to_string(), + )); + } + + let aggregator = Arc::new(Aggregator::new( + node_client, + contract_adapter, + block_prover, + Box::new(rollup_tree), + config.block_batch_size, + config.gas_per_burn_call, + Arc::clone(&agg_agg_circuit), + agg_final_circuit, + )); + + if config.run_once { + run_once(aggregator, bb_backend, config.pipeline_depth).await?; + return Ok(()); + } + + let shutdown = Box::pin(async { + let _ = signal::ctrl_c().await; + }); + + run_loop( + aggregator, + bb_backend, + Duration::from_millis(config.poll_interval_ms), + config.pipeline_depth, + shutdown, + ) + .await +} + +fn build_backend(config: &Config) -> Result> { + let base_backend: Arc = if let Some(ref raw_url) = config.barretenberg_api_url { + let api_url = ApiUrl::parse(raw_url).context("parse barretenberg api url")?; + let client = ClientBackend::with_retry_policy( + api_url, + Duration::from_millis(config.bb_timeout_ms), + Duration::from_millis(config.bb_connect_timeout_ms), + Some(Duration::from_millis(config.bb_permit_timeout_ms)), + Duration::from_millis(config.bb_retry_delay_ms), + Duration::from_millis(config.bb_max_retry_duration_ms), + Duration::from_millis(config.bb_expect_continue_timeout_buffer_ms), + ) + .context("create barretenberg api client")?; + Arc::new(client) + } else { + Arc::new(CliBackend) + }; + + Ok(Arc::new(LimitedBbBackend::new( + base_backend, + config.bb_max_concurrency, + ))) +} + +fn parse_secret_key(raw: &str) -> Result { + let trimmed = raw.trim(); + let hex = trimmed.strip_prefix("0x").unwrap_or(trimmed); + SecretKey::from_str(hex).map_err(|err| Error::InvalidSecretKey(err.to_string())) +} diff --git a/pkg/aggregator-cli/src/runner.rs b/pkg/aggregator-cli/src/runner.rs new file mode 100644 index 0000000..4364dd2 --- /dev/null +++ b/pkg/aggregator-cli/src/runner.rs @@ -0,0 +1,272 @@ +// lint-long-file-override allow-max-lines=300 +use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc, time::Duration}; + +use aggregator::SmirkRollupTree; +use aggregator_interface::{ + Aggregator as AggregatorTrait, PreparationOutcome, PreparedBatch, PrioritizableBbBackend, + ProvenBatch, RollupTree, +}; +use async_fn_stream::try_fn_stream; +use contextful::ResultContextExt; +use element::Element; +use futures::{Stream, StreamExt}; +use node_client_http::NodeClientHttp; +use primitives::block_height::BlockHeight; +use tokio::{time::Instant, time::sleep}; +use tracing::{error, info, warn}; + +use crate::error::{Error, Result}; + +fn prepare_stream( + aggregator: Arc, +) -> impl Stream> { + try_fn_stream(|emitter| async move { + loop { + let outcome = aggregator + .prepare_next_batch() + .await + .context("prepare next batch")?; + + if let PreparationOutcome::Success(batch) = &outcome { + let start_height = batch.prepared_blocks.first().unwrap().height; + let end_height = batch.prepared_blocks.last().unwrap().height; + info!( + start_height, + end_height, + count = batch.prepared_blocks.len(), + "prepared next batch" + ); + } + + emitter.emit(outcome).await; + } + }) +} + +fn proof_stream( + aggregator: Arc, + bb_backend: Arc, + pipeline_depth: usize, + prepare_stream: impl Stream)>, +) -> impl Stream> { + prepare_stream + .map(move |(seq, batch_res)| { + let aggregator = Arc::clone(&aggregator); + let bb_backend = Arc::clone(&bb_backend); + async move { + let batch = batch_res?; + let start_height = batch.prepared_blocks.first().unwrap().height; + let end_height = batch.prepared_blocks.last().unwrap().height; + + let prioritized_backend = bb_backend.with_priority(start_height); + + info!(start_height, end_height, "proving batch"); + let start = Instant::now(); + let res = aggregator + .prove_batch(batch, prioritized_backend) + .await + .context("prove batch") + .map_err(Error::from); + + if res.is_ok() { + info!( + start_height, + end_height, + duration_ms = start.elapsed().as_millis(), + "finished proof for batch" + ); + } else { + error!(start_height, end_height, "failed to prove batch"); + } + res.map(|proven| (seq, proven)) + } + }) + .buffer_unordered(pipeline_depth) +} + +fn reorder_stream( + proof_stream: impl Stream>, +) -> impl Stream> { + try_fn_stream(|emitter| async move { + let mut buffer = BTreeMap::new(); + let mut next_seq = 0; + tokio::pin!(proof_stream); + + while let Some(res) = proof_stream.next().await { + let (seq, proven) = res?; + buffer.insert(seq, proven); + + while let Some(proven) = buffer.remove(&next_seq) { + emitter.emit(proven).await; + next_seq += 1; + } + } + Ok(()) + }) +} + +fn submission_stream( + aggregator: Arc, + proof_stream: impl Stream>, +) -> impl Stream> { + proof_stream.then(move |proven_res| { + let aggregator = Arc::clone(&aggregator); + async move { + let proven = proven_res?; + let start_height = proven.blocks.first().unwrap().block.content.header.height.0; + let end_height = proven.blocks.last().unwrap().block.content.header.height.0; + + info!(start_height, end_height, "submitting batch to contract"); + aggregator + .submit_batch(proven) + .await + .context("failed to submit batch")?; + info!(start_height, end_height, "successfully submitted batch"); + Ok(()) + } + }) +} + +fn build_pipeline( + aggregator: Arc, + bb_backend: Arc, + pipeline_depth: usize, + prepare_stream: impl Stream)>, +) -> impl Stream> { + let proof = proof_stream( + Arc::clone(&aggregator), + bb_backend, + pipeline_depth, + prepare_stream, + ); + let ordered_proof = reorder_stream(proof); + submission_stream(aggregator, ordered_proof) +} + +pub async fn run_once( + aggregator: Arc, + bb_backend: Arc, + pipeline_depth: usize, +) -> Result<()> { + let prepare = prepare_stream(Arc::clone(&aggregator)) + .scan((), |_, res| async move { + match res { + Ok(PreparationOutcome::Success(batch)) => Some(Ok(batch)), + Ok(PreparationOutcome::InsufficientBlocks { + start_height, + available, + required, + }) => { + info!( + start_height, + available, required, "insufficient blocks, stopping" + ); + None + } + Err(err) => Some(Err(err)), + } + }) + .enumerate(); + let submission = build_pipeline(aggregator, bb_backend, pipeline_depth, prepare); + tokio::pin!(submission); + + while let Some(res) = submission.next().await { + res?; + } + + Ok(()) +} + +pub async fn run_loop( + aggregator: Arc, + bb_backend: Arc, + interval: Duration, + pipeline_depth: usize, + mut shutdown: Pin + Send>>, +) -> Result<()> { + info!(pipeline_depth, "starting aggregator loop"); + + let prepare = prepare_stream(Arc::clone(&aggregator)) + .filter_map(|res| async move { + match res { + Ok(PreparationOutcome::Success(batch)) => Some(Ok(batch)), + Ok(PreparationOutcome::InsufficientBlocks { + start_height, + available, + required, + }) => { + info!( + start_height, + available, required, "insufficient blocks, waiting..." + ); + sleep(interval).await; + None + } + Err(err) => { + warn!(?err, "failed to prepare next batch, retrying"); + sleep(interval).await; + None + } + } + }) + .enumerate(); + let submission = build_pipeline(aggregator, bb_backend, pipeline_depth, prepare); + tokio::pin!(submission); + + loop { + tokio::select! { + biased; + res = submission.next() => { + match res { + Some(Ok(())) => {}, + Some(Err(err)) => return Err(err), + None => break, + } + } + _ = &mut shutdown => { + info!("shutting down aggregator"); + return Ok(()); + } + } + } + + Ok(()) +} + +pub async fn build_rollup_tree( + client: &NodeClientHttp, + rolled_height: u64, + expected_root: Element, +) -> Result { + info!( + rolled_height, + ?expected_root, + "building rollup tree from snapshot" + ); + let snapshot = client + .block_tree(BlockHeight(rolled_height)) + .await + .context("fetch block tree snapshot")?; + + if snapshot.root_hash != expected_root { + return Err(Error::RootMismatch { + node_root: snapshot.root_hash, + contract_root: expected_root, + }); + } + + let mut tree = SmirkRollupTree::new(); + let inserts: Vec<_> = snapshot + .elements + .into_iter() + .map(|element| (element, rolled_height)) + .collect(); + + let count = inserts.len(); + tree.insert(&inserts) + .context("seed rollup tree insert batch")?; + tree.set_height(rolled_height); + + info!(count, "successfully seeded rollup tree"); + Ok(tree) +} diff --git a/pkg/aggregator-cli/src/tests.rs b/pkg/aggregator-cli/src/tests.rs new file mode 100644 index 0000000..1bc0ca4 --- /dev/null +++ b/pkg/aggregator-cli/src/tests.rs @@ -0,0 +1,371 @@ +// lint-long-file-override allow-max-lines=400 +use std::{sync::Arc, time::Duration}; + +use aggregator_interface::{ + AggregatorMock, PreparationOutcome, PreparedBatch, PreparedBlock, PreparedChunk, + PrioritizableBbBackend, PrioritizableBbBackendMock, ProvenBatch, +}; +use async_trait::async_trait; +use clap::Parser; +use element::Element; +use futures::FutureExt; +use futures::future::BoxFuture; +use node_interface::{Block, BlockContent, BlockHeader, BlockState, BlockWithInfo}; +use primitives::{block_height::BlockHeight, hash::CryptoHash, sig::Signature}; +use unimock::{MockFn, Unimock, matching, unimock}; +use zk_primitives::{AggFinalProof, UtxoProofBundleWithMerkleProofs}; + +use crate::config::{ + Config, select_rollup_tree_seed_height, validate_rollup_root_height_overrides, +}; +use crate::runner::{run_loop, run_once}; + +#[unimock(api = ProofMock)] +trait AsyncProveBatch { + fn prove_batch( + &self, + batch: PreparedBatch, + ) -> BoxFuture<'static, Result>; +} + +/// Helper struct to allow `unimock` to return a `BoxFuture` for `prove_batch`. +/// This is necessary because `Aggregator` is an `#[async_trait]`, and `unimock`'s default behavior +/// for `answers` expects a return value of `Result`, not allowing us to easily simulate delays +/// by returning a pending Future manually. +struct AggregatorWithAsyncProveBatch(Unimock); + +#[async_trait] +impl aggregator_interface::Aggregator for AggregatorWithAsyncProveBatch { + async fn prepare_next_batch(&self) -> Result { + aggregator_interface::Aggregator::prepare_next_batch(&self.0).await + } + + async fn prove_batch( + &self, + batch: aggregator_interface::PreparedBatch, + _bb_backend: Arc, + ) -> Result { + AsyncProveBatch::prove_batch(&self.0, batch).await + } + + async fn submit_batch(&self, batch: ProvenBatch) -> Result<(), aggregator_interface::Error> { + aggregator_interface::Aggregator::submit_batch(&self.0, batch).await + } +} + +#[tokio::test] +async fn test_out_of_order_proof_completion_reordering() { + let batch1 = create_batch(1, 10, 20); + let batch2 = create_batch(2, 20, 30); + let batch3 = create_batch(3, 30, 40); + let batch4 = create_batch(4, 40, 50); + + let proven1 = create_proven(1, [1; 32]); + let proven2 = create_proven(2, [2; 32]); + let proven3 = create_proven(3, [3; 32]); + let proven4 = create_proven(4, [4; 32]); + + let batch1_root = leak(batch1.new_root); + let batch2_root = leak(batch2.new_root); + let batch3_root = leak(batch3.new_root); + let batch4_root = leak(batch4.new_root); + let proven1_hash = leak(proven1.other_hash); + let proven2_hash = leak(proven2.other_hash); + let proven3_hash = leak(proven3.other_hash); + let proven4_hash = leak(proven4.other_hash); + + let unimock = Unimock::new(( + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch1))), + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch2))), + ProofMock::prove_batch + .next_call(matching!((batch) if batch.new_root == *batch1_root)) + .answers(leak(move |_, _| { + let proven1 = proven1.clone(); + async move { + tokio::time::sleep(Duration::from_millis(50)).await; + Ok(proven1) + } + .boxed() + })), + ProofMock::prove_batch + .next_call(matching!((batch) if batch.new_root == *batch2_root)) + .answers(leak(move |_, _| { + futures::future::ready(Ok(proven2.clone())).boxed() + })), + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch3))), + ProofMock::prove_batch + .next_call(matching!((batch) if batch.new_root == *batch3_root)) + .answers(leak(move |_, _| { + futures::future::ready(Ok(proven3.clone())).boxed() + })), + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch4))), + ProofMock::prove_batch + .next_call(matching!((batch) if batch.new_root == *batch4_root)) + .answers(leak(move |_, _| { + futures::future::ready(Ok(proven4.clone())).boxed() + })), + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::InsufficientBlocks { + start_height: 5, + available: 0, + required: 2, + })), + AggregatorMock::submit_batch + .next_call(matching!((proven) if proven.other_hash == *proven1_hash)) + .returns(Ok(())), + AggregatorMock::submit_batch + .next_call(matching!((proven) if proven.other_hash == *proven2_hash)) + .returns(Ok(())), + AggregatorMock::submit_batch + .next_call(matching!((proven) if proven.other_hash == *proven3_hash)) + .returns(Ok(())), + AggregatorMock::submit_batch + .next_call(matching!((proven) if proven.other_hash == *proven4_hash)) + .returns(Ok(())), + )); + + let aggregator = Arc::new(AggregatorWithAsyncProveBatch(unimock)); + let bb_backend = mock_prioritizable_backend(); + + run_once(aggregator, bb_backend, 2).await.unwrap(); +} + +#[tokio::test] +async fn test_run_loop_sequential() { + let batch1 = create_batch(1, 10, 20); + let batch2 = create_batch(2, 20, 30); + + let proven1 = create_proven(1, [1; 32]); + let proven2 = create_proven(2, [2; 32]); + + let batch1_root = leak(batch1.new_root); + let batch2_root = leak(batch2.new_root); + let proven1_hash = leak(proven1.other_hash); + let proven2_hash = leak(proven2.other_hash); + + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let shutdown_tx = Arc::new(std::sync::Mutex::new(Some(shutdown_tx))); + + let aggregator = Arc::new(Unimock::new(( + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch1.clone()))), + AggregatorMock::prove_batch + .next_call(matching!((batch, _) if batch.new_root == *batch1_root)) + .returns(Ok(proven1.clone())), + AggregatorMock::submit_batch + .next_call(matching!((proven) if proven.other_hash == *proven1_hash)) + .returns(Ok(())), + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch2.clone()))), + AggregatorMock::prove_batch + .next_call(matching!((batch, _) if batch.new_root == *batch2_root)) + .returns(Ok(proven2.clone())), + AggregatorMock::submit_batch + .next_call(matching!((proven) if proven.other_hash == *proven2_hash)) + .answers(leak(move |_, _| { + if let Some(tx) = shutdown_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + Ok(()) + })), + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::InsufficientBlocks { + start_height: 1, + available: 0, + required: 2, + })), + ))); + + let bb_backend = mock_prioritizable_backend(); + + let shutdown_future = Box::pin(async move { + let _ = shutdown_rx.await; + }); + + run_loop( + aggregator, + bb_backend, + Duration::from_millis(1), + 1, + shutdown_future, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn test_run_loop_submission_failure() { + let batch1 = create_batch(1, 10, 20); + let proven1 = create_proven(1, [1; 32]); + let batch1_root = leak(batch1.new_root); + + let aggregator = Arc::new(Unimock::new(( + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::Success(batch1.clone()))), + AggregatorMock::prove_batch + .next_call(matching!((batch, _) if batch.new_root == *batch1_root)) + .returns(Ok(proven1.clone())), + AggregatorMock::submit_batch + .next_call(matching!(_)) + .returns(Err(aggregator_interface::Error::MissingApprovalBlock)), + ))); + + let bb_backend = mock_prioritizable_backend(); + + let shutdown_future = Box::pin(async move { + tokio::time::sleep(Duration::from_millis(100)).await; + }); + + let res = run_loop( + aggregator, + bb_backend, + Duration::from_millis(1), + 1, + shutdown_future, + ) + .await; + + assert!(res.is_err()); + assert!(format!("{:?}", res.unwrap_err()).contains("failed to submit batch")); +} + +#[tokio::test] +async fn test_run_once_insufficient_blocks() { + let aggregator = Arc::new(Unimock::new( + AggregatorMock::prepare_next_batch + .next_call(matching!(())) + .returns(Ok(PreparationOutcome::InsufficientBlocks { + start_height: 1, + available: 1, + required: 2, + })), + )); + + let bb_backend = Arc::new(Unimock::new(())); + + run_once(aggregator, bb_backend, 2).await.unwrap(); +} + +#[test] +fn test_rollup_root_height_overrides_parse_and_select_height() { + let config = Config::try_parse_from([ + "aggregator-cli", + "--rollup-root-height-overrides", + "0x2a=123,0x2b=456", + ]) + .unwrap(); + + assert_eq!(config.rollup_root_height_overrides.len(), 2); + assert_eq!( + select_rollup_tree_seed_height(0, Element::new(42), &config.rollup_root_height_overrides,), + 123 + ); + assert_eq!( + select_rollup_tree_seed_height( + 99, + Element::new(1000), + &config.rollup_root_height_overrides, + ), + 99 + ); +} + +#[test] +fn test_rollup_root_height_overrides_invalid_value() { + let err = Config::try_parse_from([ + "aggregator-cli", + "--rollup-root-height-overrides", + "not-a-valid-override", + ]) + .unwrap_err() + .to_string(); + + assert!(err.contains("ROOT_HASH=HEIGHT")); +} + +#[test] +fn test_rollup_root_height_overrides_conflicting_heights_rejected() { + let config = Config::try_parse_from([ + "aggregator-cli", + "--rollup-root-height-overrides", + "0x1=7,0x1=9", + ]) + .unwrap(); + + let err = + validate_rollup_root_height_overrides(&config.rollup_root_height_overrides).unwrap_err(); + assert!(err.contains("conflicting heights for root hash")); +} + +fn create_batch(height: u64, old_root: u64, new_root: u64) -> PreparedBatch { + PreparedBatch { + blocks: vec![], + prepared_blocks: vec![PreparedBlock { + height, + chunks: [dummy_chunk(), dummy_chunk()], + }], + old_root: Element::new(old_root), + new_root: Element::new(new_root), + } +} + +fn create_proven(height: u64, hash: [u8; 32]) -> ProvenBatch { + let block = BlockWithInfo { + block: Block { + content: BlockContent { + header: BlockHeader { + height: BlockHeight(height), + last_block_hash: CryptoHash::default(), + epoch_id: 0, + last_final_block_hash: CryptoHash::default(), + approvals: vec![], + }, + state: BlockState { + root_hash: Element::ZERO, + txns: vec![], + }, + }, + signature: Signature::default(), + }, + hash: CryptoHash::default(), + time: 0, + }; + ProvenBatch { + final_proof: AggFinalProof::default(), + blocks: vec![block], + other_hash: hash, + } +} + +fn dummy_chunk() -> PreparedChunk { + PreparedChunk { + old_root: Element::ZERO, + new_root: Element::ZERO, + bundles: std::array::from_fn(|_| UtxoProofBundleWithMerkleProofs::default()), + } +} + +fn mock_prioritizable_backend() -> Arc { + Arc::new(Unimock::new( + PrioritizableBbBackendMock::with_priority + .some_call(matching!(_)) + .answers(leak(|unimock: &Unimock, _| Arc::new(unimock.clone()))), + )) +} + +fn leak(value: T) -> &'static T { + Box::leak(Box::new(value)) +}