Skip to content

Add v4 distance kernels (4-bit SquaredL2 / InnerProduct)#1045

Open
m3hm3t wants to merge 6 commits intomicrosoft:mainfrom
m3hm3t:hakashya/avx512/add4bit
Open

Add v4 distance kernels (4-bit SquaredL2 / InnerProduct)#1045
m3hm3t wants to merge 6 commits intomicrosoft:mainfrom
m3hm3t:hakashya/avx512/add4bit

Conversation

@m3hm3t
Copy link
Copy Markdown

@m3hm3t m3hm3t commented May 8, 2026

  • Does this PR have a descriptive title that could go in our release notes?
  • Does this PR add any new dependencies? — No
  • Does this PR modify any existing APIs? — No (internal dispatch table only; no public API surface change)
  • Is the change to the API backwards compatible? — N/A (no API change)
  • Should this result in any changes to our documentation, either updating existing docs or adding new ones? — No

Reference Issues/PRs

N/A

What does this implement/fix? Briefly explain your changes.

Adds AVX-512 (V4) specialized distance kernels for 4-bit packed vectors in diskann-quantization. Previously, V4 dispatched 4-bit USlice × USlice distance computations through the V3 (AVX2) kernel via downcast_to_v3. This PR adds native 512-bit specializations and rewires the dispatcher.

What's new
  • SquaredL2 (V4, 4-bit) — loads 16 u32 lanes per iteration, masks/shifts to 4 nibble positions, reinterprets as i16x32, and accumulates squared differences via _mm512_madd_epi16 (dot_simd) into 4 independent i32x16 accumulators.
  • InnerProduct (V4, 4-bit) — built around AVX-512 VNNI (_mm512_dpbusd_epi32) over u8x64 / i8x64 operands. Only 2 shift positions thanks to byte-granular accumulation, halving per-iteration shift/mask work compared to the madd_epi16 approach.
  • Spherical dispatcher — adds fourbit_v4_{l2,ip,cosine}_{data,query}_data factories; flips dispatch_map! for AsData<4> / AsQuery<4> on V4 from downcast_to_v3 to the specialized path. Cosine variants come for free via CompensatedCosine(CompensatedIP(…)).
  • Tail handling — both kernels use blocks = len.div_ceil(8) and a predicated u8x64 load for the trailing bytes, so the scalar fallback only ever handles at most one dangling nibble.
Performance notes
  • 4-bit on V4 hardware previously paid the cost of running two AVX2 kernels per 512-bit worth of data; native V4 doubles per-instruction throughput.
  • Both V4 4-bit kernels finish the trailing bytes with a single predicated u8 load, so the scalar fallback only ever runs for at most one dangling nibble.
Testing

Numerical correctness is exact (integer arithmetic). The existing test_bitslice_distances_4bit exercises Scalar / V3 / V4 paths under V4::new_checked_miri() and asserts bit-exact equality against the scalar reference across dim ∈ 0..MAX_DIM with randomized inputs. CI runs the V4 path under Intel SDE Sapphire Rapids.

Any other comments?

