Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 1 addition & 256 deletions crypto/math-cuda/src/lde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg};

use crate::Result;
use crate::device::{Backend, backend};
use crate::merkle::{
keccak_launch_cfg, launch_keccak_base, launch_keccak_base_row_pair, launch_keccak_ext3,
launch_keccak_ext3_row_pair,
};
use crate::merkle::{keccak_launch_cfg, launch_keccak_base, launch_keccak_base_row_pair};
use crate::ntt::run_ntt_body;

/// Goldilocks `TWO_ADICITY = 32` puts the theoretical domain ceiling at
Expand Down Expand Up @@ -969,35 +966,6 @@ pub fn coset_lde_batch_base_into_with_leaf_hash(
.map(|_| ())
}

/// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full
/// row-pair Merkle tree on device and returns the `2*(lde_size/2) - 1` node
/// buffer back to the caller in `merkle_nodes_out` (byte length
/// `(2*(lde_size/2) - 1) * 32`).
///
/// The leaf hashes are never exposed to the caller — they stay on device and
/// feed straight into the pair-hash tree kernel, avoiding the
/// pinned→pageable→pinned round-trip that the separate-step GPU tree build
/// would pay.
pub fn coset_lde_batch_base_into_with_merkle_tree(
columns: &[&[u64]],
blowup_factor: usize,
weights: &[u64],
outputs: &mut [&mut [u64]],
merkle_nodes_out: &mut [u8],
) -> Result<()> {
coset_lde_batch_base_into_with_merkle_tree_inner(
columns,
blowup_factor,
weights,
outputs,
merkle_nodes_out,
KeccakCommit::FullTree,
false,
2,
)
.map(|_| ())
}

#[allow(clippy::too_many_arguments)]
fn coset_lde_batch_base_into_with_merkle_tree_inner(
columns: &[&[u64]],
Expand Down Expand Up @@ -1180,229 +1148,6 @@ fn coset_lde_batch_base_into_with_merkle_tree_inner(
}
}

/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: fused LDE +
/// row-pair Keccak-256 leaf hashing over ext3 columns. Thin wrapper over
/// `coset_lde_batch_ext3_into_with_merkle_tree_inner` with `LeavesOnly`.
pub fn coset_lde_batch_ext3_into_with_leaf_hash(
columns: &[&[u64]],
n: usize,
blowup_factor: usize,
weights: &[u64],
outputs: &mut [&mut [u64]],
hashed_leaves_out: &mut [u8],
) -> Result<()> {
coset_lde_batch_ext3_into_with_merkle_tree_inner(
columns,
n,
blowup_factor,
weights,
outputs,
hashed_leaves_out,
KeccakCommit::LeavesOnly,
false,
2,
)
.map(|_| ())
}

