Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions crypto/math-cuda/kernels/keccak.cu
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,37 @@ extern "C" __global__ void keccak_fri_leaves_ext3(
// children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs]
// parents: nodes[parent_begin .. parent_begin + n_pairs]
//
// Each thread hashes one child pair → one parent. Keccak-256 of the
// concatenation of two 32-byte siblings, identical to
// `FieldElementVectorBackend::hash_new_parent` on host.
// ---------------------------------------------------------------------------
extern "C" __global__ void keccak_merkle_level(
uint8_t *nodes,
uint64_t parent_begin, // node index (counted in 32-byte nodes)
uint64_t n_pairs) {
uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n_pairs) return;

uint64_t st[25];
#pragma unroll
for (int i = 0; i < 25; ++i) st[i] = 0;

uint32_t rate_pos = 0;
// `nodes` comes from cuMemAlloc (256-byte aligned); each 32-byte node
// sits at a 32-byte-aligned offset, so the u64 cast is safe.
const uint64_t *left = reinterpret_cast<const uint64_t *>(
nodes + (parent_begin + n_pairs + 2 * tid) * 32);
#pragma unroll
for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, left[i]);

const uint64_t *right = reinterpret_cast<const uint64_t *>(
nodes + (parent_begin + n_pairs + 2 * tid + 1) * 32);
#pragma unroll
for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, right[i]);

finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32);
}

// ---------------------------------------------------------------------------
// Row-major base leaf hashing.
//
Expand Down Expand Up @@ -350,34 +381,3 @@ extern "C" __global__ void keccak256_leaves_base_row_major(
}
finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32);
}

// Each thread hashes one child pair → one parent. Keccak-256 of the
// concatenation of two 32-byte siblings, identical to
// `FieldElementVectorBackend::hash_new_parent` on host.
// ---------------------------------------------------------------------------
extern "C" __global__ void keccak_merkle_level(
uint8_t *nodes,
uint64_t parent_begin, // node index (counted in 32-byte nodes)
uint64_t n_pairs) {
uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n_pairs) return;

uint64_t st[25];
#pragma unroll
for (int i = 0; i < 25; ++i) st[i] = 0;

uint32_t rate_pos = 0;
// `nodes` comes from cuMemAlloc (256-byte aligned); each 32-byte node
// sits at a 32-byte-aligned offset, so the u64 cast is safe.
const uint64_t *left = reinterpret_cast<const uint64_t *>(
nodes + (parent_begin + n_pairs + 2 * tid) * 32);
#pragma unroll
for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, left[i]);

const uint64_t *right = reinterpret_cast<const uint64_t *>(
nodes + (parent_begin + n_pairs + 2 * tid + 1) * 32);
#pragma unroll
for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, right[i]);

finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32);
}
181 changes: 69 additions & 112 deletions crypto/math-cuda/src/lde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ fn launch_keccak_base_row_major(
// threads/block, which exceeds the per-block register budget and fails the
// launch with CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES — silently dropping the whole
// R1 GPU path to the CPU fallback (no device handle for rounds 2-4).
//
// The kernel derives the bit-reversed row as `__brevll(tid) >> (64 - log_num_rows)`;
// a 64-bit shift is UB, so reject `num_rows < 2` (`log_num_rows == 0`), matching
// the `debug_assert!` guard in `launch_keccak_base`.
debug_assert!(num_rows >= 2, "row-major keccak requires num_rows >= 2");
let cfg = keccak_launch_cfg(num_rows);
unsafe {
stream
Expand Down Expand Up @@ -368,74 +373,79 @@ fn launch_row_to_col_major(
Ok(dst)
}

/// Row-major LDE + Keccak + Merkle, all on-device.
/// Shared row-major LDE + Keccak + Merkle pipeline for the base and ext3 paths.
///
/// `total_cols` is the number of base-field columns in the row-major layout:
/// `m` for base, `m * 3` for ext3. Because `Fp3 = [u64; 3]`, the three ext3
/// components are just three adjacent base-field columns, so the same row-major
/// NTT and Keccak kernels process all of them simultaneously — no de-interleave.
///
/// Input: `row_major` is a flat `n * m` slice in row-major order.
/// Returns (merkle_nodes, GpuLdeBase handle, row-major LDE Vec).
/// Single H2D, row-major NTT, single D2H — no CPU-side extract or transpose.
/// The returned handle is column-major (as required by downstream GPU kernels):
/// after D2H, `buf` is transposed on-device to column-major for the handle.
pub fn coset_lde_row_major_with_merkle_tree_keep(
/// Returns (merkle_nodes, column-major device buffer, row-major LDE Vec). The
/// buffer is transposed to column-major (as required by the downstream GPU
/// kernels DEEP/barycentric); callers wrap it in the appropriate LDE handle.
fn coset_lde_row_major_inner(
row_major: &[u64],
n: usize,
m: usize,
total_cols: usize,
blowup_factor: usize,
weights: &[u64],
) -> Result<(Vec<u8>, GpuLdeBase, Vec<u64>)> {
assert_eq!(row_major.len(), n * m);
what: &str,
) -> Result<(Vec<u8>, CudaSlice<u64>, Vec<u64>)> {
assert_eq!(row_major.len(), n * total_cols);
assert!(n.is_power_of_two());
assert_eq!(weights.len(), n);
assert!(blowup_factor.is_power_of_two());
let lde_size = n * blowup_factor;
assert_u32_domain(lde_size, "coset_lde_row_major lde_size");
assert_u32_domain(lde_size, what);

let nodes_bytes = KeccakCommit::FullTree.total_nodes_bytes(lde_size);
let log_n = n.trailing_zeros() as u64;
let log_lde = lde_size.trailing_zeros() as u64;
let n_u64 = n as u64;
let lde_u64 = lde_size as u64;
let m_u64 = m as u64;
let cols_u64 = total_cols as u64;

let be = backend()?;
let stream = be.next_stream();

// H2D into a zeroed lde_size*m buffer; only the first n*m rows carry data,
// the remainder are already zero (zero-padding for LDE expansion).
let mut buf = stream.alloc_zeros::<u64>(lde_size * m)?;
stream.memcpy_htod(row_major, &mut buf.slice_mut(0..n * m))?;
// H2D into a zeroed lde_size*total_cols buffer; only the first n*total_cols
// rows carry data, the remainder are already zero (zero-padding for LDE).
let mut buf = stream.alloc_zeros::<u64>(lde_size * total_cols)?;
stream.memcpy_htod(row_major, &mut buf.slice_mut(0..n * total_cols))?;

let inv_tw = be.inv_twiddles_for(log_n)?;
let fwd_tw = be.fwd_twiddles_for(log_lde)?;
let weights_dev = stream.clone_htod(weights)?;

// iNTT: bit-reverse rows → per-level DIT.
launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, n_u64, log_n, m_u64)?;
launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, n_u64, log_n, cols_u64)?;
run_row_major_ntt_body(
stream.as_ref(),
be,
&mut buf,
inv_tw.as_ref(),
n_u64,
log_n,
m_u64,
cols_u64,
)?;

// Coset weights: one weight per row, broadcast across all m columns.
launch_pointwise_mul_row_major(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, m_u64)?;
// Coset weights: one weight per row, broadcast across all columns.
launch_pointwise_mul_row_major(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, cols_u64)?;

// Forward NTT at lde_size.
launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, lde_u64, log_lde, m_u64)?;
launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, lde_u64, log_lde, cols_u64)?;
run_row_major_ntt_body(
stream.as_ref(),
be,
&mut buf,
fwd_tw.as_ref(),
lde_u64,
log_lde,
m_u64,
cols_u64,
)?;

