From a67345190f024951b37070990fc2395badd88ef5 Mon Sep 17 00:00:00 2001 From: B Harsha Kashyap Date: Fri, 24 Apr 2026 07:16:58 +0000 Subject: [PATCH 1/6] Add v4 distance kernels --- diskann-quantization/src/__codegen/x86_64.rs | 10 + diskann-quantization/src/bits/distances.rs | 271 +++++++++++++++++- .../src/spherical/__codegen/x86_64.rs | 84 ++++++ diskann-quantization/src/spherical/iface.rs | 4 +- 4 files changed, 365 insertions(+), 4 deletions(-) diff --git a/diskann-quantization/src/__codegen/x86_64.rs b/diskann-quantization/src/__codegen/x86_64.rs index 2ada1be64..91912811d 100644 --- a/diskann-quantization/src/__codegen/x86_64.rs +++ b/diskann-quantization/src/__codegen/x86_64.rs @@ -89,6 +89,16 @@ pub fn bits_v4_ip_bu2_bu2(arch: V4, x: USlice<'_, 2>, y: USlice<'_, 2>) -> MR, y: USlice<'_, 4>) -> MR { + arch.run2_inline(distances::SquaredL2, x, y) +} + +#[inline(never)] +pub fn bits_v4_ip_bu4_bu4(arch: V4, x: USlice<'_, 4>, y: USlice<'_, 4>) -> MR { + arch.run2_inline(distances::InnerProduct, x, y) +} + //------------// // Transposed // //------------// diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index e8e454dbf..700cb877a 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -558,6 +558,154 @@ impl Target2, USlice<'_, } } +/// Compute the squared L2 distance between `x` and `y`. +/// +/// Returns an error if the arguments have different lengths. +/// +/// # Implementation Notes +/// +/// This implementation is optimized for x86 with the AVX-512 vector extension. +/// It scales the V3 approach to 512-bit registers: we load data as `u32x16`, shift and +/// mask to extract 4-bit nibbles at 16-bit granularity (`0x000f000f` mask), reinterpret +/// as `i16x32`, compute differences, and use `_mm512_madd_epi16` via `dot_simd` to +/// accumulate squared differences into `i32x16`. +/// +/// AVX-512 does not have 16-bit integer bit-shift instructions, so we use 32-bit integer +/// shifts and then bit-cast to 16-bit intrinsics, which works because we apply the same +/// shift to all lanes. +#[cfg(target_arch = "x86_64")] +impl Target2, USlice<'_, 4>, USlice<'_, 4>> + for SquaredL2 +{ + #[expect(non_camel_case_types)] + #[inline(always)] + fn run( + self, + arch: diskann_wide::arch::x86_64::V4, + x: USlice<'_, 4>, + y: USlice<'_, 4>, + ) -> MathematicalResult { + let len = check_lengths!(x, y)?; + + type i32s = ::i32x16; + type u32s = ::u32x16; + type i16s = ::i16x32; + + let px_u32: *const u32 = x.as_ptr().cast(); + let py_u32: *const u32 = y.as_ptr().cast(); + + let mut i = 0; + let mut s: u32 = 0; + + // The number of 32-bit blocks over the underlying slice. + let blocks = len / 8; + if i < blocks { + let mut s0 = i32s::default(arch); + let mut s1 = i32s::default(arch); + let mut s2 = i32s::default(arch); + let mut s3 = i32s::default(arch); + let mask = u32s::splat(arch, 0x000f000f); + while i + 16 < blocks { + // SAFETY: We have checked that `i + 16 < blocks` which means the address + // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::())` is + // valid. + // + // The load has no alignment requirements. + let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) }; + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) }; + + let wx: i16s = (vx & mask).reinterpret_simd(); + let wy: i16s = (vy & mask).reinterpret_simd(); + let d = wx - wy; + s0 = s0.dot_simd(d, d); + + let wx: i16s = (vx >> 4 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 4 & mask).reinterpret_simd(); + let d = wx - wy; + s1 = s1.dot_simd(d, d); + + let wx: i16s = (vx >> 8 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 8 & mask).reinterpret_simd(); + let d = wx - wy; + s2 = s2.dot_simd(d, d); + + let wx: i16s = (vx >> 12 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); + let d = wx - wy; + s3 = s3.dot_simd(d, d); + + i += 16; + } + + let remainder = blocks - i; + + // SAFETY: At least one value of type `u32` is valid for an unaligned starting + // at offset `i`. The exact number is computed as `remainder`. + // + // The predicated load is guaranteed not to access memory after `remainder` and + // has no alignment requirements. + let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) }; + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) }; + + let wx: i16s = (vx & mask).reinterpret_simd(); + let wy: i16s = (vy & mask).reinterpret_simd(); + let d = wx - wy; + s0 = s0.dot_simd(d, d); + + let wx: i16s = (vx >> 4 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 4 & mask).reinterpret_simd(); + let d = wx - wy; + s1 = s1.dot_simd(d, d); + + let wx: i16s = (vx >> 8 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 8 & mask).reinterpret_simd(); + let d = wx - wy; + s2 = s2.dot_simd(d, d); + + let wx: i16s = (vx >> 12 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); + let d = wx - wy; + s3 = s3.dot_simd(d, d); + + i += remainder; + + s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32; + } + + // Convert blocks to indexes. + i *= 8; + + // Deal with the remainder the slow way. + if i != len { + // Outline the fallback routine to keep code-generation at this level cleaner. + #[inline(never)] + fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 { + let mut s: i32 = 0; + for i in from..x.len() { + // SAFETY: `i` is guaranteed to be less than `x.len()`. + let ix = unsafe { x.get_unchecked(i) } as i32; + // SAFETY: `i` is guaranteed to be less than `y.len()`. + let iy = unsafe { y.get_unchecked(i) } as i32; + let d = ix - iy; + s += d * d; + } + s as u32 + } + s += fallback(x, y, i); + } + + Ok(MV::new(s)) + } +} + /// Compute the squared L2 distance between `x` and `y`. /// /// Returns an error if the arguments have different lengths. @@ -797,7 +945,7 @@ impl_fallback_l2!(7, 6, 5, 4, 3, 2); retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3); #[cfg(target_arch = "x86_64")] -retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2); +retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 3, 2); dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8); #[cfg(target_arch = "aarch64")] @@ -1115,6 +1263,126 @@ impl Target2, USlice<'_, } } +/// Compute the inner product between `x` and `y`. +/// +/// Returns an error if the arguments have different lengths. +/// +/// # Implementation Notes +/// +/// This is optimized around the `__mm512_dpbusd_epi32` VNNI instruction, which computes the +/// pairwise dot product between vectors of 8-bit integers and accumulates groups of 4 with +/// an `i32` accumulation vector. +/// +/// For 4-bit values, each byte holds 2 nibbles. We load data as `u32x16`, mask with +/// `0x0f0f0f0f` to extract the low nibbles as bytes, and shift right by 4 then mask to +/// extract the high nibbles. This gives us `u8x64` / `i8x64` operands for VNNI, requiring +/// only 2 shift positions instead of 4 for the V3 `madd_epi16` approach. +/// +/// Since AVX-512 does not have an 8-bit shift instruction, we load data as `u32x16` +/// (which has a native shift) and bit-cast to `u8x64` as needed. +#[cfg(target_arch = "x86_64")] +impl Target2, USlice<'_, 4>, USlice<'_, 4>> + for InnerProduct +{ + #[expect(non_camel_case_types)] + #[inline(always)] + fn run( + self, + arch: diskann_wide::arch::x86_64::V4, + x: USlice<'_, 4>, + y: USlice<'_, 4>, + ) -> MathematicalResult { + let len = check_lengths!(x, y)?; + + type i32s = ::i32x16; + type u32s = ::u32x16; + type u8s = ::u8x64; + type i8s = ::i8x64; + + let px_u32: *const u32 = x.as_ptr().cast(); + let py_u32: *const u32 = y.as_ptr().cast(); + + let mut i = 0; + let mut s: u32 = 0; + + // The number of 32-bit blocks over the underlying slice. + // Each u32 holds 8 nibbles = 8 four-bit values. + let blocks = len.div_ceil(8); + if i < blocks { + let mut s0 = i32s::default(arch); + let mut s1 = i32s::default(arch); + let mask = u32s::splat(arch, 0x0f0f0f0f); + while i + 16 < blocks { + // SAFETY: We have checked that `i + 16 < blocks` which means the address + // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::())` is + // valid. + // + // The load has no alignment requirements. + let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) }; + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) }; + + let wx: u8s = (vx & mask).reinterpret_simd(); + let wy: i8s = (vy & mask).reinterpret_simd(); + s0 = s0.dot_simd(wx, wy); + + let wx: u8s = ((vx >> 4) & mask).reinterpret_simd(); + let wy: i8s = ((vy >> 4) & mask).reinterpret_simd(); + s1 = s1.dot_simd(wx, wy); + + i += 16; + } + + // Here + // * `len / 2` gives the number of full bytes (2 nibbles per byte) + // * `4 * i` gives the number of bytes processed (4 bytes per u32 × i u32s). + let remainder = len / 2 - 4 * i; + + // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`. + // + // The predicated load is guaranteed not to access memory after `remainder` and + // has no alignment requirements. + let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; + let vx: u32s = vx.reinterpret_simd(); + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::(), remainder) }; + let vy: u32s = vy.reinterpret_simd(); + + let wx: u8s = (vx & mask).reinterpret_simd(); + let wy: i8s = (vy & mask).reinterpret_simd(); + s0 = s0.dot_simd(wx, wy); + + let wx: u8s = ((vx >> 4) & mask).reinterpret_simd(); + let wy: i8s = ((vy >> 4) & mask).reinterpret_simd(); + s1 = s1.dot_simd(wx, wy); + + s = (s0 + s1).sum_tree() as u32; + i = (4 * i) + remainder; + } + + // Convert bytes to nibble indexes. + i *= 2; + + // Deal with the remainder the slow way (at most 1 element). + debug_assert!(len - i <= 1); + if i != len { + // SAFETY: `i` is guaranteed to be less than `x.len()`. + let ix = unsafe { x.get_unchecked(i) } as u32; + // SAFETY: `i` is guaranteed to be less than `y.len()`. + let iy = unsafe { y.get_unchecked(i) } as u32; + s += ix * iy; + } + + Ok(MV::new(s)) + } +} + /// Compute the inner product between `x` and `y`. /// /// Returns an error if the arguments have different lengths. @@ -1346,7 +1614,6 @@ retarget!( 7, 6, 5, - 4, 3, (8, 4), (8, 2), diff --git a/diskann-quantization/src/spherical/__codegen/x86_64.rs b/diskann-quantization/src/spherical/__codegen/x86_64.rs index 0c920fc58..d89c16f49 100644 --- a/diskann-quantization/src/spherical/__codegen/x86_64.rs +++ b/diskann-quantization/src/spherical/__codegen/x86_64.rs @@ -335,6 +335,90 @@ pub fn twobit_v3_cosine_full_data( // 4-bit // /////////// +//----// +// V4 // +//----// + +#[inline(never)] +pub fn fourbit_v4_l2_data_data(arch: V4, dim: usize) -> Result { + let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( + vectors::CompensatedSquaredL2::new(dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_ip_data_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( + vectors::CompensatedIP::new(shift, dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_cosine_data_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( + vectors::CompensatedCosine::new(vectors::CompensatedIP::new(shift, dim)), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_l2_query_data(arch: V4, dim: usize) -> Result { + let reify = Reify::<_, _, AsQuery<4, Dense>, AsData<4>>::new( + vectors::CompensatedSquaredL2::new(dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_ip_query_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsQuery<4, Dense>, AsData<4>>::new( + vectors::CompensatedIP::new(shift, dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_cosine_query_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsQuery<4, Dense>, AsData<4>>::new( + vectors::CompensatedCosine::new(vectors::CompensatedIP::new(shift, dim)), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +//----// +// V3 // +//----// + #[inline(never)] pub fn fourbit_v3_l2_data_data(arch: V3, dim: usize) -> Result { let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( diff --git a/diskann-quantization/src/spherical/iface.rs b/diskann-quantization/src/spherical/iface.rs index d0f3d4fa4..e16aa827c 100644 --- a/diskann-quantization/src/spherical/iface.rs +++ b/diskann-quantization/src/spherical/iface.rs @@ -1378,12 +1378,12 @@ cfg_if::cfg_if! { dispatch_map!(1, AsData<1>, V4, downcast_to_v3); dispatch_map!(2, AsData<2>, V4); // specialized - dispatch_map!(4, AsData<4>, V4, downcast_to_v3); + dispatch_map!(4, AsData<4>, V4); // specialized dispatch_map!(8, AsData<8>, V4, downcast_to_v3); dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3); dispatch_map!(2, AsQuery<2>, V4); // specialized - dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3); + dispatch_map!(4, AsQuery<4>, V4); // specialized dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3); } else if #[cfg(target_arch = "aarch64")] { fn downcast(arch: Neon) -> Scalar { From cc4074a688eda223fe82fb489407e056de8fca1b Mon Sep 17 00:00:00 2001 From: Mehmet Yilmaz Date: Fri, 8 May 2026 10:41:32 +0300 Subject: [PATCH 2/6] quantization: clean up V4 4-bit distance kernels - Use diskann_wide::alias! for type aliases. - Add V3<->V4 cross-references and minor doc fixes. - Add a debug_assert on the trailing scalar fallback bound. - Gate the predicated tail behind 'if remainder > 0'. - Note (vs the initial cleanup): keep the V4 4-bit InnerProduct main loop bound as 'i + 16 < blocks'. Because IP uses 'blocks = len.div_ceil(8)' (so a partial trailing byte can be mopped up by the predicated u8x64 tail), relaxing this to '<=' would cause the SIMD loop to read 4 bytes past the buffer end whenever 'blocks' is a multiple of 16 but 'len' is not a multiple of 8 (e.g. len = 121). The L2 kernel is unaffected because it uses 'blocks = len / 8' (rounding down). Co-authored-by: B Harsha Kashyap Co-authored-by: Krishnakumar Ravi (KK) --- diskann-quantization/src/bits/distances.rs | 156 ++++++++++++--------- 1 file changed, 89 insertions(+), 67 deletions(-) diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index 700cb877a..b7313c322 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -565,10 +565,13 @@ impl Target2, USlice<'_, /// # Implementation Notes /// /// This implementation is optimized for x86 with the AVX-512 vector extension. -/// It scales the V3 approach to 512-bit registers: we load data as `u32x16`, shift and -/// mask to extract 4-bit nibbles at 16-bit granularity (`0x000f000f` mask), reinterpret -/// as `i16x32`, compute differences, and use `_mm512_madd_epi16` via `dot_simd` to -/// accumulate squared differences into `i32x16`. +/// It is structurally identical to the V3 4-bit `SquaredL2` impl above, scaled to 512-bit +/// registers (16 u32 lanes × 4 nibble positions). **If you fix a correctness bug here, +/// fix it in the V3 impl as well.** +/// +/// We load data as `u32x16`, shift and mask to extract 4-bit nibbles at 16-bit granularity +/// (`0x000f000f` mask), reinterpret as `i16x32`, compute differences, and use +/// `_mm512_madd_epi16` via `dot_simd` to accumulate squared differences into `i32x16`. /// /// AVX-512 does not have 16-bit integer bit-shift instructions, so we use 32-bit integer /// shifts and then bit-cast to 16-bit intrinsics, which works because we apply the same @@ -577,7 +580,6 @@ impl Target2, USlice<'_, impl Target2, USlice<'_, 4>, USlice<'_, 4>> for SquaredL2 { - #[expect(non_camel_case_types)] #[inline(always)] fn run( self, @@ -587,9 +589,9 @@ impl Target2, USlice<'_, ) -> MathematicalResult { let len = check_lengths!(x, y)?; - type i32s = ::i32x16; - type u32s = ::u32x16; - type i16s = ::i16x32; + diskann_wide::alias!(i32s = ::i32x16); + diskann_wide::alias!(u32s = ::u32x16); + diskann_wide::alias!(i16s = ::i16x32); let px_u32: *const u32 = x.as_ptr().cast(); let py_u32: *const u32 = y.as_ptr().cast(); @@ -605,8 +607,8 @@ impl Target2, USlice<'_, let mut s2 = i32s::default(arch); let mut s3 = i32s::default(arch); let mask = u32s::splat(arch, 0x000f000f); - while i + 16 < blocks { - // SAFETY: We have checked that `i + 16 < blocks` which means the address + while i + 16 <= blocks { + // SAFETY: We have checked that `i + 16 <= blocks` which means the address // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::())` is // valid. // @@ -643,39 +645,41 @@ impl Target2, USlice<'_, let remainder = blocks - i; - // SAFETY: At least one value of type `u32` is valid for an unaligned starting - // at offset `i`. The exact number is computed as `remainder`. - // - // The predicated load is guaranteed not to access memory after `remainder` and - // has no alignment requirements. - let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) }; + if remainder > 0 { + // SAFETY: At least one value of type `u32` is valid for an unaligned + // starting at offset `i`. The exact number is computed as `remainder`. + // + // The predicated load is guaranteed not to access memory after `remainder` + // and has no alignment requirements. + let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) }; - // SAFETY: The same logic applies to `y` because: - // 1. It has the same type as `x`. - // 2. We've verified that it has the same length as `x`. - let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) }; + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) }; - let wx: i16s = (vx & mask).reinterpret_simd(); - let wy: i16s = (vy & mask).reinterpret_simd(); - let d = wx - wy; - s0 = s0.dot_simd(d, d); + let wx: i16s = (vx & mask).reinterpret_simd(); + let wy: i16s = (vy & mask).reinterpret_simd(); + let d = wx - wy; + s0 = s0.dot_simd(d, d); - let wx: i16s = (vx >> 4 & mask).reinterpret_simd(); - let wy: i16s = (vy >> 4 & mask).reinterpret_simd(); - let d = wx - wy; - s1 = s1.dot_simd(d, d); + let wx: i16s = (vx >> 4 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 4 & mask).reinterpret_simd(); + let d = wx - wy; + s1 = s1.dot_simd(d, d); - let wx: i16s = (vx >> 8 & mask).reinterpret_simd(); - let wy: i16s = (vy >> 8 & mask).reinterpret_simd(); - let d = wx - wy; - s2 = s2.dot_simd(d, d); + let wx: i16s = (vx >> 8 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 8 & mask).reinterpret_simd(); + let d = wx - wy; + s2 = s2.dot_simd(d, d); - let wx: i16s = (vx >> 12 & mask).reinterpret_simd(); - let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); - let d = wx - wy; - s3 = s3.dot_simd(d, d); + let wx: i16s = (vx >> 12 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); + let d = wx - wy; + s3 = s3.dot_simd(d, d); - i += remainder; + i += remainder; + } s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32; } @@ -683,6 +687,9 @@ impl Target2, USlice<'_, // Convert blocks to indexes. i *= 8; + // At most 7 nibbles can dangle past the last full u32 block. + debug_assert!(len - i < 8); + // Deal with the remainder the slow way. if i != len { // Outline the fallback routine to keep code-generation at this level cleaner. @@ -1269,9 +1276,10 @@ impl Target2, USlice<'_, /// /// # Implementation Notes /// -/// This is optimized around the `__mm512_dpbusd_epi32` VNNI instruction, which computes the -/// pairwise dot product between vectors of 8-bit integers and accumulates groups of 4 with -/// an `i32` accumulation vector. +/// Unlike the V3 4-bit `InnerProduct` impl (which uses `_mm256_madd_epi16` over `i16` lanes +/// and 4 shift positions), this is optimized around the `_mm512_dpbusd_epi32` VNNI +/// instruction, which computes the pairwise dot product between vectors of 8-bit integers +/// and accumulates groups of 4 with an `i32` accumulation vector. /// /// For 4-bit values, each byte holds 2 nibbles. We load data as `u32x16`, mask with /// `0x0f0f0f0f` to extract the low nibbles as bytes, and shift right by 4 then mask to @@ -1284,7 +1292,6 @@ impl Target2, USlice<'_, impl Target2, USlice<'_, 4>, USlice<'_, 4>> for InnerProduct { - #[expect(non_camel_case_types)] #[inline(always)] fn run( self, @@ -1294,10 +1301,10 @@ impl Target2, USlice<'_, ) -> MathematicalResult { let len = check_lengths!(x, y)?; - type i32s = ::i32x16; - type u32s = ::u32x16; - type u8s = ::u8x64; - type i8s = ::i8x64; + diskann_wide::alias!(i32s = ::i32x16); + diskann_wide::alias!(u32s = ::u32x16); + diskann_wide::alias!(u8s = ::u8x64); + diskann_wide::alias!(i8s = ::i8x64); let px_u32: *const u32 = x.as_ptr().cast(); let py_u32: *const u32 = y.as_ptr().cast(); @@ -1305,8 +1312,13 @@ impl Target2, USlice<'_, let mut i = 0; let mut s: u32 = 0; - // The number of 32-bit blocks over the underlying slice. - // Each u32 holds 8 nibbles = 8 four-bit values. + // Number of u32 blocks (rounded up). Each u32 holds 8 nibbles = 8 four-bit values. + // We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still + // covered by the predicated `u8` load below; the scalar fallback only ever needs to + // handle at most a single dangling nibble. + // + // (The L2 kernel uses `len / 8` because its scalar fallback can handle up to 7 + // dangling nibbles itself.) let blocks = len.div_ceil(8); if i < blocks { let mut s0 = i32s::default(arch); @@ -1325,6 +1337,9 @@ impl Target2, USlice<'_, // 2. We've verified that it has the same length as `x`. let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) }; + // VNNI `vpdpbusd` requires (unsigned, signed) operands in this order; + // `dot_simd(u8s, i8s)` lowers to that intrinsic on V4. 4-bit values fit in + // the positive half of `i8`, so the unsigned->signed reinterpret is safe. let wx: u8s = (vx & mask).reinterpret_simd(); let wy: i8s = (vy & mask).reinterpret_simd(); s0 = s0.dot_simd(wx, wy); @@ -1336,31 +1351,38 @@ impl Target2, USlice<'_, i += 16; } - // Here - // * `len / 2` gives the number of full bytes (2 nibbles per byte) - // * `4 * i` gives the number of bytes processed (4 bytes per u32 × i u32s). + // Compute the number of bytes still to process: + // * `len / 2` — total full bytes in the input (2 nibbles/byte). + // * `4 * i` — bytes already consumed (4 bytes per u32 × `i` u32s). + // + // The loop invariant `i + 16 < blocks` (with `blocks = len.div_ceil(8)`) + // guarantees `4 * i < len / 2 + small`, so this subtraction is safe. let remainder = len / 2 - 4 * i; - // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`. - // - // The predicated load is guaranteed not to access memory after `remainder` and - // has no alignment requirements. - let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; - let vx: u32s = vx.reinterpret_simd(); + if remainder > 0 { + // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`. + // + // The predicated load is guaranteed not to access memory after `remainder` + // and has no alignment requirements. + let vx = + unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; + let vx: u32s = vx.reinterpret_simd(); - // SAFETY: The same logic applies to `y` because: - // 1. It has the same type as `x`. - // 2. We've verified that it has the same length as `x`. - let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::(), remainder) }; - let vy: u32s = vy.reinterpret_simd(); + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = + unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::(), remainder) }; + let vy: u32s = vy.reinterpret_simd(); - let wx: u8s = (vx & mask).reinterpret_simd(); - let wy: i8s = (vy & mask).reinterpret_simd(); - s0 = s0.dot_simd(wx, wy); + let wx: u8s = (vx & mask).reinterpret_simd(); + let wy: i8s = (vy & mask).reinterpret_simd(); + s0 = s0.dot_simd(wx, wy); - let wx: u8s = ((vx >> 4) & mask).reinterpret_simd(); - let wy: i8s = ((vy >> 4) & mask).reinterpret_simd(); - s1 = s1.dot_simd(wx, wy); + let wx: u8s = ((vx >> 4) & mask).reinterpret_simd(); + let wy: i8s = ((vy >> 4) & mask).reinterpret_simd(); + s1 = s1.dot_simd(wx, wy); + } s = (s0 + s1).sum_tree() as u32; i = (4 * i) + remainder; From 1019ec613bec4a7b49b9789654d6587ae4ef4ddf Mon Sep 17 00:00:00 2001 From: Mehmet Yilmaz Date: Fri, 8 May 2026 13:59:28 +0300 Subject: [PATCH 3/6] wide: add Emulated => Emulated reinterpret for V4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The new V4 4-bit SquaredL2 kernel in diskann-quantization reinterprets 'u32x16' loads as 'i16x32' to feed '_mm512_madd_epi16'. The intrinsic- backed conversion 'u32x16: SIMDReinterpret' was already provided in 'arch/x86_64/v4/conversion.rs', but the parallel impl on the emulated representation 'Emulated: SIMDReinterpret>' was missing. This impl is exercised by the coverage build (--cfg=miri + instrument-coverage) which runs the SquaredL2 V4 kernel against the Emulated<…, V4> types. CI failed with E0277 'unsatisfied trait bound' at 8 call sites in diskann-quantization/src/bits/distances.rs. The new entry mirrors the existing ' => ' impl used by V3. Co-authored-by: B Harsha Kashyap Co-authored-by: Krishnakumar Ravi (KK) --- diskann-wide/src/emulated.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/diskann-wide/src/emulated.rs b/diskann-wide/src/emulated.rs index 7477207ec..a91d5d767 100644 --- a/diskann-wide/src/emulated.rs +++ b/diskann-wide/src/emulated.rs @@ -610,6 +610,7 @@ macro_rules! impl_little_endian_transmute_cast { } impl_little_endian_transmute_cast!( => ); +impl_little_endian_transmute_cast!( => ); impl_little_endian_transmute_cast!( => ); impl_little_endian_transmute_cast!( => ); From e26e4a9194587422f1ac98513f8e92586ddb7507 Mon Sep 17 00:00:00 2001 From: Mehmet Yilmaz Date: Fri, 8 May 2026 14:41:59 +0300 Subject: [PATCH 4/6] quantization: address review comments on V4 4-bit kernel docs Three Copilot review nits from PR #1045: - L2 doc (~line 578): the claim 'AVX-512 does not have 16-bit integer bit-shift instructions' is incorrect; V4 enables avx512bw which provides 'vpsrlw'. Reword to give the actual reason: the project's 'diskann_wide' type abstractions provide a native shift on 'u32x16' but not on 'i16x32', so we shift in u32 view and reinterpret. - L2 SAFETY comment (~line 612) and IP SAFETY comment (~line 1330): these mixed pointer-element arithmetic with byte sizes ('px_u32 + i + 16 * size_of::()' on a *const u32). Rewrite consistently in u32-element units: 'px_u32.add(i)..px_u32.add(i + 16) (in u32 units)'. Co-authored-by: B Harsha Kashyap Co-authored-by: Krishnakumar Ravi (KK) --- diskann-quantization/src/bits/distances.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index b7313c322..29e632c6b 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -573,9 +573,10 @@ impl Target2, USlice<'_, /// (`0x000f000f` mask), reinterpret as `i16x32`, compute differences, and use /// `_mm512_madd_epi16` via `dot_simd` to accumulate squared differences into `i32x16`. /// -/// AVX-512 does not have 16-bit integer bit-shift instructions, so we use 32-bit integer -/// shifts and then bit-cast to 16-bit intrinsics, which works because we apply the same -/// shift to all lanes. +/// We perform shifts on the `u32x16` view rather than the `i16x32` view because the +/// `diskann_wide` type abstractions provide a native shift on `u32x16` but not on +/// `i16x32`. The reinterpret-after-shift is well-defined because the same shift amount +/// is applied uniformly to all lanes. #[cfg(target_arch = "x86_64")] impl Target2, USlice<'_, 4>, USlice<'_, 4>> for SquaredL2 @@ -608,9 +609,9 @@ impl Target2, USlice<'_, let mut s3 = i32s::default(arch); let mask = u32s::splat(arch, 0x000f000f); while i + 16 <= blocks { - // SAFETY: We have checked that `i + 16 <= blocks` which means the address - // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::())` is - // valid. + // SAFETY: We have checked that `i + 16 <= blocks` which means the + // 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units) + // is dereferenceable. // // The load has no alignment requirements. let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) }; @@ -1325,9 +1326,9 @@ impl Target2, USlice<'_, let mut s1 = i32s::default(arch); let mask = u32s::splat(arch, 0x0f0f0f0f); while i + 16 < blocks { - // SAFETY: We have checked that `i + 16 < blocks` which means the address - // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::())` is - // valid. + // SAFETY: We have checked that `i + 16 < blocks` which means the + // 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units) + // is dereferenceable. // // The load has no alignment requirements. let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) }; From 45f5b7fffd63b70737cb936605dbbe782c0bc7fd Mon Sep 17 00:00:00 2001 From: Mehmet Yilmaz Date: Fri, 8 May 2026 15:56:50 +0300 Subject: [PATCH 5/6] quantization: address review comments on V4 4-bit predicated tail SAFETY Two Copilot review nits from PR #1045: - L2 predicated tail (~line 651): grammar fix ('an unaligned starting' -> rewritten to refer concretely to 'remainder values of type u32 ... starting at px_u32.add(i)'). - IP predicated tail (~line 1369): the SAFETY comment said the load starts 'at an offset of i', but the pointer is 'px_u32.add(i).cast::()' so the byte offset from px_u32 is 4*i. Rewrite the comment in consistent units (refer directly to the byte pointer and call out the 4*i byte offset). Co-authored-by: B Harsha Kashyap Co-authored-by: Krishnakumar Ravi (KK) --- diskann-quantization/src/bits/distances.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index 29e632c6b..95d57cce3 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -647,11 +647,11 @@ impl Target2, USlice<'_, let remainder = blocks - i; if remainder > 0 { - // SAFETY: At least one value of type `u32` is valid for an unaligned - // starting at offset `i`. The exact number is computed as `remainder`. + // SAFETY: `remainder` values of type `u32` are dereferenceable starting + // at `px_u32.add(i)` (i.e. element offset `i` in `u32` units). // - // The predicated load is guaranteed not to access memory after `remainder` - // and has no alignment requirements. + // The predicated load is guaranteed not to access memory beyond the + // first `remainder` lanes and has no alignment requirements. let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) }; // SAFETY: The same logic applies to `y` because: @@ -1361,10 +1361,11 @@ impl Target2, USlice<'_, let remainder = len / 2 - 4 * i; if remainder > 0 { - // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`. + // SAFETY: `remainder` bytes are dereferenceable starting at + // `px_u32.add(i).cast::()` (i.e. byte offset `4 * i` from `px_u32`). // - // The predicated load is guaranteed not to access memory after `remainder` - // and has no alignment requirements. + // The predicated load is guaranteed not to access memory beyond the + // first `remainder` bytes and has no alignment requirements. let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; let vx: u32s = vx.reinterpret_simd(); From 08beb308942838400107a7ac21f849249d7dd01c Mon Sep 17 00:00:00 2001 From: Mehmet Yilmaz Date: Fri, 8 May 2026 21:12:03 +0300 Subject: [PATCH 6/6] quantization: address V4 4-bit review feedback - L2 V4: drop V3-contrast / misleading shift docs; reword to focus on what this kernel does (per @hildebrandmw's review). - L2 V4: switch the trailing tail to a predicated u8 load (matching the IP V4 kernel), so the scalar fallback only handles at most one dangling nibble instead of up to 7. blocks now uses div_ceil(8) and the main loop uses '< blocks', mirroring the IP layout. - IP V4: drop V3-contrast docs / parenthetical, focus on what this kernel does. - Bump BITSLICE_TEST_BOUNDS for (4, X86_64_V4) from (256, 150) to (512, 300) so miri exercises the doubled main-loop width. Co-authored-by: B Harsha Kashyap Co-authored-by: Krishnakumar Ravi (KK) --- diskann-quantization/src/bits/distances.rs | 94 ++++++++++------------ 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index 95d57cce3..aedd2e200 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -565,18 +565,12 @@ impl Target2, USlice<'_, /// # Implementation Notes /// /// This implementation is optimized for x86 with the AVX-512 vector extension. -/// It is structurally identical to the V3 4-bit `SquaredL2` impl above, scaled to 512-bit -/// registers (16 u32 lanes × 4 nibble positions). **If you fix a correctness bug here, -/// fix it in the V3 impl as well.** /// /// We load data as `u32x16`, shift and mask to extract 4-bit nibbles at 16-bit granularity /// (`0x000f000f` mask), reinterpret as `i16x32`, compute differences, and use -/// `_mm512_madd_epi16` via `dot_simd` to accumulate squared differences into `i32x16`. -/// -/// We perform shifts on the `u32x16` view rather than the `i16x32` view because the -/// `diskann_wide` type abstractions provide a native shift on `u32x16` but not on -/// `i16x32`. The reinterpret-after-shift is well-defined because the same shift amount -/// is applied uniformly to all lanes. +/// `_mm512_madd_epi16` via `dot_simd` to accumulate squared differences into 4 independent +/// `i32x16` accumulators. Reinterpreting after a shift is well-defined because the same +/// shift amount is applied uniformly to all lanes. #[cfg(target_arch = "x86_64")] impl Target2, USlice<'_, 4>, USlice<'_, 4>> for SquaredL2 @@ -592,6 +586,7 @@ impl Target2, USlice<'_, diskann_wide::alias!(i32s = ::i32x16); diskann_wide::alias!(u32s = ::u32x16); + diskann_wide::alias!(u8s = ::u8x64); diskann_wide::alias!(i16s = ::i16x32); let px_u32: *const u32 = x.as_ptr().cast(); @@ -600,16 +595,19 @@ impl Target2, USlice<'_, let mut i = 0; let mut s: u32 = 0; - // The number of 32-bit blocks over the underlying slice. - let blocks = len / 8; + // Number of u32 blocks (rounded up). Each u32 holds 8 nibbles = 8 four-bit values. + // We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still + // covered by the predicated `u8` load below; the scalar fallback only ever needs to + // handle at most a single dangling nibble. + let blocks = len.div_ceil(8); if i < blocks { let mut s0 = i32s::default(arch); let mut s1 = i32s::default(arch); let mut s2 = i32s::default(arch); let mut s3 = i32s::default(arch); let mask = u32s::splat(arch, 0x000f000f); - while i + 16 <= blocks { - // SAFETY: We have checked that `i + 16 <= blocks` which means the + while i + 16 < blocks { + // SAFETY: We have checked that `i + 16 < blocks` which means the // 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units) // is dereferenceable. // @@ -644,20 +642,30 @@ impl Target2, USlice<'_, i += 16; } - let remainder = blocks - i; + // Compute the number of bytes still to process: + // * `len / 2` — total full bytes in the input (2 nibbles/byte). + // * `4 * i` — bytes already consumed (4 bytes per u32 × `i` u32s). + // + // The loop invariant `i + 16 < blocks` (with `blocks = len.div_ceil(8)`) + // guarantees `4 * i < len / 2 + small`, so this subtraction is safe. + let remainder = len / 2 - 4 * i; if remainder > 0 { - // SAFETY: `remainder` values of type `u32` are dereferenceable starting - // at `px_u32.add(i)` (i.e. element offset `i` in `u32` units). + // SAFETY: `remainder` bytes are dereferenceable starting at + // `px_u32.add(i).cast::()` (i.e. byte offset `4 * i` from `px_u32`). // // The predicated load is guaranteed not to access memory beyond the - // first `remainder` lanes and has no alignment requirements. - let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) }; + // first `remainder` bytes and has no alignment requirements. + let vx = + unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; + let vx: u32s = vx.reinterpret_simd(); // SAFETY: The same logic applies to `y` because: // 1. It has the same type as `x`. // 2. We've verified that it has the same length as `x`. - let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) }; + let vy = + unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::(), remainder) }; + let vy: u32s = vy.reinterpret_simd(); let wx: i16s = (vx & mask).reinterpret_simd(); let wy: i16s = (vy & mask).reinterpret_simd(); @@ -678,36 +686,24 @@ impl Target2, USlice<'_, let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); let d = wx - wy; s3 = s3.dot_simd(d, d); - - i += remainder; } s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32; + i = (4 * i) + remainder; } - // Convert blocks to indexes. - i *= 8; - - // At most 7 nibbles can dangle past the last full u32 block. - debug_assert!(len - i < 8); + // Convert bytes to nibble indexes. + i *= 2; - // Deal with the remainder the slow way. + // Deal with the remainder the slow way (at most 1 element). + debug_assert!(len - i <= 1); if i != len { - // Outline the fallback routine to keep code-generation at this level cleaner. - #[inline(never)] - fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 { - let mut s: i32 = 0; - for i in from..x.len() { - // SAFETY: `i` is guaranteed to be less than `x.len()`. - let ix = unsafe { x.get_unchecked(i) } as i32; - // SAFETY: `i` is guaranteed to be less than `y.len()`. - let iy = unsafe { y.get_unchecked(i) } as i32; - let d = ix - iy; - s += d * d; - } - s as u32 - } - s += fallback(x, y, i); + // SAFETY: `i` is guaranteed to be less than `x.len()`. + let ix = unsafe { x.get_unchecked(i) } as i32; + // SAFETY: `i` is guaranteed to be less than `y.len()`. + let iy = unsafe { y.get_unchecked(i) } as i32; + let d = ix - iy; + s += (d * d) as u32; } Ok(MV::new(s)) @@ -1277,15 +1273,14 @@ impl Target2, USlice<'_, /// /// # Implementation Notes /// -/// Unlike the V3 4-bit `InnerProduct` impl (which uses `_mm256_madd_epi16` over `i16` lanes -/// and 4 shift positions), this is optimized around the `_mm512_dpbusd_epi32` VNNI +/// This implementation is optimized around the AVX-512 VNNI `_mm512_dpbusd_epi32` /// instruction, which computes the pairwise dot product between vectors of 8-bit integers -/// and accumulates groups of 4 with an `i32` accumulation vector. +/// and accumulates groups of 4 into an `i32` accumulator. /// /// For 4-bit values, each byte holds 2 nibbles. We load data as `u32x16`, mask with /// `0x0f0f0f0f` to extract the low nibbles as bytes, and shift right by 4 then mask to -/// extract the high nibbles. This gives us `u8x64` / `i8x64` operands for VNNI, requiring -/// only 2 shift positions instead of 4 for the V3 `madd_epi16` approach. +/// extract the high nibbles. This yields `u8x64` / `i8x64` operands for VNNI, requiring +/// only 2 shift positions. /// /// Since AVX-512 does not have an 8-bit shift instruction, we load data as `u32x16` /// (which has a native shift) and bit-cast to `u8x64` as needed. @@ -1317,9 +1312,6 @@ impl Target2, USlice<'_, // We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still // covered by the predicated `u8` load below; the scalar fallback only ever needs to // handle at most a single dangling nibble. - // - // (The L2 kernel uses `len / 8` because its scalar fallback can handle up to 7 - // dangling nibbles itself.) let blocks = len.div_ceil(8); if i < blocks { let mut s0 = i32s::default(arch); @@ -2950,7 +2942,7 @@ mod tests { (Key::new(4, Scalar), Bounds::new(64, 64)), // Need a higher miri-amount due to the larget block size (Key::new(4, X86_64_V3), Bounds::new(256, 150)), - (Key::new(4, X86_64_V4), Bounds::new(256, 150)), + (Key::new(4, X86_64_V4), Bounds::new(512, 300)), (Key::new(4, Neon), Bounds::new(64, 64)), (Key::new(5, Scalar), Bounds::new(64, 64)), (Key::new(5, X86_64_V3), Bounds::new(256, 96)),