From 2630e265958f4e615d7dee2f070756f80cc87880 Mon Sep 17 00:00:00 2001 From: Wei Wu Date: Fri, 8 May 2026 14:17:46 +0800 Subject: [PATCH 1/4] PQ: size scratch by logical dim, formalize dim contract Follow-up to @hildebrandmw's review on #960. - `PQScratch::rotated_query` is sized by `PQData::get_dim()` (PQ logical dim) instead of `graph_header.metadata().dims` (slot byte count, exceeds logical dim for `MinMaxElement` due to trailing min/max metadata). - PQ entries take `&[f32]`, accept `len >= dim`, slice `[..dim]`. Callers decode via `VectorRepr::as_f32` once at the boundary; PQ subtree is f32-only internally. - Kernels (`preprocess_query`, `populate_chunk_distances_impl`, `direct_distance_impl`) `debug_assert_eq!` on entry, matching `pq_dist_lookup_single`. The two `_impl` helpers become private. - `DirectCosine::populate` uses `copy_from_slice` (the previous zip silently truncated, no longer applicable). - Drop redundant `Copy` and `U: Into` bounds on touched fns. --- diskann-disk/src/search/pq/pq_scratch.rs | 47 +++++++------- .../src/search/pq/quantizer_preprocess.rs | 15 ++--- .../src/search/provider/disk_provider.rs | 15 ++--- .../src/storage/quant/pq/pq_dataset.rs | 8 +++ .../async_/bf_tree/quant_vector_provider.rs | 2 +- .../async_/experimental/multi_pq_async.rs | 6 +- .../fast_memory_quant_vector_provider.rs | 2 +- .../async_/memory_quant_vector_provider.rs | 7 +-- diskann-providers/src/model/mod.rs | 6 +- .../src/model/pq/distance/cosine.rs | 32 +++------- .../src/model/pq/distance/dynamic.rs | 62 ++++++++++++------- .../src/model/pq/distance/innerproduct.rs | 25 +++----- diskann-providers/src/model/pq/distance/l2.rs | 23 +++---- .../src/model/pq/distance/multi.rs | 21 +++---- .../src/model/pq/distance/test_utils.rs | 15 ++--- .../src/model/pq/fixed_chunk_pq_table.rs | 11 ++-- diskann-providers/src/model/pq/mod.rs | 2 +- 17 files changed, 142 insertions(+), 157 deletions(-) diff --git a/diskann-disk/src/search/pq/pq_scratch.rs b/diskann-disk/src/search/pq/pq_scratch.rs index 0707fae3a..b4bbc635a 100644 --- a/diskann-disk/src/search/pq/pq_scratch.rs +++ b/diskann-disk/src/search/pq/pq_scratch.rs @@ -4,7 +4,7 @@ */ //! Aligned allocator -use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult}; +use diskann::{ANNError, ANNResult}; use diskann_quantization::alloc::{AlignedAllocator, Poly}; @@ -23,13 +23,17 @@ pub struct PQScratch { /// This is used to store the pq coordinates of the candidate vectors. pub aligned_pq_coord_scratch: Poly<[u8], AlignedAllocator>, - /// Query scratch buffer stored as `f32`. `set` initializes it by copying/converting the - /// raw query values; `PQTable.PreprocessQuery` can then rotate or otherwise preprocess it. + /// Query scratch buffer stored as `f32`, sized by the PQ table's logical dimension. + /// `set` populates it from a caller-provided `&[f32]`; `PQTable::preprocess_query` can + /// then rotate or otherwise preprocess it. pub rotated_query: Vec, } impl PQScratch { - /// Create a new pq scratch + /// Create a new pq scratch. + /// + /// `dim` is the PQ table's logical dimension (`PQData::get_dim()`); the + /// internal `rotated_query` buffer is sized to exactly this many `f32` slots. pub fn new( graph_degree: usize, dim: usize, @@ -54,31 +58,24 @@ impl PQScratch { }) } - /// Copy `query` into `rotated_query`, converting to `f32`. + /// Copy the first `dim` elements of `query` into `rotated_query`. /// - /// `dim` is the element count in the `T` representation. The decompressed - /// `f32` length returned by `T::as_f32` may differ (e.g. `MinMaxElement` - /// expands to more `f32`s than its raw element count), so the destination - /// slice is sized by that actual length. + /// `query` must already be in full-precision `f32` representation; quantized + /// inputs (e.g. `MinMaxElement`) should be decoded via `VectorRepr::as_f32` + /// at the caller boundary before invoking this method. /// - /// Returns `DimensionMismatchError` if `dim > query.len()` or the - /// decompressed vector does not fit in `rotated_query`. - pub fn set(&mut self, dim: usize, query: &[T]) -> ANNResult<()> { - if dim > query.len() { + /// Accepts oversized `query` (only the first `dim` elements are used) for + /// backwards compatibility with callers that hold alignment-padded buffers. + /// Returns `DimensionMismatchError` if `query.len() < rotated_query.len()`. + pub fn set(&mut self, query: &[f32]) -> ANNResult<()> { + let dim = self.rotated_query.len(); + if query.len() < dim { return Err(ANNError::log_dimension_mismatch_error(format!( "PQScratch::set: expected query of length >= {dim}, got {}", query.len() ))); } - let query = T::as_f32(&query[..dim]).into_ann_result()?; - if query.len() > self.rotated_query.len() { - return Err(ANNError::log_dimension_mismatch_error(format!( - "PQScratch::set: decompressed query of length {} does not fit rotated_query buffer of length {}", - query.len(), - self.rotated_query.len() - ))); - } - self.rotated_query[..query.len()].copy_from_slice(&query); + self.rotated_query.copy_from_slice(&query[..dim]); Ok(()) } } @@ -116,11 +113,11 @@ mod tests { ); // Test set() method - let query: Vec = (1..=dim).map(|i| i as u8).collect(); - pq_scratch.set::(query.len(), &query).unwrap(); + let query: Vec = (1..=dim).map(|i| i as f32).collect(); + pq_scratch.set(&query).unwrap(); (0..query.len()).for_each(|i| { - assert_eq!(pq_scratch.rotated_query[i], query[i] as f32); + assert_eq!(pq_scratch.rotated_query[i], query[i]); }); } } diff --git a/diskann-disk/src/search/pq/quantizer_preprocess.rs b/diskann-disk/src/search/pq/quantizer_preprocess.rs index cc454ea7b..88766f992 100644 --- a/diskann-disk/src/search/pq/quantizer_preprocess.rs +++ b/diskann-disk/src/search/pq/quantizer_preprocess.rs @@ -23,7 +23,6 @@ pub fn quantizer_preprocess( ) -> ANNResult<()> { match &pq_data.pq_table() { PQTable::Transposed(table) => { - let dim = table.dim(); let expected_len = table.ncenters() * table.nchunks(); let dst = diskann_utils::views::MutMatrixView::try_from( &mut (*pq_scratch.aligned_pqtable_dist_scratch)[..expected_len], @@ -40,13 +39,13 @@ pub fn quantizer_preprocess( // as L2 until a more thorough evaluation can be made. Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { table.process_into::( - &pq_scratch.rotated_query[..dim], + &pq_scratch.rotated_query, dst, ); } Metric::InnerProduct => { table.process_into::( - &pq_scratch.rotated_query[..dim], + &pq_scratch.rotated_query, dst, ); } @@ -60,21 +59,17 @@ pub fn quantizer_preprocess( // We're keeping that behavior here - treating `Cosine` and `CosineNormalized` // as L2 until a more thorough evaluation can be made. Metric::L2 | Metric::Cosine | Metric::CosineNormalized => { - // The scratch only stores the aligned dimension. However, preprocessing - // wants the actual dimension used, so we have to shrink the rotated query - // accordingly. - let dim = table.get_dim(); - table.preprocess_query(&mut pq_scratch.rotated_query[..dim]); + table.preprocess_query(&mut pq_scratch.rotated_query); // Compute the distance between each chunk of the query to each pq centroids. table.populate_chunk_distances( - pq_scratch.rotated_query.as_slice(), + &pq_scratch.rotated_query, &mut pq_scratch.aligned_pqtable_dist_scratch, )?; } Metric::InnerProduct => { table.populate_chunk_inner_products( - pq_scratch.rotated_query.as_slice(), + &pq_scratch.rotated_query, &mut pq_scratch.aligned_pqtable_dist_scratch, )?; } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 1344605f4..33938caea 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -17,6 +17,7 @@ use std::{ use crate::data_model::GraphDataType; use diskann::{ + error::IntoANNResult, graph::{ self, glue::{ @@ -501,7 +502,7 @@ where #[derive(Clone)] struct DiskSearchScratchArgs<'a, ProviderFactory> { graph_degree: usize, - dim: usize, + pq_dim: usize, num_pq_chunks: usize, num_pq_centers: usize, vertex_factory: &'a ProviderFactory, @@ -519,7 +520,7 @@ where fn try_create(args: &DiskSearchScratchArgs) -> Result { let pq_scratch = PQScratch::new( args.graph_degree, - args.dim, + args.pq_dim, args.num_pq_chunks, args.num_pq_centers, )?; @@ -621,7 +622,7 @@ where scratch_pool, &DiskSearchScratchArgs { graph_degree: provider.graph_header.max_degree::()?, - dim: provider.graph_header.metadata().dims, + pq_dim: provider.pq_data.get_dim(), num_pq_chunks: provider.pq_data.get_num_chunks(), num_pq_centers: provider.pq_data.get_num_centers(), vertex_factory: vertex_provider_factory, @@ -629,9 +630,9 @@ where }, )?; - scratch - .pq_scratch - .set(provider.graph_header.metadata().dims, query)?; + // Decode caller's native vector representation into `f32`; downstream PQ kernels operate purely on `&[f32]`. + let f32_query = Data::VectorDataType::as_f32(query).into_ann_result()?; + scratch.pq_scratch.set(&f32_query)?; let start_vertex_id = provider.graph_header.metadata().medoid as u32; let timer = Instant::now(); @@ -875,7 +876,7 @@ where let pq_data = disk_index_reader.get_pq_data(); let scratch_pool_args = DiskSearchScratchArgs { graph_degree: graph_header.max_degree::()?, - dim: graph_header.metadata().dims, + pq_dim: pq_data.get_dim(), num_pq_chunks: pq_data.get_num_chunks(), num_pq_centers: pq_data.get_num_centers(), vertex_factory: &vertex_provider_factory, diff --git a/diskann-disk/src/storage/quant/pq/pq_dataset.rs b/diskann-disk/src/storage/quant/pq/pq_dataset.rs index 049825a19..bcff23f9d 100644 --- a/diskann-disk/src/storage/quant/pq/pq_dataset.rs +++ b/diskann-disk/src/storage/quant/pq/pq_dataset.rs @@ -60,6 +60,14 @@ impl PQData { &self.pq_pivot_table } + /// Return the logical dimension of the original (pre-quantization) vectors. + pub fn get_dim(&self) -> usize { + match &self.pq_pivot_table { + PQTable::Transposed(table) => table.dim(), + PQTable::Fixed(table) => table.get_dim(), + } + } + /// Return the number of chunks in the underlying PQ schema. pub fn get_num_chunks(&self) -> usize { match &self.pq_pivot_table { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs index c3c08bb58..d9e05933c 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/quant_vector_provider.rs @@ -116,7 +116,7 @@ impl QuantVectorProvider { /// Create a query computer for the provided query vector pub fn query_computer(&self, query: &[T]) -> ANNResult where - T: Copy + VectorRepr, + T: VectorRepr, { QueryComputer::new( self.pq_chunk_table.clone(), diff --git a/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs b/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs index c77e257d2..cd512f0ab 100644 --- a/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs +++ b/diskann-providers/src/model/graph/provider/async_/experimental/multi_pq_async.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::{ArcSwap, Guard}; -use diskann::{ANNError, ANNResult}; +use diskann::{ANNError, ANNResult, error::IntoANNResult, utils::VectorRepr}; use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -77,7 +77,7 @@ impl TestMultiPQProviderAsync { pub fn get_query_computer(&self, query: &[T]) -> ANNResult> where - T: Copy + Into, + T: VectorRepr, { let table = self.multi_table().map_err(|err| { ANNError::log_index_error(format_args!("Table construction failed with: {}", err)) @@ -85,7 +85,7 @@ impl TestMultiPQProviderAsync { Ok(NoneToInfinity(QueryComputer::new( table, self.metric, - query, + &T::as_f32(query).into_ann_result()?, )?)) } diff --git a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs index 63aad4566..49103d80f 100644 --- a/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/fast_memory_quant_vector_provider.rs @@ -101,7 +101,7 @@ impl FastMemoryQuantVectorProviderAsync { /// Create a query computer for the provided query vector. pub fn query_computer(&self, query: &[T]) -> ANNResult where - T: Copy + VectorRepr, + T: VectorRepr, { QueryComputer::new( self.pq_chunk_table.clone(), diff --git a/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs index 06a22536e..9dd6ffc4d 100644 --- a/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/memory_quant_vector_provider.rs @@ -12,9 +12,8 @@ use std::sync::Arc; use crate::storage::{StorageReadProvider, StorageWriteProvider}; use arc_swap::{ArcSwap, Guard}; -#[cfg(test)] use diskann::utils::VectorRepr; -use diskann::{ANNError, ANNResult}; +use diskann::{ANNError, ANNResult, error::IntoANNResult}; #[cfg(test)] use diskann_quantization::CompressInto; use diskann_utils::object_pool::ObjectPool; @@ -88,12 +87,12 @@ impl MemoryQuantVectorProviderAsync { /// Create a query computer for the provided query vector. pub fn query_computer(&self, query: &[T]) -> ANNResult where - T: Copy + Into, + T: VectorRepr, { QueryComputer::new( self.pq_chunk_table.clone(), self.metric, - query, + &T::as_f32(query).into_ann_result()?, Some(self.vec_pool.clone()), ) } diff --git a/diskann-providers/src/model/mod.rs b/diskann-providers/src/model/mod.rs index f6ae2be75..39e3cfa22 100644 --- a/diskann-providers/src/model/mod.rs +++ b/diskann-providers/src/model/mod.rs @@ -12,9 +12,9 @@ pub mod pq; pub use pq::{ FixedChunkPQTable, GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS, accum_row_inplace, calculate_chunk_offsets_auto, compute_pq_distance, - compute_pq_distance_for_pq_coordinates, direct_distance_impl, distance, - generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch, - generate_pq_pivots, generate_pq_pivots_from_membuf, + compute_pq_distance_for_pq_coordinates, distance, generate_pq_data_from_pivots_from_membuf, + generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots, + generate_pq_pivots_from_membuf, }; pub mod statistics; diff --git a/diskann-providers/src/model/pq/distance/cosine.rs b/diskann-providers/src/model/pq/distance/cosine.rs index dc69e4489..bac0c42d7 100644 --- a/diskann-providers/src/model/pq/distance/cosine.rs +++ b/diskann-providers/src/model/pq/distance/cosine.rs @@ -31,10 +31,8 @@ impl DirectCosine where T: Deref, { - pub(crate) fn new(parent: T, query: &[U]) -> ANNResult - where - U: Into + Copy, - { + /// Caller must ensure `query.len() == parent.get_dim()` (validated by `QueryComputer::new`). + pub(crate) fn new(parent: T, query: &[f32]) -> ANNResult { let mut object = Self::new_unpopulated(parent); object.populate(query)?; Ok(object) @@ -47,24 +45,10 @@ where } } - fn populate(&mut self, query: &[U]) -> ANNResult<()> - where - U: Into + Copy, - { - // Make sure the query we are getting is the expected length. - // - // Alignment means that the size of `query` gets increased ... - // This makes is VERY hard to do error checking on dimension propagation. - assert!(self.query.len() <= query.len()); - - // Preprocessing currently just converts the query to f32 so we don't have - // to do that every time we want to compute a distance. - // - // If the query is *already* f32, then we can skip this memcpy and use the original - // query for distance computations. - std::iter::zip(self.query.iter_mut(), query.iter()).for_each(|(dst, src)| { - *dst = (*src).into(); - }); + fn populate(&mut self, query: &[f32]) -> ANNResult<()> { + // Stash a copy of the query so subsequent `evaluate` calls can reuse it + // without converting on every call. + self.query.copy_from_slice(query); Ok(()) } @@ -142,8 +126,8 @@ mod tests { }; // DirectCosine - test_utils::test_cosine_inner( - |table: &FixedChunkPQTable, query: &[T]| { + test_utils::test_cosine_inner::( + |table: &FixedChunkPQTable, query: &[f32]| { DirectCosine::new(table, query).unwrap() }, &table, diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index c41955efb..e45c2531b 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -5,7 +5,7 @@ use std::{ops::Deref, sync::Arc}; -use diskann::ANNResult; +use diskann::{ANNError, ANNResult}; use diskann_utils::object_pool::ObjectPool; use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; @@ -67,15 +67,22 @@ where /// /// Even though PQ does not necessarily preserve the norms of compressed vectors, using L2 /// for Cosine Normalized seems to work well enough in practice to work as a temporary fix. - pub fn new( + pub fn new( table: T, metric: Metric, - query: &[U], + query: &[f32], pool: Option>>>, - ) -> ANNResult - where - U: Into + Copy, - { + ) -> ANNResult { + // Accept oversized `query` (only the first `dim` elements are used) for + // backwards compatibility with callers that hold alignment-padded buffers. + let dim = table.get_dim(); + if query.len() < dim { + return Err(ANNError::log_dimension_mismatch_error(format!( + "QueryComputer::new: expected query of length >= {dim}, got {}", + query.len() + ))); + } + let query = &query[..dim]; let result = match metric { Metric::L2 => Self::L2(TableL2::new(table, query, pool)?), Metric::InnerProduct => Self::IP(TableIP::new(table, query, pool)?), @@ -209,13 +216,22 @@ where { #[inline(always)] fn evaluate_similarity(&self, fp: &[f32], q: &[u8]) -> f32 { + // Accept oversized `fp` (only the first `dim` elements are used) for + // backwards compatibility with callers that hold alignment-padded buffers. + let dim = self.table.get_dim(); + assert!( + fp.len() >= dim, + "DistanceComputer: full-precision query length {} < dim {}", + fp.len(), + dim, + ); assert_eq!( q.len(), self.table.get_num_chunks(), "{}", INVALID_PQ_DIMENSION ); - (self.vtable.distance_fn)(&self.table, fp, q) + (self.vtable.distance_fn)(&self.table, &fp[..dim], q) } } @@ -306,8 +322,8 @@ mod tests { absolute: 0.0, }; - test_utils::test_l2_inner( - |table: &FixedChunkPQTable, query: &[T]| { + test_utils::test_l2_inner::( + |table: &FixedChunkPQTable, query: &[f32]| { QueryComputer::new(table, Metric::L2, query, None).unwrap() }, &table, @@ -317,10 +333,10 @@ mod tests { errors, ); - test_utils::test_l2_inner( - |table: &FixedChunkPQTable, query: &[T]| PreprocessedWrapper { + test_utils::test_l2_inner::( + |table: &FixedChunkPQTable, query: &[f32]| PreprocessedWrapper { table: DistanceComputer::new(table, Metric::L2), - query: query.iter().map(|i| >::into(*i)).collect(), + query: query.to_vec(), }, &table, num_trials, @@ -369,8 +385,8 @@ mod tests { absolute: 5.0e-3, }; - test_utils::test_ip_inner( - |table: &FixedChunkPQTable, query: &[T]| { + test_utils::test_ip_inner::( + |table: &FixedChunkPQTable, query: &[f32]| { QueryComputer::new(table, Metric::InnerProduct, query, None).unwrap() }, &table, @@ -380,10 +396,10 @@ mod tests { errors, ); - test_utils::test_ip_inner( - |table: &FixedChunkPQTable, query: &[T]| PreprocessedWrapper { + test_utils::test_ip_inner::( + |table: &FixedChunkPQTable, query: &[f32]| PreprocessedWrapper { table: DistanceComputer::new(table, Metric::InnerProduct), - query: query.iter().map(|i| >::into(*i)).collect(), + query: query.to_vec(), }, &table, num_trials, @@ -432,8 +448,8 @@ mod tests { absolute: 0.0, }; - test_utils::test_cosine_inner( - |table: &FixedChunkPQTable, query: &[T]| { + test_utils::test_cosine_inner::( + |table: &FixedChunkPQTable, query: &[f32]| { QueryComputer::new(table, Metric::Cosine, query, None).unwrap() }, &table, @@ -443,10 +459,10 @@ mod tests { errors, ); - test_utils::test_cosine_inner( - |table: &FixedChunkPQTable, query: &[T]| PreprocessedWrapper { + test_utils::test_cosine_inner::( + |table: &FixedChunkPQTable, query: &[f32]| PreprocessedWrapper { table: DistanceComputer::new(table, Metric::Cosine), - query: query.iter().map(|i| >::into(*i)).collect(), + query: query.to_vec(), }, &table, num_trials, diff --git a/diskann-providers/src/model/pq/distance/innerproduct.rs b/diskann-providers/src/model/pq/distance/innerproduct.rs index 7d9abfc7e..27c578051 100644 --- a/diskann-providers/src/model/pq/distance/innerproduct.rs +++ b/diskann-providers/src/model/pq/distance/innerproduct.rs @@ -48,14 +48,12 @@ impl TableIP where T: Deref, { - pub(crate) fn new( + /// Caller must ensure `query.len() == parent.get_dim()` (validated by `QueryComputer::new`). + pub(crate) fn new( parent: T, - query: &[U], + query: &[f32], pool: Option>>>, - ) -> ANNResult - where - U: Into + Copy, - { + ) -> ANNResult { let mut object = Self::new_unpopulated(parent, pool); object.populate(query)?; Ok(object) @@ -73,17 +71,10 @@ where } } - fn populate + Copy>(&mut self, query: &[U]) -> ANNResult<()> { - // Ensure that the query has the expected length. - // - // Alignment means that the size of `query` gets increased ... - // This makes is VERY hard to do error checking on dimension propagation. - assert!(self.parent.get_dim() <= query.len()); - let local_query: Vec = query.iter().map(|x| (*x).into()).collect(); - + fn populate(&mut self, query: &[f32]) -> ANNResult<()> { // Compute the partial distances into the lookup-table. self.parent - .populate_chunk_inner_products(&local_query, &mut self.lookup_table) + .populate_chunk_inner_products(query, &mut self.lookup_table) } /// Compute the distance between a PQ code that the query provided to the most recent @@ -163,8 +154,8 @@ mod tests { }; // Basic `TableIP` - test_utils::test_ip_inner( - |table: &FixedChunkPQTable, query: &[T]| { + test_utils::test_ip_inner::( + |table: &FixedChunkPQTable, query: &[f32]| { TableIP::new(table, query, None).unwrap() }, &table, diff --git a/diskann-providers/src/model/pq/distance/l2.rs b/diskann-providers/src/model/pq/distance/l2.rs index 543534796..69bd479bc 100644 --- a/diskann-providers/src/model/pq/distance/l2.rs +++ b/diskann-providers/src/model/pq/distance/l2.rs @@ -49,14 +49,12 @@ impl TableL2 where T: Deref, { - pub(crate) fn new( + /// Caller must ensure `query.len() == parent.get_dim()` (validated by `QueryComputer::new`). + pub(crate) fn new( parent: T, - query: &[U], + query: &[f32], pool: Option>>>, - ) -> ANNResult - where - U: Into + Copy, - { + ) -> ANNResult { let mut object = Self::new_unpopulated(parent, pool); object.populate(query)?; Ok(object) @@ -74,13 +72,8 @@ where } } - fn populate + Copy>(&mut self, query: &[U]) -> ANNResult<()> { - // Ensure that the query has the expected length. - // - // Alignment means that the size of `query` gets increased ... - // This makes is VERY hard to do error checking on dimension propagation. - assert!(self.parent.get_dim() <= query.len()); - let mut local_query: Vec = query.iter().map(|x| (*x).into()).collect(); + fn populate(&mut self, query: &[f32]) -> ANNResult<()> { + let mut local_query: Vec = query.to_vec(); // This function does the following: // 1. Centers the data (if the centorid is non-zero). @@ -168,8 +161,8 @@ mod tests { }; // Basic `TableL2` - test_utils::test_l2_inner( - |table: &FixedChunkPQTable, query: &[T]| { + test_utils::test_l2_inner::( + |table: &FixedChunkPQTable, query: &[f32]| { TableL2::new(table, query, None).unwrap() }, &table, diff --git a/diskann-providers/src/model/pq/distance/multi.rs b/diskann-providers/src/model/pq/distance/multi.rs index a591059c3..38ae97c80 100644 --- a/diskann-providers/src/model/pq/distance/multi.rs +++ b/diskann-providers/src/model/pq/distance/multi.rs @@ -370,10 +370,7 @@ where I: PQVersion, { /// Construct a new `MultiQueryComputer` with the requested metric and query. - pub fn new(table: MultiTable, metric: Metric, query: &[U]) -> ANNResult - where - U: Into + Copy, - { + pub fn new(table: MultiTable, metric: Metric, query: &[f32]) -> ANNResult { let s = match table { MultiTable::One { table, version } => Self::One { computer: { QueryComputer::new(table, metric, query, None)? }, @@ -868,7 +865,7 @@ mod tests { } fn test_query_computer_multi_with_one<'a, T, R>( - mut create: impl FnMut(usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>, + mut create: impl FnMut(usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>, table: &'a FixedChunkPQTable, config: &test_utils::TableConfig, reference: &::Distance, @@ -888,7 +885,7 @@ mod tests { let version: usize = version.into_usize(); let invalid_version = version.wrapping_add(1); - let computer = create(version, &input); + let computer = create(version, &input_f32); assert_eq!( computer.versions(), @@ -935,11 +932,11 @@ mod tests { absolute: 0.0, }; - let create = |version: usize, query: &[T]| { + let create = |version: usize, query: &[f32]| { let schema = MultiTable::one(&table, version); MultiQueryComputer::new(schema, metric, query).unwrap() }; - test_query_computer_multi_with_one( + test_query_computer_multi_with_one::( create, &table, &config, @@ -956,7 +953,7 @@ mod tests { #[allow(clippy::too_many_arguments)] fn test_query_computer_multi_with_two<'a, T, R>( - create: impl Fn(usize, usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>, + create: impl Fn(usize, usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>, new: &'a FixedChunkPQTable, old: &'a FixedChunkPQTable, new_config: &test_utils::TableConfig, @@ -990,7 +987,7 @@ mod tests { let new_version = new_version.into_usize(); let invalid_version = invalid_version.into_usize(); - let computer = create(new_version, old_version, &input); + let computer = create(new_version, old_version, &input_f32); assert_eq!( computer.versions(), @@ -1066,7 +1063,7 @@ mod tests { let new = test_utils::seed_pivot_table(new_config); let num_trials = 20; - let create = |new_version: usize, old_version: usize, query: &[T]| { + let create = |new_version: usize, old_version: usize, query: &[f32]| { let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap(); MultiQueryComputer::new(schema, metric, query).unwrap() }; @@ -1076,7 +1073,7 @@ mod tests { absolute: 0.0, }; - test_query_computer_multi_with_two( + test_query_computer_multi_with_two::( create, &new, &old, diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 76c240e1f..35c6e80a0 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -185,7 +185,7 @@ where /// Next, if OPQ is used, we need to ensure that the matrix multiplication is applied /// to the query vector before we can obtain expected results. pub(super) fn test_l2_inner<'a, T, F, R>( - create: impl Fn(&'a FixedChunkPQTable, &[T]) -> F, + create: impl Fn(&'a FixedChunkPQTable, &[f32]) -> F, table: &'a FixedChunkPQTable, num_trials: usize, config: TableConfig, @@ -198,11 +198,12 @@ pub(super) fn test_l2_inner<'a, T, F, R>( { for _ in 0..num_trials { let input: Vec = T::generate(config.dim, rng); - let mut input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); + let input_query: Vec = input.iter().map(|x| (*x).into()).collect(); + let mut input_f32 = input_query.clone(); table.preprocess_query(&mut input_f32); - let computer = create(table, &input); + let computer = create(table, &input_query); for _ in 0..num_trials { let code = generate_random_code(config.num_pivots, config.pq_chunks, rng); let expected_vector = @@ -223,7 +224,7 @@ pub(super) fn test_l2_inner<'a, T, F, R>( } pub(super) fn test_ip_inner<'a, T, F, R>( - create: impl Fn(&'a FixedChunkPQTable, &[T]) -> F, + create: impl Fn(&'a FixedChunkPQTable, &[f32]) -> F, table: &'a FixedChunkPQTable, num_trials: usize, config: TableConfig, @@ -238,7 +239,7 @@ pub(super) fn test_ip_inner<'a, T, F, R>( let input: Vec = T::generate(config.dim, rng); let input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); - let computer = create(table, &input); + let computer = create(table, &input_f32); for _ in 0..num_trials { let code = generate_random_code(config.num_pivots, config.pq_chunks, rng); let expected_vector = @@ -259,7 +260,7 @@ pub(super) fn test_ip_inner<'a, T, F, R>( } pub(super) fn test_cosine_inner<'a, T, F, R>( - create: impl Fn(&'a FixedChunkPQTable, &[T]) -> F, + create: impl Fn(&'a FixedChunkPQTable, &[f32]) -> F, table: &'a FixedChunkPQTable, num_trials: usize, config: TableConfig, @@ -274,7 +275,7 @@ pub(super) fn test_cosine_inner<'a, T, F, R>( let input: Vec = T::generate(config.dim, rng); let input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); - let computer = create(table, &input); + let computer = create(table, &input_f32); for _ in 0..num_trials { let code = generate_random_code(config.num_pivots, config.pq_chunks, rng); let expected_vector = diff --git a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs index 1f58205cb..e9e64f529 100644 --- a/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs +++ b/diskann-providers/src/model/pq/fixed_chunk_pq_table.rs @@ -32,7 +32,7 @@ pub struct FixedChunkPQTable { // These free functions use internals of the `FixedChunkPQTable`. // // We should clean up the API in the FFI. -pub fn direct_distance_impl( +fn direct_distance_impl( pq_table: &[f32], chunk_offsets: &[usize], dim: usize, @@ -42,6 +42,7 @@ pub fn direct_distance_impl( where T: distance::simd::ResumableSIMDSchema, { + debug_assert_eq!(query_vec.len(), dim); let mut accumulator = distance::simd::Resumable::new(T::init(ARCH)); let mut start = chunk_offsets[0]; let num_pq_chunks = chunk_offsets.len() - 1; @@ -167,12 +168,13 @@ impl FixedChunkPQTable { /// Shifting the query according to mean or the whole corpus. The output is a rotated query vector, /// which is later used to calculate the distance between each query chunk and each centroid using populate_chunk_distances. pub fn preprocess_query(&self, rotated_query_vec: &mut [f32]) { + debug_assert_eq!(rotated_query_vec.len(), self.centroids.len()); for (query, ¢roid) in rotated_query_vec.iter_mut().zip(self.centroids.iter()) { *query -= centroid; } } - pub fn populate_chunk_distances_impl( + fn populate_chunk_distances_impl( &self, rotated_query_vec: &[f32], aligned_pq_table_dist_scratch: &mut [f32], @@ -182,6 +184,8 @@ impl FixedChunkPQTable { { let num_centers = self.get_num_centers(); let num_chunks = self.get_num_chunks(); + let dim = self.get_dim(); + debug_assert_eq!(rotated_query_vec.len(), dim); if aligned_pq_table_dist_scratch.len() < num_chunks * num_centers { return Err(ANNError::log_pq_error( "aligned_pq_table_dist_scratch.len() should at least be num_pq_chunks * num_centers", @@ -190,7 +194,6 @@ impl FixedChunkPQTable { let offsets: &[usize] = self.table.view_offsets().into(); let table: &[f32] = self.table.view_pivots().into(); - let dim = self.get_dim(); for centroid_index in 0..num_centers { let table_start = dim * centroid_index; @@ -277,7 +280,7 @@ impl FixedChunkPQTable { /// Calculate the distance between query and given centroid by inner product /// * `query_vec` - query vector: 1 * dim /// * `base_vec` - given centroid array: 1 * num_pq_chunks - pub fn inner_product_raw(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 { + fn inner_product_raw(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 { direct_distance_impl::>( self.table.view_pivots().as_slice(), self.table.view_offsets().as_slice(), diff --git a/diskann-providers/src/model/pq/mod.rs b/diskann-providers/src/model/pq/mod.rs index 6338e39ec..9e8385bc7 100644 --- a/diskann-providers/src/model/pq/mod.rs +++ b/diskann-providers/src/model/pq/mod.rs @@ -5,7 +5,7 @@ mod fixed_chunk_pq_table; pub use fixed_chunk_pq_table::{ FixedChunkPQTable, compute_pq_distance, compute_pq_distance_for_pq_coordinates, - direct_distance_impl, pq_dist_lookup_single, + pq_dist_lookup_single, }; mod pq_construction; From a79123d60bd2cb78d1707298bb19a9aedede1b99 Mon Sep 17 00:00:00 2001 From: Wei Wu Date: Fri, 8 May 2026 15:35:32 +0800 Subject: [PATCH 2/4] Add coverage tests for QueryComputer/DistanceComputer dim checks --- .../src/model/pq/distance/dynamic.rs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index e45c2531b..70207091a 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -548,4 +548,34 @@ mod tests { ); } } + + #[test] + fn query_computer_new_rejects_undersized_query() { + let config = test_utils::TableConfig { + dim: 16, + pq_chunks: 4, + num_pivots: 8, + start_value: 0.0, + }; + let table = test_utils::seed_pivot_table(config); + let short_query = vec![0.0f32; config.dim - 1]; + let err = QueryComputer::new(&table, Metric::L2, &short_query, None).unwrap_err(); + assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError); + } + + #[test] + #[should_panic(expected = "DistanceComputer: full-precision query length")] + fn distance_computer_panics_on_undersized_fp_query() { + let config = test_utils::TableConfig { + dim: 16, + pq_chunks: 4, + num_pivots: 8, + start_value: 0.0, + }; + let table = test_utils::seed_pivot_table(config); + let computer = DistanceComputer::new(&table, Metric::L2); + let short_fp = vec![0.0f32; config.dim - 1]; + let code = vec![0u8; config.pq_chunks]; + let _ = computer.evaluate_similarity(short_fp.as_slice(), code.as_slice()); + } } From 01065034bfd1e436729653cf813eba19882c7f76 Mon Sep 17 00:00:00 2001 From: Wei Wu Date: Fri, 8 May 2026 16:00:09 +0800 Subject: [PATCH 3/4] Drop unused T parameter from PQ distance test helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR moved entries to &[f32] so test_X_inner helpers no longer need to generate per-T data — drop the type parameter, generate Vec directly. Removes turbofish at all call sites and the rstest values parameterization. --- .../src/model/pq/distance/cosine.rs | 27 +++------- .../src/model/pq/distance/dynamic.rs | 53 +++++-------------- .../src/model/pq/distance/innerproduct.rs | 27 +++------- diskann-providers/src/model/pq/distance/l2.rs | 26 +++------ .../src/model/pq/distance/multi.rs | 43 ++++----------- .../src/model/pq/distance/test_utils.rs | 18 +++---- 6 files changed, 49 insertions(+), 145 deletions(-) diff --git a/diskann-providers/src/model/pq/distance/cosine.rs b/diskann-providers/src/model/pq/distance/cosine.rs index bac0c42d7..29c874354 100644 --- a/diskann-providers/src/model/pq/distance/cosine.rs +++ b/diskann-providers/src/model/pq/distance/cosine.rs @@ -82,27 +82,12 @@ where #[cfg(test)] mod tests { - use std::marker::PhantomData; - - use diskann_vector::Half; use rand::SeedableRng; - use rstest::rstest; - - use super::{ - super::test_utils::{self, TestDistribution}, - *, - }; - - #[rstest] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - fn test_cosine(#[case] _marker: PhantomData) - where - T: Into + TestDistribution, - { - // RNG + + use super::{super::test_utils, *}; + + #[test] + fn test_cosine() { let mut rng = rand::rngs::StdRng::seed_from_u64(0xc33529acbe474958); let num_trials = 20; @@ -126,7 +111,7 @@ mod tests { }; // DirectCosine - test_utils::test_cosine_inner::( + test_utils::test_cosine_inner( |table: &FixedChunkPQTable, query: &[f32]| { DirectCosine::new(table, query).unwrap() }, diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index 70207091a..1660f97a3 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -251,21 +251,16 @@ where #[cfg(test)] mod tests { - use std::marker::PhantomData; - use approx::assert_relative_eq; use diskann_vector::{ - Half, Norm, PureDistanceFunction, + Norm, PureDistanceFunction, distance::{Cosine, CosineNormalized, InnerProduct, SquaredL2}, norm::FastL2Norm, }; use rand::SeedableRng; use rstest::rstest; - use super::{ - super::test_utils::{self, TestDistribution}, - *, - }; + use super::{super::test_utils, *}; // A wrapper for the `DistanceComputer` that enables it to behave like a // `PreprocessedDistanceFunction`. @@ -292,13 +287,8 @@ mod tests { // L2 // //////// - #[rstest] - fn test_l2( - #[values(PhantomData::, PhantomData::, PhantomData::, PhantomData::)] - _marker: PhantomData, - ) where - T: Into + TestDistribution, - { + #[test] + fn test_l2() { let mut rng = rand::rngs::StdRng::seed_from_u64(0x83aa68de5765b565); for dim in [50, 51] { for pq_chunks in [8, 19, 50] { @@ -322,7 +312,7 @@ mod tests { absolute: 0.0, }; - test_utils::test_l2_inner::( + test_utils::test_l2_inner( |table: &FixedChunkPQTable, query: &[f32]| { QueryComputer::new(table, Metric::L2, query, None).unwrap() }, @@ -333,7 +323,7 @@ mod tests { errors, ); - test_utils::test_l2_inner::( + test_utils::test_l2_inner( |table: &FixedChunkPQTable, query: &[f32]| PreprocessedWrapper { table: DistanceComputer::new(table, Metric::L2), query: query.to_vec(), @@ -353,15 +343,8 @@ mod tests { // InnerProduct // ////////////////// - #[rstest] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - fn test_innerproduct(#[case] _marker: PhantomData) - where - T: Into + TestDistribution, - { + #[test] + fn test_innerproduct() { let mut rng = rand::rngs::StdRng::seed_from_u64(0xc392d773dc8de593); for dim in [12, 15, 128] { for pq_chunks in [2, 5, 15] { @@ -385,7 +368,7 @@ mod tests { absolute: 5.0e-3, }; - test_utils::test_ip_inner::( + test_utils::test_ip_inner( |table: &FixedChunkPQTable, query: &[f32]| { QueryComputer::new(table, Metric::InnerProduct, query, None).unwrap() }, @@ -396,7 +379,7 @@ mod tests { errors, ); - test_utils::test_ip_inner::( + test_utils::test_ip_inner( |table: &FixedChunkPQTable, query: &[f32]| PreprocessedWrapper { table: DistanceComputer::new(table, Metric::InnerProduct), query: query.to_vec(), @@ -416,16 +399,8 @@ mod tests { // Cosine // //////////// - #[rstest] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - fn test_cosine(#[case] _marker: PhantomData) - where - T: Into + TestDistribution, - { - // RNG + #[test] + fn test_cosine() { let mut rng = rand::rngs::StdRng::seed_from_u64(0xc33529acbe474958); let num_trials = 20; @@ -448,7 +423,7 @@ mod tests { absolute: 0.0, }; - test_utils::test_cosine_inner::( + test_utils::test_cosine_inner( |table: &FixedChunkPQTable, query: &[f32]| { QueryComputer::new(table, Metric::Cosine, query, None).unwrap() }, @@ -459,7 +434,7 @@ mod tests { errors, ); - test_utils::test_cosine_inner::( + test_utils::test_cosine_inner( |table: &FixedChunkPQTable, query: &[f32]| PreprocessedWrapper { table: DistanceComputer::new(table, Metric::Cosine), query: query.to_vec(), diff --git a/diskann-providers/src/model/pq/distance/innerproduct.rs b/diskann-providers/src/model/pq/distance/innerproduct.rs index 27c578051..c982589c9 100644 --- a/diskann-providers/src/model/pq/distance/innerproduct.rs +++ b/diskann-providers/src/model/pq/distance/innerproduct.rs @@ -109,27 +109,12 @@ where #[cfg(test)] mod tests { - use std::marker::PhantomData; - - use diskann_vector::Half; use rand::SeedableRng; - use rstest::rstest; - - use super::{ - super::test_utils::{self, TestDistribution}, - *, - }; - - #[rstest] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - #[case(PhantomData::)] - fn test_ip(#[case] _marker: PhantomData) - where - T: Into + TestDistribution, - { - // RNG + + use super::{super::test_utils, *}; + + #[test] + fn test_ip() { let mut rng = rand::rngs::StdRng::seed_from_u64(0x2e767adc3d5d630f); for dim in [12, 15, 128] { @@ -154,7 +139,7 @@ mod tests { }; // Basic `TableIP` - test_utils::test_ip_inner::( + test_utils::test_ip_inner( |table: &FixedChunkPQTable, query: &[f32]| { TableIP::new(table, query, None).unwrap() }, diff --git a/diskann-providers/src/model/pq/distance/l2.rs b/diskann-providers/src/model/pq/distance/l2.rs index 69bd479bc..d1889ffd1 100644 --- a/diskann-providers/src/model/pq/distance/l2.rs +++ b/diskann-providers/src/model/pq/distance/l2.rs @@ -117,24 +117,12 @@ where #[cfg(test)] mod tests { - use std::marker::PhantomData; - - use diskann_vector::Half; use rand::SeedableRng; - use rstest::rstest; - - use super::{ - super::test_utils::{self, TestDistribution}, - *, - }; - - #[rstest] - fn test_l2( - #[values(PhantomData::, PhantomData::, PhantomData::, PhantomData::)] - _marker: PhantomData, - ) where - T: Into + TestDistribution, - { + + use super::{super::test_utils, *}; + + #[test] + fn test_l2() { let mut rng = rand::rngs::StdRng::seed_from_u64(5); for dim in [12, 17, 100, 101] { for pq_chunks in [1, 17, 19, 20] { @@ -153,15 +141,13 @@ mod tests { let table = test_utils::seed_pivot_table(config); let num_trials = 10; - // RNG - let errors = test_utils::RelativeAndAbsolute { relative: 5e-7, absolute: 0.0, }; // Basic `TableL2` - test_utils::test_l2_inner::( + test_utils::test_l2_inner( |table: &FixedChunkPQTable, query: &[f32]| { TableL2::new(table, query, None).unwrap() }, diff --git a/diskann-providers/src/model/pq/distance/multi.rs b/diskann-providers/src/model/pq/distance/multi.rs index 38ae97c80..58584216f 100644 --- a/diskann-providers/src/model/pq/distance/multi.rs +++ b/diskann-providers/src/model/pq/distance/multi.rs @@ -462,11 +462,9 @@ where /// get sent to the right location and that the error handling is correct. #[cfg(test)] mod tests { - use std::marker::PhantomData; - use approx::assert_relative_eq; use diskann::utils::{IntoUsize, VectorRepr}; - use diskann_vector::{Half, PreprocessedDistanceFunction}; + use diskann_vector::PreprocessedDistanceFunction; use rand::{Rng, SeedableRng, distr::Distribution}; use rstest::rstest; @@ -475,13 +473,6 @@ mod tests { *, }; - fn to_f32(x: &[T]) -> Vec - where - T: Into + Copy, - { - x.iter().map(|i| (*i).into()).collect() - } - ///////////////////////// // Versioned PQ Vector // ///////////////////////// @@ -864,7 +855,7 @@ mod tests { ); } - fn test_query_computer_multi_with_one<'a, T, R>( + fn test_query_computer_multi_with_one<'a, R>( mut create: impl FnMut(usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>, table: &'a FixedChunkPQTable, config: &test_utils::TableConfig, @@ -873,13 +864,11 @@ mod tests { rng: &mut R, errors: test_utils::RelativeAndAbsolute, ) where - T: Into + TestDistribution, R: Rng, { let standard = rand::distr::StandardUniform {}; for _ in 0..num_trials { - let input: Vec = T::generate(config.dim, rng); - let input_f32 = to_f32(&input); + let input_f32: Vec = f32::generate(config.dim, rng); let version: u64 = standard.sample(rng); let version: usize = version.into_usize(); @@ -908,13 +897,9 @@ mod tests { } #[rstest] - fn test_query_computer_one( - #[values(PhantomData::, PhantomData::, PhantomData::, PhantomData::)] - _datatype: PhantomData, + fn test_query_computer_one( #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric, - ) where - T: Into + TestDistribution, - { + ) { let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e); let config = test_utils::TableConfig { @@ -936,7 +921,7 @@ mod tests { let schema = MultiTable::one(&table, version); MultiQueryComputer::new(schema, metric, query).unwrap() }; - test_query_computer_multi_with_one::( + test_query_computer_multi_with_one( create, &table, &config, @@ -952,7 +937,7 @@ mod tests { ///////////////////////////////// #[allow(clippy::too_many_arguments)] - fn test_query_computer_multi_with_two<'a, T, R>( + fn test_query_computer_multi_with_two<'a, R>( create: impl Fn(usize, usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>, new: &'a FixedChunkPQTable, old: &'a FixedChunkPQTable, @@ -963,13 +948,11 @@ mod tests { rng: &mut R, errors: test_utils::RelativeAndAbsolute, ) where - T: Into + TestDistribution, R: Rng, { let standard = rand::distr::StandardUniform {}; for _ in 0..num_trials { - let input: Vec = T::generate(old_config.dim, rng); - let input_f32: Vec = to_f32(&input); + let input_f32: Vec = f32::generate(old_config.dim, rng); // Create a computer with two random versions. let old_version: u64 = standard.sample(rng); @@ -1036,13 +1019,9 @@ mod tests { } #[rstest] - fn test_query_computer_two( - #[values(PhantomData::, PhantomData::, PhantomData::, PhantomData::)] - _datatype: PhantomData, + fn test_query_computer_two( #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric, - ) where - T: Into + TestDistribution, - { + ) { let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f); let old_config = test_utils::TableConfig { @@ -1073,7 +1052,7 @@ mod tests { absolute: 0.0, }; - test_query_computer_multi_with_two::( + test_query_computer_multi_with_two( create, &new, &old, diff --git a/diskann-providers/src/model/pq/distance/test_utils.rs b/diskann-providers/src/model/pq/distance/test_utils.rs index 35c6e80a0..542da6020 100644 --- a/diskann-providers/src/model/pq/distance/test_utils.rs +++ b/diskann-providers/src/model/pq/distance/test_utils.rs @@ -184,7 +184,7 @@ where /// /// Next, if OPQ is used, we need to ensure that the matrix multiplication is applied /// to the query vector before we can obtain expected results. -pub(super) fn test_l2_inner<'a, T, F, R>( +pub(super) fn test_l2_inner<'a, F, R>( create: impl Fn(&'a FixedChunkPQTable, &[f32]) -> F, table: &'a FixedChunkPQTable, num_trials: usize, @@ -192,13 +192,11 @@ pub(super) fn test_l2_inner<'a, T, F, R>( rng: &mut R, errors: RelativeAndAbsolute, ) where - T: Into + TestDistribution, F: for<'any> PreprocessedDistanceFunction<&'any [u8], f32>, R: Rng, { for _ in 0..num_trials { - let input: Vec = T::generate(config.dim, rng); - let input_query: Vec = input.iter().map(|x| (*x).into()).collect(); + let input_query: Vec = f32::generate(config.dim, rng); let mut input_f32 = input_query.clone(); table.preprocess_query(&mut input_f32); @@ -223,7 +221,7 @@ pub(super) fn test_l2_inner<'a, T, F, R>( } } -pub(super) fn test_ip_inner<'a, T, F, R>( +pub(super) fn test_ip_inner<'a, F, R>( create: impl Fn(&'a FixedChunkPQTable, &[f32]) -> F, table: &'a FixedChunkPQTable, num_trials: usize, @@ -231,13 +229,11 @@ pub(super) fn test_ip_inner<'a, T, F, R>( rng: &mut R, errors: RelativeAndAbsolute, ) where - T: Into + TestDistribution, F: for<'any> PreprocessedDistanceFunction<&'any [u8], f32>, R: Rng, { for _ in 0..num_trials { - let input: Vec = T::generate(config.dim, rng); - let input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); + let input_f32: Vec = f32::generate(config.dim, rng); let computer = create(table, &input_f32); for _ in 0..num_trials { @@ -259,7 +255,7 @@ pub(super) fn test_ip_inner<'a, T, F, R>( } } -pub(super) fn test_cosine_inner<'a, T, F, R>( +pub(super) fn test_cosine_inner<'a, F, R>( create: impl Fn(&'a FixedChunkPQTable, &[f32]) -> F, table: &'a FixedChunkPQTable, num_trials: usize, @@ -267,13 +263,11 @@ pub(super) fn test_cosine_inner<'a, T, F, R>( rng: &mut R, errors: RelativeAndAbsolute, ) where - T: Into + TestDistribution, F: for<'any> PreprocessedDistanceFunction<&'any [u8], f32>, R: Rng, { for _ in 0..num_trials { - let input: Vec = T::generate(config.dim, rng); - let input_f32: Vec = input.iter().map(|x| (*x).into()).collect(); + let input_f32: Vec = f32::generate(config.dim, rng); let computer = create(table, &input_f32); for _ in 0..num_trials { From bc20758a82bd97e4f25182b5097cf844d5b2e6cb Mon Sep 17 00:00:00 2001 From: Wei Wu Date: Mon, 11 May 2026 10:04:03 +0800 Subject: [PATCH 4/4] Tighten inmem PQ entries to strict query.len() == dim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per #1044 review: callers passing oversized queries to inmem PQ entries was a bug worth surfacing. QueryComputer::new (and via delegation MultiQueryComputer::new) return Result::Err on mismatch; DistanceComputer::evaluate_similarity asserts equality since the trait method has no Result return. PQScratch::set kept on >= dim tolerance for now — disk-side surface. --- .../src/model/pq/distance/dynamic.rs | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/diskann-providers/src/model/pq/distance/dynamic.rs b/diskann-providers/src/model/pq/distance/dynamic.rs index 1660f97a3..3cc277834 100644 --- a/diskann-providers/src/model/pq/distance/dynamic.rs +++ b/diskann-providers/src/model/pq/distance/dynamic.rs @@ -73,16 +73,13 @@ where query: &[f32], pool: Option>>>, ) -> ANNResult { - // Accept oversized `query` (only the first `dim` elements are used) for - // backwards compatibility with callers that hold alignment-padded buffers. let dim = table.get_dim(); - if query.len() < dim { + if query.len() != dim { return Err(ANNError::log_dimension_mismatch_error(format!( - "QueryComputer::new: expected query of length >= {dim}, got {}", + "QueryComputer::new: expected query of length {dim}, got {}", query.len() ))); } - let query = &query[..dim]; let result = match metric { Metric::L2 => Self::L2(TableL2::new(table, query, pool)?), Metric::InnerProduct => Self::IP(TableIP::new(table, query, pool)?), @@ -216,14 +213,12 @@ where { #[inline(always)] fn evaluate_similarity(&self, fp: &[f32], q: &[u8]) -> f32 { - // Accept oversized `fp` (only the first `dim` elements are used) for - // backwards compatibility with callers that hold alignment-padded buffers. - let dim = self.table.get_dim(); - assert!( - fp.len() >= dim, - "DistanceComputer: full-precision query length {} < dim {}", + assert_eq!( fp.len(), - dim, + self.table.get_dim(), + "DistanceComputer: full-precision query length {} != dim {}", + fp.len(), + self.table.get_dim(), ); assert_eq!( q.len(), @@ -231,7 +226,7 @@ where "{}", INVALID_PQ_DIMENSION ); - (self.vtable.distance_fn)(&self.table, &fp[..dim], q) + (self.vtable.distance_fn)(&self.table, fp, q) } } @@ -525,7 +520,7 @@ mod tests { } #[test] - fn query_computer_new_rejects_undersized_query() { + fn query_computer_new_rejects_mismatched_query() { let config = test_utils::TableConfig { dim: 16, pq_chunks: 4, @@ -540,7 +535,7 @@ mod tests { #[test] #[should_panic(expected = "DistanceComputer: full-precision query length")] - fn distance_computer_panics_on_undersized_fp_query() { + fn distance_computer_panics_on_mismatched_fp_query() { let config = test_utils::TableConfig { dim: 16, pq_chunks: 4,