// Keccak + Merkle on-device.
// Keccak + Merkle on-device. Each leaf reads `total_cols` consecutive u64s.
let mut nodes_dev = unsafe { stream.alloc::<u8>(nodes_bytes) }?;
let leaves_offset = KeccakCommit::FullTree.leaves_offset_bytes(lde_size);
{
Expand All @@ -444,7 +454,7 @@ pub fn coset_lde_row_major_with_merkle_tree_keep(
stream.as_ref(),
be,
&buf,
m_u64,
cols_u64,
lde_u64,
log_lde,
&mut leaves_view,
Expand All @@ -457,11 +467,11 @@ pub fn coset_lde_row_major_with_merkle_tree_keep(
let lde_out = {
let staging_slot = be.pinned_staging();
let mut staging = staging_slot.lock().unwrap();
staging.ensure_capacity(lde_size * m, &be.ctx)?;
let pinned = unsafe { staging.as_mut_slice(lde_size * m) };
staging.ensure_capacity(lde_size * total_cols, &be.ctx)?;
let pinned = unsafe { staging.as_mut_slice(lde_size * total_cols) };
stream.memcpy_dtoh(&buf, pinned)?;
stream.synchronize()?;
let out = pinned[..lde_size * m].to_vec();
let out = pinned[..lde_size * total_cols].to_vec();
drop(staging);
out
};
Expand All @@ -471,16 +481,39 @@ pub fn coset_lde_row_major_with_merkle_tree_keep(

// Transpose row-major buf → column-major for the handle. Downstream kernels
// (DEEP, barycentric) expect buf[c * lde_size + r] (column-major).
let col_major_dev = launch_row_to_col_major(&stream, be, &buf, lde_size, m, lde_u64)?;
let col_major_dev = launch_row_to_col_major(&stream, be, &buf, lde_size, total_cols, lde_u64)?;
// Synchronize before returning: the handle crosses stream boundaries — downstream
// consumers call be.next_stream() and read handle.buf on a different stream.
// Without this, a barycentric or DEEP kernel can start before the transpose finishes.
stream.synchronize()?;

Ok((nodes_out, col_major_dev, lde_out))
}

/// Row-major LDE + Keccak + Merkle, all on-device.
///
/// Input: `row_major` is a flat `n * m` slice in row-major order.
/// Returns (merkle_nodes, GpuLdeBase handle, row-major LDE Vec).
/// The returned handle is column-major (as required by downstream GPU kernels).
pub fn coset_lde_row_major_with_merkle_tree_keep(
row_major: &[u64],
n: usize,
m: usize,
blowup_factor: usize,
weights: &[u64],
) -> Result<(Vec<u8>, GpuLdeBase, Vec<u64>)> {
let (nodes_out, col_major_dev, lde_out) = coset_lde_row_major_inner(
row_major,
n,
m,
blowup_factor,
weights,
"coset_lde_row_major lde_size",
)?;
let handle = GpuLdeBase {
buf: Arc::new(col_major_dev),
m,
lde_size,
lde_size: n * blowup_factor,
};
Ok((nodes_out, handle, lde_out))
}
Expand All @@ -502,94 +535,18 @@ pub fn coset_lde_ext3_row_major_with_merkle_tree_keep(
blowup_factor: usize,
weights: &[u64],
) -> Result<(Vec<u8>, GpuLdeExt3, Vec<u64>)> {
let m3 = m * 3;
assert_eq!(row_major.len(), n * m3);
assert!(n.is_power_of_two());
assert_eq!(weights.len(), n);
assert!(blowup_factor.is_power_of_two());
let lde_size = n * blowup_factor;
assert_u32_domain(lde_size, "coset_lde_ext3_row_major lde_size");

let nodes_bytes = KeccakCommit::FullTree.total_nodes_bytes(lde_size);
let log_n = n.trailing_zeros() as u64;
let log_lde = lde_size.trailing_zeros() as u64;
let n_u64 = n as u64;
let lde_u64 = lde_size as u64;
let m3_u64 = m3 as u64;

let be = backend()?;
let stream = be.next_stream();

let mut buf = stream.alloc_zeros::<u64>(lde_size * m3)?;
stream.memcpy_htod(row_major, &mut buf.slice_mut(0..n * m3))?;

let inv_tw = be.inv_twiddles_for(log_n)?;
let fwd_tw = be.fwd_twiddles_for(log_lde)?;
let weights_dev = stream.clone_htod(weights)?;

// iNTT + coset weights + forward NTT — same row-major kernels as base-field
// but with m3 = m*3 (all 3 components processed simultaneously).
launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, n_u64, log_n, m3_u64)?;
run_row_major_ntt_body(
stream.as_ref(),
be,
&mut buf,
inv_tw.as_ref(),
n_u64,
log_n,
m3_u64,
)?;
launch_pointwise_mul_row_major(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, m3_u64)?;
launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, lde_u64, log_lde, m3_u64)?;
run_row_major_ntt_body(
stream.as_ref(),
be,
&mut buf,
fwd_tw.as_ref(),
lde_u64,
log_lde,
m3_u64,
let (nodes_out, col_major_dev, lde_out) = coset_lde_row_major_inner(
row_major,
n,
m * 3,
blowup_factor,
weights,
"coset_lde_ext3_row_major lde_size",
)?;

// Keccak: same row-major kernel — each leaf reads m3 consecutive u64s (= m ext3 elements).
let mut nodes_dev = unsafe { stream.alloc::<u8>(nodes_bytes) }?;
let leaves_offset = KeccakCommit::FullTree.leaves_offset_bytes(lde_size);
{
let mut leaves_view = nodes_dev.slice_mut(leaves_offset..leaves_offset + lde_size * 32);
launch_keccak_base_row_major(
stream.as_ref(),
be,
&buf,
m3_u64,
lde_u64,
log_lde,
&mut leaves_view,
)?;
}
crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?;

let lde_out = {
let staging_slot = be.pinned_staging();
let mut staging = staging_slot.lock().unwrap();
staging.ensure_capacity(lde_size * m3, &be.ctx)?;
let pinned = unsafe { staging.as_mut_slice(lde_size * m3) };
stream.memcpy_dtoh(&buf, pinned)?;
stream.synchronize()?;
let out = pinned[..lde_size * m3].to_vec();
drop(staging);
out
};

let mut nodes_out = vec![0u8; nodes_bytes];
d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, &mut nodes_out)?;

let col_major_dev = launch_row_to_col_major(&stream, be, &buf, lde_size, m3, lde_u64)?;
stream.synchronize()?;

let handle = GpuLdeExt3 {
buf: Arc::new(col_major_dev),
m,
lde_size,
lde_size: n * blowup_factor,
};
Ok((nodes_out, handle, lde_out))
}
Expand Down
18 changes: 3 additions & 15 deletions crypto/stark/src/gpu_lde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,13 @@ use crate::trace::LDETraceTable;
const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19;

fn gpu_lde_threshold() -> usize {
// In test builds re-read the env var on every call so tests can switch
// between GPU and CPU paths in the same process (OnceLock can't be reset).
#[cfg(test)]
{
static CACHED: OnceLock<usize> = OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_GPU_LDE_THRESHOLD)
}
#[cfg(not(test))]
{
static CACHED: OnceLock<usize> = OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_GPU_LDE_THRESHOLD)
})
}
})
}

/// Incremented by the `try_expand_*` functions per base-field column handed to
Expand Down
Loading
Loading