Add v4 distance kernels (4-bit SquaredL2 / InnerProduct)#1045
Add v4 distance kernels (4-bit SquaredL2 / InnerProduct)#1045m3hm3t wants to merge 6 commits intomicrosoft:mainfrom
Conversation
67c8cd3 to
fec9bfc
Compare
- Use diskann_wide::alias! for type aliases. - Add V3<->V4 cross-references and minor doc fixes. - Add a debug_assert on the trailing scalar fallback bound. - Gate the predicated tail behind 'if remainder > 0'. - Note (vs the initial cleanup): keep the V4 4-bit InnerProduct main loop bound as 'i + 16 < blocks'. Because IP uses 'blocks = len.div_ceil(8)' (so a partial trailing byte can be mopped up by the predicated u8x64 tail), relaxing this to '<=' would cause the SIMD loop to read 4 bytes past the buffer end whenever 'blocks' is a multiple of 16 but 'len' is not a multiple of 8 (e.g. len = 121). The L2 kernel is unaffected because it uses 'blocks = len / 8' (rounding down). Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com> Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
fec9bfc to
cc4074a
Compare
The new V4 4-bit SquaredL2 kernel in diskann-quantization reinterprets 'u32x16' loads as 'i16x32' to feed '_mm512_madd_epi16'. The intrinsic- backed conversion 'u32x16: SIMDReinterpret<i16x32>' was already provided in 'arch/x86_64/v4/conversion.rs', but the parallel impl on the emulated representation 'Emulated<u32, 16, A>: SIMDReinterpret<Emulated<i16, 32, A>>' was missing. This impl is exercised by the coverage build (--cfg=miri + instrument-coverage) which runs the SquaredL2 V4 kernel against the Emulated<…, V4> types. CI failed with E0277 'unsatisfied trait bound' at 8 call sites in diskann-quantization/src/bits/distances.rs. The new entry mirrors the existing '<u32, 8> => <i16, 16>' impl used by V3. Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com> Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (0.52%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1045 +/- ##
==========================================
+ Coverage 89.43% 89.49% +0.06%
==========================================
Files 449 461 +12
Lines 83779 85682 +1903
==========================================
+ Hits 74926 76684 +1758
- Misses 8853 8998 +145
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Adds native x86_64 V4 (AVX-512) distance kernels for 4-bit packed vectors in diskann-quantization, and switches spherical dispatch for 4-bit AsData / AsQuery from V4→V3 downcasting to the new V4-specialized paths.
Changes:
- Implement V4 4-bit
SquaredL2andInnerProductkernels inbits/distances.rs, and update V4 retarget/dispatch to use them. - Update spherical dispatcher so 4-bit
AsData<4>/AsQuery<4>on V4 no longer downcasts to V3. - Extend codegen/disassembly helpers and emulation reinterpret support for new V4 lane shapes.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
diskann-wide/src/emulated.rs |
Adds emulated little-endian reinterpret support for u32x16 -> i16x32 used by V4 kernels (Miri/emulation). |
diskann-quantization/src/spherical/iface.rs |
Switches V4 spherical dispatch for 4-bit data/query layouts to the specialized V4 path. |
diskann-quantization/src/spherical/__codegen/x86_64.rs |
Adds V4 4-bit spherical distance-computer factory functions for codegen/disassembly coverage. |
diskann-quantization/src/bits/distances.rs |
Introduces V4 4-bit SquaredL2 and InnerProduct implementations and updates V4 retarget lists accordingly. |
diskann-quantization/src/__codegen/x86_64.rs |
Adds V4 4-bit bitslice distance entrypoints for disassembly inspection. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Three Copilot review nits from PR microsoft#1045: - L2 doc (~line 578): the claim 'AVX-512 does not have 16-bit integer bit-shift instructions' is incorrect; V4 enables avx512bw which provides 'vpsrlw'. Reword to give the actual reason: the project's 'diskann_wide' type abstractions provide a native shift on 'u32x16' but not on 'i16x32', so we shift in u32 view and reinterpret. - L2 SAFETY comment (~line 612) and IP SAFETY comment (~line 1330): these mixed pointer-element arithmetic with byte sizes ('px_u32 + i + 16 * size_of::<u32>()' on a *const u32). Rewrite consistently in u32-element units: 'px_u32.add(i)..px_u32.add(i + 16) (in u32 units)'. Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com> Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
Two Copilot review nits from PR microsoft#1045: - L2 predicated tail (~line 651): grammar fix ('an unaligned starting' -> rewritten to refer concretely to 'remainder values of type u32 ... starting at px_u32.add(i)'). - IP predicated tail (~line 1369): the SAFETY comment said the load starts 'at an offset of i', but the pointer is 'px_u32.add(i).cast::<u8>()' so the byte offset from px_u32 is 4*i. Rewrite the comment in consistent units (refer directly to the byte pointer and call out the 4*i byte offset). Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com> Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
hildebrandmw
left a comment
There was a problem hiding this comment.
Thanks! This looks great! Can you please also update the bounds in BITSLICE_TEST_BOUNDS for the (Key::new(4, X86_64_V4), Bounds::new(256, 150))? I think maybe 512, 300 is needed (sorry Miri runs) since the width of the main loop is effectively doubled.
| /// | ||
| /// We perform shifts on the `u32x16` view rather than the `i16x32` view because the | ||
| /// `diskann_wide` type abstractions provide a native shift on `u32x16` but not on | ||
| /// `i16x32`. The reinterpret-after-shift is well-defined because the same shift amount |
There was a problem hiding this comment.
Thanks for pointing out. Fixed in 08beb30 by removing the wrong paragraph. Left the kernel doing shifts on the u32x16 view since since masking to 4 bits makes it equivalent.
|
|
||
| let remainder = blocks - i; | ||
|
|
||
| if remainder > 0 { |
There was a problem hiding this comment.
Can we do the same thing here for the remainder that is done for 4-bit inner product? A masked 8-bit load instead of a masked 32-bit load with at most one left over for the final epilogue?
There was a problem hiding this comment.
Done in 08beb30 . L2 V4 tail now mirrors the IP V4 layout: blocks = len.div_ceil(8), main loop while i + 16 < blocks, predicated u8x64 load for the trailing bytes.
| /// | ||
| /// # Implementation Notes | ||
| /// | ||
| /// Unlike the V3 4-bit `InnerProduct` impl (which uses `_mm256_madd_epi16` over `i16` lanes |
There was a problem hiding this comment.
In general, please keep doc comments to what the kernel does and not as a contrast to what another kernel does. The former is more durable long-term.
There was a problem hiding this comment.
Reworded in 08beb30 to describe what this kernel does.
| // covered by the predicated `u8` load below; the scalar fallback only ever needs to | ||
| // handle at most a single dangling nibble. | ||
| // | ||
| // (The L2 kernel uses `len / 8` because its scalar fallback can handle up to 7 |
There was a problem hiding this comment.
Again, please keep comments focused on what this kernel does, not what other kernels do.
- L2 V4: drop V3-contrast / misleading shift docs; reword to focus on what this kernel does (per @hildebrandmw's review). - L2 V4: switch the trailing tail to a predicated u8 load (matching the IP V4 kernel), so the scalar fallback only handles at most one dangling nibble instead of up to 7. blocks now uses div_ceil(8) and the main loop uses '< blocks', mirroring the IP layout. - IP V4: drop V3-contrast docs / parenthetical, focus on what this kernel does. - Bump BITSLICE_TEST_BOUNDS for (4, X86_64_V4) from (256, 150) to (512, 300) so miri exercises the doubled main-loop width. Co-authored-by: B Harsha Kashyap <hakashya@microsoft.com> Co-authored-by: Krishnakumar Ravi (KK) <kkravi@microsoft.com>
|
Thanks for the review @hildebrandmw! Addressed everything in 08beb30:
|
hildebrandmw
left a comment
There was a problem hiding this comment.
Thanks @m3hm3t, this looks great! I would echo what Harsha asked and wonder if you have some before/after numbers for performance on a suitable AVX-512 machine 😄.
Reference Issues/PRs
N/A
What does this implement/fix? Briefly explain your changes.
Adds AVX-512 (V4) specialized distance kernels for 4-bit packed vectors in
diskann-quantization. Previously, V4 dispatched 4-bitUSlice × USlicedistance computations through the V3 (AVX2) kernel viadowncast_to_v3. This PR adds native 512-bit specializations and rewires the dispatcher.What's new
SquaredL2(V4, 4-bit) — loads 16u32lanes per iteration, masks/shifts to 4 nibble positions, reinterprets asi16x32, and accumulates squared differences via_mm512_madd_epi16(dot_simd) into 4 independenti32x16accumulators.InnerProduct(V4, 4-bit) — built around AVX-512 VNNI (_mm512_dpbusd_epi32) overu8x64/i8x64operands. Only 2 shift positions thanks to byte-granular accumulation, halving per-iteration shift/mask work compared to themadd_epi16approach.fourbit_v4_{l2,ip,cosine}_{data,query}_datafactories; flipsdispatch_map!forAsData<4>/AsQuery<4>on V4 fromdowncast_to_v3to the specialized path. Cosine variants come for free viaCompensatedCosine(CompensatedIP(…)).blocks = len.div_ceil(8)and a predicatedu8x64load for the trailing bytes, so the scalar fallback only ever handles at most one dangling nibble.Performance notes
u8load, so the scalar fallback only ever runs for at most one dangling nibble.Testing
Numerical correctness is exact (integer arithmetic). The existing
test_bitslice_distances_4bitexercises Scalar / V3 / V4 paths underV4::new_checked_miri()and asserts bit-exact equality against the scalar reference acrossdim ∈ 0..MAX_DIMwith randomized inputs. CI runs the V4 path under Intel SDE Sapphire Rapids.Any other comments?
Commits
a6734519— Add v4 distance kernels (initial implementation).cc4074a6— quantization: clean up V4 4-bit distance kernels (alias macro, doc cross-refs, gated tail, L2<= blocks, IP loop bound preserved as<).1019ec61— wide: addEmulated<u32, 16> => Emulated<i16, 32>reinterpret for V4 (required by the coverage build's--cfg=miripath; mirrors the existing V3<u32, 8> => <i16, 16>impl).e26e4a91— quantization: address review comments on V4 4-bit kernel docs (correctavx512bwshift claim; rewrite SAFETY pointer-arithmetic ranges in consistentu32-element units).45f5b7ff— quantization: address review comments on V4 4-bit predicated tail SAFETY (grammar fix in L2 tail; rewrite IP tail SAFETY in byte units to match thepx_u32.add(i).cast::<u8>()pointer).08beb308— quantization: address V4 4-bit review feedback (drop V3-contrast docs; refactor L2 tail to use a predicatedu8load matching IP; bumpBITSLICE_TEST_BOUNDSfor(4, X86_64_V4)to(512, 300)to exercise the doubled main-loop width under miri).Co-authors