Commits
  1. a6734519 — Add v4 distance kernels (initial implementation).
  2. cc4074a6 — quantization: clean up V4 4-bit distance kernels (alias macro, doc cross-refs, gated tail, L2 <= blocks, IP loop bound preserved as <).
  3. 1019ec61 — wide: add Emulated<u32, 16> => Emulated<i16, 32> reinterpret for V4 (required by the coverage build's --cfg=miri path; mirrors the existing V3 <u32, 8> => <i16, 16> impl).
  4. e26e4a91 — quantization: address review comments on V4 4-bit kernel docs (correct avx512bw shift claim; rewrite SAFETY pointer-arithmetic ranges in consistent u32-element units).
  5. 45f5b7ff — quantization: address review comments on V4 4-bit predicated tail SAFETY (grammar fix in L2 tail; rewrite IP tail SAFETY in byte units to match the px_u32.add(i).cast::<u8>() pointer).
  6. 08beb308 — quantization: address V4 4-bit review feedback (drop V3-contrast docs; refactor L2 tail to use a predicated u8 load matching IP; bump BITSLICE_TEST_BOUNDS for (4, X86_64_V4) to (512, 300) to exercise the doubled main-loop width under miri).
Co-authors
  • B Harsha Kashyap @hakashya — initial implementation
  • Krishnakumar Ravi (KK) @kk-src — original idea / approach

@m3hm3t m3hm3t force-pushed the hakashya/avx512/add4bit branch from 67c8cd3 to fec9bfc Compare May 8, 2026 08:45
- Use diskann_wide::alias! for type aliases.
- Add V3<->V4 cross-references and minor doc fixes.
- Add a debug_assert on the trailing scalar fallback bound.
- Gate the predicated tail behind 'if remainder > 0'.
- Note (vs the initial cleanup): keep the V4 4-bit InnerProduct main loop
  bound as 'i + 16 < blocks'. Because IP uses 'blocks = len.div_ceil(8)'
  (so a partial trailing byte can be mopped up by the predicated u8x64
  tail), relaxing this to '<=' would cause the SIMD loop to read 4 bytes
  past the buffer end whenever 'blocks' is a multiple of 16 but 'len' is
  not a multiple of 8 (e.g. len = 121). The L2 kernel is unaffected
  because it uses 'blocks = len / 8' (rounding down).

Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com>
Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
@m3hm3t m3hm3t force-pushed the hakashya/avx512/add4bit branch from fec9bfc to cc4074a Compare May 8, 2026 08:48
The new V4 4-bit SquaredL2 kernel in diskann-quantization reinterprets
'u32x16' loads as 'i16x32' to feed '_mm512_madd_epi16'. The intrinsic-
backed conversion 'u32x16: SIMDReinterpret<i16x32>' was already provided
in 'arch/x86_64/v4/conversion.rs', but the parallel impl on the emulated
representation 'Emulated<u32, 16, A>: SIMDReinterpret<Emulated<i16, 32, A>>'
was missing.

This impl is exercised by the coverage build (--cfg=miri +
instrument-coverage) which runs the SquaredL2 V4 kernel against the
Emulated<…, V4> types. CI failed with E0277 'unsatisfied trait bound'
at 8 call sites in diskann-quantization/src/bits/distances.rs.

The new entry mirrors the existing '<u32, 8> => <i16, 16>' impl used by V3.

Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com>
Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 8, 2026

Codecov Report

❌ Patch coverage is 98.94180% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 89.49%. Comparing base (3a20042) to head (08beb30).
⚠️ Report is 24 commits behind head on main.

Files with missing lines Patch % Lines
diskann-quantization/src/bits/distances.rs 98.94% 2 Missing ⚠️

❌ Your patch status has failed because the patch coverage (0.52%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1045      +/-   ##
==========================================
+ Coverage   89.43%   89.49%   +0.06%     
==========================================
  Files         449      461      +12     
  Lines       83779    85682    +1903     
==========================================
+ Hits        74926    76684    +1758     
- Misses       8853     8998     +145     
Flag Coverage Δ
miri 89.49% <98.94%> (+0.06%) ⬆️
unittests 89.12% <0.52%> (-0.15%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
diskann-quantization/src/spherical/iface.rs 92.90% <ø> (+0.32%) ⬆️
diskann-wide/src/emulated.rs 98.29% <ø> (ø)
diskann-quantization/src/bits/distances.rs 98.93% <98.94%> (+<0.01%) ⬆️

... and 95 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@m3hm3t m3hm3t marked this pull request as ready for review May 8, 2026 11:17
@m3hm3t m3hm3t requested review from a team and Copilot May 8, 2026 11:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds native x86_64 V4 (AVX-512) distance kernels for 4-bit packed vectors in diskann-quantization, and switches spherical dispatch for 4-bit AsData / AsQuery from V4→V3 downcasting to the new V4-specialized paths.

Changes:

  • Implement V4 4-bit SquaredL2 and InnerProduct kernels in bits/distances.rs, and update V4 retarget/dispatch to use them.
  • Update spherical dispatcher so 4-bit AsData<4> / AsQuery<4> on V4 no longer downcasts to V3.
  • Extend codegen/disassembly helpers and emulation reinterpret support for new V4 lane shapes.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
diskann-wide/src/emulated.rs Adds emulated little-endian reinterpret support for u32x16 -> i16x32 used by V4 kernels (Miri/emulation).
diskann-quantization/src/spherical/iface.rs Switches V4 spherical dispatch for 4-bit data/query layouts to the specialized V4 path.
diskann-quantization/src/spherical/__codegen/x86_64.rs Adds V4 4-bit spherical distance-computer factory functions for codegen/disassembly coverage.
diskann-quantization/src/bits/distances.rs Introduces V4 4-bit SquaredL2 and InnerProduct implementations and updates V4 retarget lists accordingly.
diskann-quantization/src/__codegen/x86_64.rs Adds V4 4-bit bitslice distance entrypoints for disassembly inspection.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread diskann-quantization/src/bits/distances.rs Outdated
Comment thread diskann-quantization/src/bits/distances.rs Outdated
Comment thread diskann-quantization/src/bits/distances.rs Outdated
Three Copilot review nits from PR microsoft#1045:

- L2 doc (~line 578): the claim 'AVX-512 does not have 16-bit integer
  bit-shift instructions' is incorrect; V4 enables avx512bw which
  provides 'vpsrlw'. Reword to give the actual reason: the project's
  'diskann_wide' type abstractions provide a native shift on 'u32x16'
  but not on 'i16x32', so we shift in u32 view and reinterpret.

- L2 SAFETY comment (~line 612) and IP SAFETY comment (~line 1330):
  these mixed pointer-element arithmetic with byte sizes
  ('px_u32 + i + 16 * size_of::<u32>()' on a *const u32). Rewrite
  consistently in u32-element units:
  'px_u32.add(i)..px_u32.add(i + 16) (in u32 units)'.

Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com>
Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Comment thread diskann-quantization/src/bits/distances.rs Outdated
Comment thread diskann-quantization/src/bits/distances.rs Outdated
Two Copilot review nits from PR microsoft#1045:

- L2 predicated tail (~line 651): grammar fix ('an unaligned starting' ->
  rewritten to refer concretely to 'remainder values of type u32 ...
  starting at px_u32.add(i)').

- IP predicated tail (~line 1369): the SAFETY comment said the load
  starts 'at an offset of i', but the pointer is
  'px_u32.add(i).cast::<u8>()' so the byte offset from px_u32 is 4*i.
  Rewrite the comment in consistent units (refer directly to the byte
  pointer and call out the 4*i byte offset).

Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com>
Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
Copy link
Copy Markdown
Contributor

@hildebrandmw hildebrandmw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great! Can you please also update the bounds in BITSLICE_TEST_BOUNDS for the (Key::new(4, X86_64_V4), Bounds::new(256, 150))? I think maybe 512, 300 is needed (sorry Miri runs) since the width of the main loop is effectively doubled.

///
/// We perform shifts on the `u32x16` view rather than the `i16x32` view because the
/// `diskann_wide` type abstractions provide a native shift on `u32x16` but not on
/// `i16x32`. The reinterpret-after-shift is well-defined because the same shift amount
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a mapping for i16x32 shifts?

Copy link
Copy Markdown
Author

@m3hm3t m3hm3t May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out. Fixed in 08beb30 by removing the wrong paragraph. Left the kernel doing shifts on the u32x16 view since since masking to 4 bits makes it equivalent.


let remainder = blocks - i;

if remainder > 0 {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the same thing here for the remainder that is done for 4-bit inner product? A masked 8-bit load instead of a masked 32-bit load with at most one left over for the final epilogue?

Copy link
Copy Markdown
Author

@m3hm3t m3hm3t May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 08beb30 . L2 V4 tail now mirrors the IP V4 layout: blocks = len.div_ceil(8), main loop while i + 16 < blocks, predicated u8x64 load for the trailing bytes.

///
/// # Implementation Notes
///
/// Unlike the V3 4-bit `InnerProduct` impl (which uses `_mm256_madd_epi16` over `i16` lanes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, please keep doc comments to what the kernel does and not as a contrast to what another kernel does. The former is more durable long-term.

Copy link
Copy Markdown
Author

@m3hm3t m3hm3t May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworded in 08beb30 to describe what this kernel does.

// covered by the predicated `u8` load below; the scalar fallback only ever needs to
// handle at most a single dangling nibble.
//
// (The L2 kernel uses `len / 8` because its scalar fallback can handle up to 7
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, please keep comments focused on what this kernel does, not what other kernels do.

Copy link
Copy Markdown
Author

@m3hm3t m3hm3t May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the L2-contrast parenthetical in 08beb30.

- L2 V4: drop V3-contrast / misleading shift docs; reword to focus on
  what this kernel does (per @hildebrandmw's review).
- L2 V4: switch the trailing tail to a predicated u8 load (matching the
  IP V4 kernel), so the scalar fallback only handles at most one
  dangling nibble instead of up to 7. blocks now uses div_ceil(8) and
  the main loop uses '< blocks', mirroring the IP layout.
- IP V4: drop V3-contrast docs / parenthetical, focus on what this
  kernel does.
- Bump BITSLICE_TEST_BOUNDS for (4, X86_64_V4) from (256, 150) to
  (512, 300) so miri exercises the doubled main-loop width.

Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com>
Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
@m3hm3t
Copy link
Copy Markdown
Author

m3hm3t commented May 8, 2026

Thanks for the review @hildebrandmw! Addressed everything in 08beb30:

  • Bumped BITSLICE_TEST_BOUNDS for (4, X86_64_V4) from (256, 150) to (512, 300) so miri exercises the doubled main-loop width.

@m3hm3t m3hm3t requested a review from hildebrandmw May 8, 2026 18:24
@harsha-simhadri
Copy link
Copy Markdown
Contributor

Thanks for the great contribution @m3hm3t and @Hakshaya. Have you run this through the benchmark simd tool on an AVX512 machine. Could you summarize relevant numbers in PR description.

Copy link
Copy Markdown
Contributor

@hildebrandmw hildebrandmw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @m3hm3t, this looks great! I would echo what Harsha asked and wonder if you have some before/after numbers for performance on a suitable AVX-512 machine 😄.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants