diff --git a/diskann-quantization/src/__codegen/x86_64.rs b/diskann-quantization/src/__codegen/x86_64.rs index 2ada1be64..91912811d 100644 --- a/diskann-quantization/src/__codegen/x86_64.rs +++ b/diskann-quantization/src/__codegen/x86_64.rs @@ -89,6 +89,16 @@ pub fn bits_v4_ip_bu2_bu2(arch: V4, x: USlice<'_, 2>, y: USlice<'_, 2>) -> MR, y: USlice<'_, 4>) -> MR { + arch.run2_inline(distances::SquaredL2, x, y) +} + +#[inline(never)] +pub fn bits_v4_ip_bu4_bu4(arch: V4, x: USlice<'_, 4>, y: USlice<'_, 4>) -> MR { + arch.run2_inline(distances::InnerProduct, x, y) +} + //------------// // Transposed // //------------// diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index e8e454dbf..aedd2e200 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -558,6 +558,158 @@ impl Target2, USlice<'_, } } +/// Compute the squared L2 distance between `x` and `y`. +/// +/// Returns an error if the arguments have different lengths. +/// +/// # Implementation Notes +/// +/// This implementation is optimized for x86 with the AVX-512 vector extension. +/// +/// We load data as `u32x16`, shift and mask to extract 4-bit nibbles at 16-bit granularity +/// (`0x000f000f` mask), reinterpret as `i16x32`, compute differences, and use +/// `_mm512_madd_epi16` via `dot_simd` to accumulate squared differences into 4 independent +/// `i32x16` accumulators. Reinterpreting after a shift is well-defined because the same +/// shift amount is applied uniformly to all lanes. +#[cfg(target_arch = "x86_64")] +impl Target2, USlice<'_, 4>, USlice<'_, 4>> + for SquaredL2 +{ + #[inline(always)] + fn run( + self, + arch: diskann_wide::arch::x86_64::V4, + x: USlice<'_, 4>, + y: USlice<'_, 4>, + ) -> MathematicalResult { + let len = check_lengths!(x, y)?; + + diskann_wide::alias!(i32s = ::i32x16); + diskann_wide::alias!(u32s = ::u32x16); + diskann_wide::alias!(u8s = ::u8x64); + diskann_wide::alias!(i16s = ::i16x32); + + let px_u32: *const u32 = x.as_ptr().cast(); + let py_u32: *const u32 = y.as_ptr().cast(); + + let mut i = 0; + let mut s: u32 = 0; + + // Number of u32 blocks (rounded up). Each u32 holds 8 nibbles = 8 four-bit values. + // We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still + // covered by the predicated `u8` load below; the scalar fallback only ever needs to + // handle at most a single dangling nibble. + let blocks = len.div_ceil(8); + if i < blocks { + let mut s0 = i32s::default(arch); + let mut s1 = i32s::default(arch); + let mut s2 = i32s::default(arch); + let mut s3 = i32s::default(arch); + let mask = u32s::splat(arch, 0x000f000f); + while i + 16 < blocks { + // SAFETY: We have checked that `i + 16 < blocks` which means the + // 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units) + // is dereferenceable. + // + // The load has no alignment requirements. + let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) }; + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) }; + + let wx: i16s = (vx & mask).reinterpret_simd(); + let wy: i16s = (vy & mask).reinterpret_simd(); + let d = wx - wy; + s0 = s0.dot_simd(d, d); + + let wx: i16s = (vx >> 4 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 4 & mask).reinterpret_simd(); + let d = wx - wy; + s1 = s1.dot_simd(d, d); + + let wx: i16s = (vx >> 8 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 8 & mask).reinterpret_simd(); + let d = wx - wy; + s2 = s2.dot_simd(d, d); + + let wx: i16s = (vx >> 12 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); + let d = wx - wy; + s3 = s3.dot_simd(d, d); + + i += 16; + } + + // Compute the number of bytes still to process: + // * `len / 2` — total full bytes in the input (2 nibbles/byte). + // * `4 * i` — bytes already consumed (4 bytes per u32 × `i` u32s). + // + // The loop invariant `i + 16 < blocks` (with `blocks = len.div_ceil(8)`) + // guarantees `4 * i < len / 2 + small`, so this subtraction is safe. + let remainder = len / 2 - 4 * i; + + if remainder > 0 { + // SAFETY: `remainder` bytes are dereferenceable starting at + // `px_u32.add(i).cast::()` (i.e. byte offset `4 * i` from `px_u32`). + // + // The predicated load is guaranteed not to access memory beyond the + // first `remainder` bytes and has no alignment requirements. + let vx = + unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; + let vx: u32s = vx.reinterpret_simd(); + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = + unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::(), remainder) }; + let vy: u32s = vy.reinterpret_simd(); + + let wx: i16s = (vx & mask).reinterpret_simd(); + let wy: i16s = (vy & mask).reinterpret_simd(); + let d = wx - wy; + s0 = s0.dot_simd(d, d); + + let wx: i16s = (vx >> 4 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 4 & mask).reinterpret_simd(); + let d = wx - wy; + s1 = s1.dot_simd(d, d); + + let wx: i16s = (vx >> 8 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 8 & mask).reinterpret_simd(); + let d = wx - wy; + s2 = s2.dot_simd(d, d); + + let wx: i16s = (vx >> 12 & mask).reinterpret_simd(); + let wy: i16s = (vy >> 12 & mask).reinterpret_simd(); + let d = wx - wy; + s3 = s3.dot_simd(d, d); + } + + s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32; + i = (4 * i) + remainder; + } + + // Convert bytes to nibble indexes. + i *= 2; + + // Deal with the remainder the slow way (at most 1 element). + debug_assert!(len - i <= 1); + if i != len { + // SAFETY: `i` is guaranteed to be less than `x.len()`. + let ix = unsafe { x.get_unchecked(i) } as i32; + // SAFETY: `i` is guaranteed to be less than `y.len()`. + let iy = unsafe { y.get_unchecked(i) } as i32; + let d = ix - iy; + s += (d * d) as u32; + } + + Ok(MV::new(s)) + } +} + /// Compute the squared L2 distance between `x` and `y`. /// /// Returns an error if the arguments have different lengths. @@ -797,7 +949,7 @@ impl_fallback_l2!(7, 6, 5, 4, 3, 2); retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3); #[cfg(target_arch = "x86_64")] -retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2); +retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 3, 2); dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8); #[cfg(target_arch = "aarch64")] @@ -1115,6 +1267,138 @@ impl Target2, USlice<'_, } } +/// Compute the inner product between `x` and `y`. +/// +/// Returns an error if the arguments have different lengths. +/// +/// # Implementation Notes +/// +/// This implementation is optimized around the AVX-512 VNNI `_mm512_dpbusd_epi32` +/// instruction, which computes the pairwise dot product between vectors of 8-bit integers +/// and accumulates groups of 4 into an `i32` accumulator. +/// +/// For 4-bit values, each byte holds 2 nibbles. We load data as `u32x16`, mask with +/// `0x0f0f0f0f` to extract the low nibbles as bytes, and shift right by 4 then mask to +/// extract the high nibbles. This yields `u8x64` / `i8x64` operands for VNNI, requiring +/// only 2 shift positions. +/// +/// Since AVX-512 does not have an 8-bit shift instruction, we load data as `u32x16` +/// (which has a native shift) and bit-cast to `u8x64` as needed. +#[cfg(target_arch = "x86_64")] +impl Target2, USlice<'_, 4>, USlice<'_, 4>> + for InnerProduct +{ + #[inline(always)] + fn run( + self, + arch: diskann_wide::arch::x86_64::V4, + x: USlice<'_, 4>, + y: USlice<'_, 4>, + ) -> MathematicalResult { + let len = check_lengths!(x, y)?; + + diskann_wide::alias!(i32s = ::i32x16); + diskann_wide::alias!(u32s = ::u32x16); + diskann_wide::alias!(u8s = ::u8x64); + diskann_wide::alias!(i8s = ::i8x64); + + let px_u32: *const u32 = x.as_ptr().cast(); + let py_u32: *const u32 = y.as_ptr().cast(); + + let mut i = 0; + let mut s: u32 = 0; + + // Number of u32 blocks (rounded up). Each u32 holds 8 nibbles = 8 four-bit values. + // We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still + // covered by the predicated `u8` load below; the scalar fallback only ever needs to + // handle at most a single dangling nibble. + let blocks = len.div_ceil(8); + if i < blocks { + let mut s0 = i32s::default(arch); + let mut s1 = i32s::default(arch); + let mask = u32s::splat(arch, 0x0f0f0f0f); + while i + 16 < blocks { + // SAFETY: We have checked that `i + 16 < blocks` which means the + // 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units) + // is dereferenceable. + // + // The load has no alignment requirements. + let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) }; + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) }; + + // VNNI `vpdpbusd` requires (unsigned, signed) operands in this order; + // `dot_simd(u8s, i8s)` lowers to that intrinsic on V4. 4-bit values fit in + // the positive half of `i8`, so the unsigned->signed reinterpret is safe. + let wx: u8s = (vx & mask).reinterpret_simd(); + let wy: i8s = (vy & mask).reinterpret_simd(); + s0 = s0.dot_simd(wx, wy); + + let wx: u8s = ((vx >> 4) & mask).reinterpret_simd(); + let wy: i8s = ((vy >> 4) & mask).reinterpret_simd(); + s1 = s1.dot_simd(wx, wy); + + i += 16; + } + + // Compute the number of bytes still to process: + // * `len / 2` — total full bytes in the input (2 nibbles/byte). + // * `4 * i` — bytes already consumed (4 bytes per u32 × `i` u32s). + // + // The loop invariant `i + 16 < blocks` (with `blocks = len.div_ceil(8)`) + // guarantees `4 * i < len / 2 + small`, so this subtraction is safe. + let remainder = len / 2 - 4 * i; + + if remainder > 0 { + // SAFETY: `remainder` bytes are dereferenceable starting at + // `px_u32.add(i).cast::()` (i.e. byte offset `4 * i` from `px_u32`). + // + // The predicated load is guaranteed not to access memory beyond the + // first `remainder` bytes and has no alignment requirements. + let vx = + unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::(), remainder) }; + let vx: u32s = vx.reinterpret_simd(); + + // SAFETY: The same logic applies to `y` because: + // 1. It has the same type as `x`. + // 2. We've verified that it has the same length as `x`. + let vy = + unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::(), remainder) }; + let vy: u32s = vy.reinterpret_simd(); + + let wx: u8s = (vx & mask).reinterpret_simd(); + let wy: i8s = (vy & mask).reinterpret_simd(); + s0 = s0.dot_simd(wx, wy); + + let wx: u8s = ((vx >> 4) & mask).reinterpret_simd(); + let wy: i8s = ((vy >> 4) & mask).reinterpret_simd(); + s1 = s1.dot_simd(wx, wy); + } + + s = (s0 + s1).sum_tree() as u32; + i = (4 * i) + remainder; + } + + // Convert bytes to nibble indexes. + i *= 2; + + // Deal with the remainder the slow way (at most 1 element). + debug_assert!(len - i <= 1); + if i != len { + // SAFETY: `i` is guaranteed to be less than `x.len()`. + let ix = unsafe { x.get_unchecked(i) } as u32; + // SAFETY: `i` is guaranteed to be less than `y.len()`. + let iy = unsafe { y.get_unchecked(i) } as u32; + s += ix * iy; + } + + Ok(MV::new(s)) + } +} + /// Compute the inner product between `x` and `y`. /// /// Returns an error if the arguments have different lengths. @@ -1346,7 +1630,6 @@ retarget!( 7, 6, 5, - 4, 3, (8, 4), (8, 2), @@ -2659,7 +2942,7 @@ mod tests { (Key::new(4, Scalar), Bounds::new(64, 64)), // Need a higher miri-amount due to the larget block size (Key::new(4, X86_64_V3), Bounds::new(256, 150)), - (Key::new(4, X86_64_V4), Bounds::new(256, 150)), + (Key::new(4, X86_64_V4), Bounds::new(512, 300)), (Key::new(4, Neon), Bounds::new(64, 64)), (Key::new(5, Scalar), Bounds::new(64, 64)), (Key::new(5, X86_64_V3), Bounds::new(256, 96)), diff --git a/diskann-quantization/src/spherical/__codegen/x86_64.rs b/diskann-quantization/src/spherical/__codegen/x86_64.rs index 0c920fc58..d89c16f49 100644 --- a/diskann-quantization/src/spherical/__codegen/x86_64.rs +++ b/diskann-quantization/src/spherical/__codegen/x86_64.rs @@ -335,6 +335,90 @@ pub fn twobit_v3_cosine_full_data( // 4-bit // /////////// +//----// +// V4 // +//----// + +#[inline(never)] +pub fn fourbit_v4_l2_data_data(arch: V4, dim: usize) -> Result { + let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( + vectors::CompensatedSquaredL2::new(dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_ip_data_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( + vectors::CompensatedIP::new(shift, dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_cosine_data_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( + vectors::CompensatedCosine::new(vectors::CompensatedIP::new(shift, dim)), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_l2_query_data(arch: V4, dim: usize) -> Result { + let reify = Reify::<_, _, AsQuery<4, Dense>, AsData<4>>::new( + vectors::CompensatedSquaredL2::new(dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_ip_query_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsQuery<4, Dense>, AsData<4>>::new( + vectors::CompensatedIP::new(shift, dim), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +#[inline(never)] +pub fn fourbit_v4_cosine_query_data( + arch: V4, + shift: &[f32], + dim: usize, +) -> Result { + let reify = Reify::<_, _, AsQuery<4, Dense>, AsData<4>>::new( + vectors::CompensatedCosine::new(vectors::CompensatedIP::new(shift, dim)), + dim, + arch, + ); + DistanceComputer::new(reify, GlobalAllocator) +} + +//----// +// V3 // +//----// + #[inline(never)] pub fn fourbit_v3_l2_data_data(arch: V3, dim: usize) -> Result { let reify = Reify::<_, _, AsData<4>, AsData<4>>::new( diff --git a/diskann-quantization/src/spherical/iface.rs b/diskann-quantization/src/spherical/iface.rs index d0f3d4fa4..e16aa827c 100644 --- a/diskann-quantization/src/spherical/iface.rs +++ b/diskann-quantization/src/spherical/iface.rs @@ -1378,12 +1378,12 @@ cfg_if::cfg_if! { dispatch_map!(1, AsData<1>, V4, downcast_to_v3); dispatch_map!(2, AsData<2>, V4); // specialized - dispatch_map!(4, AsData<4>, V4, downcast_to_v3); + dispatch_map!(4, AsData<4>, V4); // specialized dispatch_map!(8, AsData<8>, V4, downcast_to_v3); dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3); dispatch_map!(2, AsQuery<2>, V4); // specialized - dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3); + dispatch_map!(4, AsQuery<4>, V4); // specialized dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3); } else if #[cfg(target_arch = "aarch64")] { fn downcast(arch: Neon) -> Scalar { diff --git a/diskann-wide/src/emulated.rs b/diskann-wide/src/emulated.rs index 7477207ec..a91d5d767 100644 --- a/diskann-wide/src/emulated.rs +++ b/diskann-wide/src/emulated.rs @@ -610,6 +610,7 @@ macro_rules! impl_little_endian_transmute_cast { } impl_little_endian_transmute_cast!( => ); +impl_little_endian_transmute_cast!( => ); impl_little_endian_transmute_cast!( => ); impl_little_endian_transmute_cast!( => );