Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions diskann-quantization/src/__codegen/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ pub fn bits_v4_ip_bu2_bu2(arch: V4, x: USlice<'_, 2>, y: USlice<'_, 2>) -> MR<u3
arch.run2_inline(distances::InnerProduct, x, y)
}

#[inline(never)]
pub fn bits_v4_l2_bu4_bu4(arch: V4, x: USlice<'_, 4>, y: USlice<'_, 4>) -> MR<u32> {
arch.run2_inline(distances::SquaredL2, x, y)
}

#[inline(never)]
pub fn bits_v4_ip_bu4_bu4(arch: V4, x: USlice<'_, 4>, y: USlice<'_, 4>) -> MR<u32> {
arch.run2_inline(distances::InnerProduct, x, y)
}

//------------//
// Transposed //
//------------//
Expand Down
289 changes: 286 additions & 3 deletions diskann-quantization/src/bits/distances.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,158 @@ impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_,
}
}

/// Compute the squared L2 distance between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
///
/// # Implementation Notes
///
/// This implementation is optimized for x86 with the AVX-512 vector extension.
///
/// We load data as `u32x16`, shift and mask to extract 4-bit nibbles at 16-bit granularity
/// (`0x000f000f` mask), reinterpret as `i16x32`, compute differences, and use
/// `_mm512_madd_epi16` via `dot_simd` to accumulate squared differences into 4 independent
/// `i32x16` accumulators. Reinterpreting after a shift is well-defined because the same
/// shift amount is applied uniformly to all lanes.
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
for SquaredL2
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V4,
x: USlice<'_, 4>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;

diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V4>::i32x16);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V4>::u32x16);
diskann_wide::alias!(u8s = <diskann_wide::arch::x86_64::V4>::u8x64);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V4>::i16x32);

let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();

let mut i = 0;
let mut s: u32 = 0;

// Number of u32 blocks (rounded up). Each u32 holds 8 nibbles = 8 four-bit values.
// We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still
// covered by the predicated `u8` load below; the scalar fallback only ever needs to
// handle at most a single dangling nibble.
let blocks = len.div_ceil(8);
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x000f000f);
while i + 16 < blocks {
// SAFETY: We have checked that `i + 16 < blocks` which means the
// 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units)
// is dereferenceable.
//
// The load has no alignment requirements.
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };

let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);

let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);

let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);

let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);

i += 16;
}

// Compute the number of bytes still to process:
// * `len / 2` — total full bytes in the input (2 nibbles/byte).
// * `4 * i` — bytes already consumed (4 bytes per u32 × `i` u32s).
//
// The loop invariant `i + 16 < blocks` (with `blocks = len.div_ceil(8)`)
// guarantees `4 * i < len / 2 + small`, so this subtraction is safe.
let remainder = len / 2 - 4 * i;

if remainder > 0 {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the same thing here for the remainder that is done for 4-bit inner product? A masked 8-bit load instead of a masked 32-bit load with at most one left over for the final epilogue?

Copy link
Copy Markdown
Author

@m3hm3t m3hm3t May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 08beb30 . L2 V4 tail now mirrors the IP V4 layout: blocks = len.div_ceil(8), main loop while i + 16 < blocks, predicated u8x64 load for the trailing bytes.

// SAFETY: `remainder` bytes are dereferenceable starting at
// `px_u32.add(i).cast::<u8>()` (i.e. byte offset `4 * i` from `px_u32`).
//
// The predicated load is guaranteed not to access memory beyond the
// first `remainder` bytes and has no alignment requirements.
let vx =
unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
let vx: u32s = vx.reinterpret_simd();

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy =
unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
let vy: u32s = vy.reinterpret_simd();

let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);

let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);

let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);

let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
}

s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
i = (4 * i) + remainder;
}

// Convert bytes to nibble indexes.
i *= 2;

// Deal with the remainder the slow way (at most 1 element).
debug_assert!(len - i <= 1);
if i != len {
// SAFETY: `i` is guaranteed to be less than `x.len()`.
let ix = unsafe { x.get_unchecked(i) } as i32;
// SAFETY: `i` is guaranteed to be less than `y.len()`.
let iy = unsafe { y.get_unchecked(i) } as i32;
let d = ix - iy;
s += (d * d) as u32;
}

Ok(MV::new(s))
}
}

/// Compute the squared L2 distance between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
Expand Down Expand Up @@ -797,7 +949,7 @@ impl_fallback_l2!(7, 6, 5, 4, 3, 2);
retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);

#[cfg(target_arch = "x86_64")]
retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 3, 2);

dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -1115,6 +1267,138 @@ impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_,
}
}

/// Compute the inner product between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
///
/// # Implementation Notes
///
/// This implementation is optimized around the AVX-512 VNNI `_mm512_dpbusd_epi32`
/// instruction, which computes the pairwise dot product between vectors of 8-bit integers
/// and accumulates groups of 4 into an `i32` accumulator.
///
/// For 4-bit values, each byte holds 2 nibbles. We load data as `u32x16`, mask with
/// `0x0f0f0f0f` to extract the low nibbles as bytes, and shift right by 4 then mask to
/// extract the high nibbles. This yields `u8x64` / `i8x64` operands for VNNI, requiring
/// only 2 shift positions.
///
/// Since AVX-512 does not have an 8-bit shift instruction, we load data as `u32x16`
/// (which has a native shift) and bit-cast to `u8x64` as needed.
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V4,
x: USlice<'_, 4>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;

diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V4>::i32x16);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V4>::u32x16);
diskann_wide::alias!(u8s = <diskann_wide::arch::x86_64::V4>::u8x64);
diskann_wide::alias!(i8s = <diskann_wide::arch::x86_64::V4>::i8x64);

let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();

let mut i = 0;
let mut s: u32 = 0;

// Number of u32 blocks (rounded up). Each u32 holds 8 nibbles = 8 four-bit values.
// We use `div_ceil` so a partial trailing byte (1..=7 stray nibbles) is still
// covered by the predicated `u8` load below; the scalar fallback only ever needs to
// handle at most a single dangling nibble.
let blocks = len.div_ceil(8);
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mask = u32s::splat(arch, 0x0f0f0f0f);
while i + 16 < blocks {
// SAFETY: We have checked that `i + 16 < blocks` which means the
// 16-element range `px_u32.add(i)..px_u32.add(i + 16)` (in `u32` units)
// is dereferenceable.
//
// The load has no alignment requirements.
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };

// VNNI `vpdpbusd` requires (unsigned, signed) operands in this order;
// `dot_simd(u8s, i8s)` lowers to that intrinsic on V4. 4-bit values fit in
// the positive half of `i8`, so the unsigned->signed reinterpret is safe.
let wx: u8s = (vx & mask).reinterpret_simd();
let wy: i8s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);

let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);

i += 16;
}

// Compute the number of bytes still to process:
// * `len / 2` — total full bytes in the input (2 nibbles/byte).
// * `4 * i` — bytes already consumed (4 bytes per u32 × `i` u32s).
//
// The loop invariant `i + 16 < blocks` (with `blocks = len.div_ceil(8)`)
// guarantees `4 * i < len / 2 + small`, so this subtraction is safe.
let remainder = len / 2 - 4 * i;

if remainder > 0 {
// SAFETY: `remainder` bytes are dereferenceable starting at
// `px_u32.add(i).cast::<u8>()` (i.e. byte offset `4 * i` from `px_u32`).
//
// The predicated load is guaranteed not to access memory beyond the
// first `remainder` bytes and has no alignment requirements.
let vx =
unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
let vx: u32s = vx.reinterpret_simd();

// SAFETY: The same logic applies to `y` because:
// 1. It has the same type as `x`.
// 2. We've verified that it has the same length as `x`.
let vy =
unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
let vy: u32s = vy.reinterpret_simd();

let wx: u8s = (vx & mask).reinterpret_simd();
let wy: i8s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);

let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
}

s = (s0 + s1).sum_tree() as u32;
i = (4 * i) + remainder;
}

// Convert bytes to nibble indexes.
i *= 2;

// Deal with the remainder the slow way (at most 1 element).
debug_assert!(len - i <= 1);
if i != len {
// SAFETY: `i` is guaranteed to be less than `x.len()`.
let ix = unsafe { x.get_unchecked(i) } as u32;
// SAFETY: `i` is guaranteed to be less than `y.len()`.
let iy = unsafe { y.get_unchecked(i) } as u32;
s += ix * iy;
}

Ok(MV::new(s))
}
}

/// Compute the inner product between `x` and `y`.
///
/// Returns an error if the arguments have different lengths.
Expand Down Expand Up @@ -1346,7 +1630,6 @@ retarget!(
7,
6,
5,
4,
3,
(8, 4),
(8, 2),
Expand Down Expand Up @@ -2659,7 +2942,7 @@ mod tests {
(Key::new(4, Scalar), Bounds::new(64, 64)),
// Need a higher miri-amount due to the larget block size
(Key::new(4, X86_64_V3), Bounds::new(256, 150)),
(Key::new(4, X86_64_V4), Bounds::new(256, 150)),
(Key::new(4, X86_64_V4), Bounds::new(512, 300)),
(Key::new(4, Neon), Bounds::new(64, 64)),
(Key::new(5, Scalar), Bounds::new(64, 64)),
(Key::new(5, X86_64_V3), Bounds::new(256, 96)),
Expand Down
Loading
Loading