/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`.
/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE
/// evaluations and the full `2*(lde_size/2) - 1` row-pair node buffer.
pub fn coset_lde_batch_ext3_into_with_merkle_tree(
columns: &[&[u64]],
n: usize,
blowup_factor: usize,
weights: &[u64],
outputs: &mut [&mut [u64]],
merkle_nodes_out: &mut [u8],
) -> Result<()> {
coset_lde_batch_ext3_into_with_merkle_tree_inner(
columns,
n,
blowup_factor,
weights,
outputs,
merkle_nodes_out,
KeccakCommit::FullTree,
false,
2,
)
.map(|_| ())
}

#[allow(clippy::too_many_arguments)]
fn coset_lde_batch_ext3_into_with_merkle_tree_inner(
columns: &[&[u64]],
n: usize,
blowup_factor: usize,
weights: &[u64],
outputs: &mut [&mut [u64]],
nodes_out: &mut [u8],
commit: KeccakCommit,
keep_device_buf: bool,
// 1 = one leaf per bit-reversed row; 2 = one leaf per row pair (2i, 2i+1),
// matching the CPU `commit_bit_reversed(.., 2)` used for the trace commit.
rows_per_leaf: usize,
) -> Result<Option<GpuLdeExt3>> {
if columns.is_empty() {
assert_eq!(outputs.len(), 0);
return Ok(None);
}
// (is_power_of_two returns false for 0).
if n == 0 {
return Ok(None);
}
let m = columns.len();
assert_eq!(outputs.len(), m);
assert!(n.is_power_of_two());
assert_eq!(weights.len(), n);
assert!(blowup_factor.is_power_of_two());
for c in columns.iter() {
assert_eq!(c.len(), 3 * n);
}
let lde_size = n * blowup_factor;
assert_u32_domain(
lde_size,
"coset_lde_batch_ext3_into_with_merkle_tree lde_size",
);
for o in outputs.iter() {
assert_eq!(o.len(), 3 * lde_size);
}
assert!(
rows_per_leaf == 1 || rows_per_leaf == 2,
"rows_per_leaf must be 1 or 2"
);
assert_eq!(lde_size % rows_per_leaf, 0);
let num_leaves = lde_size / rows_per_leaf;
let nodes_dev_bytes = commit.total_nodes_bytes(num_leaves);
assert_eq!(nodes_out.len(), nodes_dev_bytes);
let log_n = n.trailing_zeros() as u64;
let log_lde = lde_size.trailing_zeros() as u64;

let mb = 3 * m;
let be = backend()?;
let stream = be.next_stream();
let staging_slot = be.pinned_staging();

let mut staging = staging_slot.lock().unwrap();
staging.ensure_capacity(mb * lde_size, &be.ctx)?;
let pinned = unsafe { staging.as_mut_slice(mb * lde_size) };

pack_ext3_to_pinned_slabs(columns, pinned, n);

let mut buf = stream.alloc_zeros::<u64>(mb * lde_size)?;
for s in 0..mb {
let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n);
stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?;
}

let inv_tw = be.inv_twiddles_for(log_n)?;
let fwd_tw = be.fwd_twiddles_for(log_lde)?;
let weights_dev = stream.clone_htod(weights)?;

let n_u64 = n as u64;
let lde_u64 = lde_size as u64;
let col_stride_u64 = lde_size as u64;
let mb_u32 = mb as u32;

launch_bit_reverse_batched(
stream.as_ref(),
be,
&mut buf,
n_u64,
log_n,
col_stride_u64,
mb_u32,
)?;
run_batched_ntt_body(
stream.as_ref(),
&mut buf,
inv_tw.as_ref(),
n_u64,
log_n,
col_stride_u64,
mb_u32,
)?;
launch_pointwise_mul_batched(
stream.as_ref(),
be,
&mut buf,
&weights_dev,
n_u64,
col_stride_u64,
mb_u32,
)?;
launch_bit_reverse_batched(
stream.as_ref(),
be,
&mut buf,
lde_u64,
log_lde,
col_stride_u64,
mb_u32,
)?;
run_batched_ntt_body(
stream.as_ref(),
&mut buf,
fwd_tw.as_ref(),
lde_u64,
log_lde,
col_stride_u64,
mb_u32,
)?;

// Allocate device output buffer (LeavesOnly -> num_leaves*32; FullTree ->
// (2*num_leaves - 1)*32). Leaf kernel writes to the leaves slab; the
// inner-tree pass (when present) fills the head.
let mut nodes_dev = unsafe { stream.alloc::<u8>(nodes_dev_bytes) }?;
let leaves_offset_bytes = commit.leaves_offset_bytes(num_leaves);
{
let mut leaves_view =
nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32);
if rows_per_leaf == 2 {
launch_keccak_ext3_row_pair(
stream.as_ref(),
&buf,
col_stride_u64,
m as u64,
lde_u64,
&mut leaves_view,
)?;
} else {
launch_keccak_ext3(
stream.as_ref(),
&buf,
col_stride_u64,
m as u64,
lde_u64,
&mut leaves_view,
)?;
}
}

if commit == KeccakCommit::FullTree {
crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, num_leaves)?;
}

// D2H LDE (mb * lde_size u64) and tree/leaves nodes.
stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?;
d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, nodes_out)?;

unpack_pinned_slabs_to_ext3(pinned, outputs, lde_size);
drop(staging);

if keep_device_buf {
Ok(Some(GpuLdeExt3 {
buf: Arc::new(buf),
m,
lde_size,
}))
} else {
drop(buf);
Ok(None)
}
}

/// Batched ext3 polynomial → coset evaluation.
///
/// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64).
Expand Down
Loading