Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions diskann-disk/src/search/pq/pq_scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
//! Aligned allocator

use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
use diskann::{ANNError, ANNResult};

use diskann_quantization::alloc::{AlignedAllocator, Poly};

Expand All @@ -23,13 +23,17 @@ pub struct PQScratch {
/// This is used to store the pq coordinates of the candidate vectors.
pub aligned_pq_coord_scratch: Poly<[u8], AlignedAllocator>,

/// Query scratch buffer stored as `f32`. `set` initializes it by copying/converting the
/// raw query values; `PQTable.PreprocessQuery` can then rotate or otherwise preprocess it.
/// Query scratch buffer stored as `f32`, sized by the PQ table's logical dimension.
/// `set` populates it from a caller-provided `&[f32]`; `PQTable::preprocess_query` can
/// then rotate or otherwise preprocess it.
pub rotated_query: Vec<f32>,
}

impl PQScratch {
/// Create a new pq scratch
/// Create a new pq scratch.
///
/// `dim` is the PQ table's logical dimension (`PQData::get_dim()`); the
/// internal `rotated_query` buffer is sized to exactly this many `f32` slots.
pub fn new(
graph_degree: usize,
dim: usize,
Expand All @@ -54,31 +58,24 @@ impl PQScratch {
})
}

/// Copy `query` into `rotated_query`, converting to `f32`.
/// Copy the first `dim` elements of `query` into `rotated_query`.
///
/// `dim` is the element count in the `T` representation. The decompressed
/// `f32` length returned by `T::as_f32` may differ (e.g. `MinMaxElement`
/// expands to more `f32`s than its raw element count), so the destination
/// slice is sized by that actual length.
/// `query` must already be in full-precision `f32` representation; quantized
/// inputs (e.g. `MinMaxElement`) should be decoded via `VectorRepr::as_f32`
/// at the caller boundary before invoking this method.
///
/// Returns `DimensionMismatchError` if `dim > query.len()` or the
/// decompressed vector does not fit in `rotated_query`.
pub fn set<T: VectorRepr>(&mut self, dim: usize, query: &[T]) -> ANNResult<()> {
if dim > query.len() {
/// Accepts oversized `query` (only the first `dim` elements are used) for
/// backwards compatibility with callers that hold alignment-padded buffers.
/// Returns `DimensionMismatchError` if `query.len() < rotated_query.len()`.
pub fn set(&mut self, query: &[f32]) -> ANNResult<()> {
let dim = self.rotated_query.len();
if query.len() < dim {
return Err(ANNError::log_dimension_mismatch_error(format!(
"PQScratch::set: expected query of length >= {dim}, got {}",
query.len()
)));
}
let query = T::as_f32(&query[..dim]).into_ann_result()?;
if query.len() > self.rotated_query.len() {
return Err(ANNError::log_dimension_mismatch_error(format!(
"PQScratch::set: decompressed query of length {} does not fit rotated_query buffer of length {}",
query.len(),
self.rotated_query.len()
)));
}
self.rotated_query[..query.len()].copy_from_slice(&query);
self.rotated_query.copy_from_slice(&query[..dim]);
Ok(())
}
}
Expand Down Expand Up @@ -116,11 +113,11 @@ mod tests {
);

// Test set() method
let query: Vec<u8> = (1..=dim).map(|i| i as u8).collect();
pq_scratch.set::<u8>(query.len(), &query).unwrap();
let query: Vec<f32> = (1..=dim).map(|i| i as f32).collect();
pq_scratch.set(&query).unwrap();

(0..query.len()).for_each(|i| {
assert_eq!(pq_scratch.rotated_query[i], query[i] as f32);
assert_eq!(pq_scratch.rotated_query[i], query[i]);
});
}
}
15 changes: 5 additions & 10 deletions diskann-disk/src/search/pq/quantizer_preprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pub fn quantizer_preprocess(
) -> ANNResult<()> {
match &pq_data.pq_table() {
PQTable::Transposed(table) => {
let dim = table.dim();
let expected_len = table.ncenters() * table.nchunks();
let dst = diskann_utils::views::MutMatrixView::try_from(
&mut (*pq_scratch.aligned_pqtable_dist_scratch)[..expected_len],
Expand All @@ -40,13 +39,13 @@ pub fn quantizer_preprocess(
// as L2 until a more thorough evaluation can be made.
Metric::L2 | Metric::Cosine | Metric::CosineNormalized => {
table.process_into::<diskann_quantization::distances::SquaredL2>(
&pq_scratch.rotated_query[..dim],
&pq_scratch.rotated_query,
dst,
);
}
Metric::InnerProduct => {
table.process_into::<diskann_quantization::distances::InnerProduct>(
&pq_scratch.rotated_query[..dim],
&pq_scratch.rotated_query,
dst,
);
}
Expand All @@ -60,21 +59,17 @@ pub fn quantizer_preprocess(
// We're keeping that behavior here - treating `Cosine` and `CosineNormalized`
// as L2 until a more thorough evaluation can be made.
Metric::L2 | Metric::Cosine | Metric::CosineNormalized => {
// The scratch only stores the aligned dimension. However, preprocessing
// wants the actual dimension used, so we have to shrink the rotated query
// accordingly.
let dim = table.get_dim();
table.preprocess_query(&mut pq_scratch.rotated_query[..dim]);
table.preprocess_query(&mut pq_scratch.rotated_query);

// Compute the distance between each chunk of the query to each pq centroids.
table.populate_chunk_distances(
pq_scratch.rotated_query.as_slice(),
&pq_scratch.rotated_query,
&mut pq_scratch.aligned_pqtable_dist_scratch,
)?;
}
Metric::InnerProduct => {
table.populate_chunk_inner_products(
pq_scratch.rotated_query.as_slice(),
&pq_scratch.rotated_query,
&mut pq_scratch.aligned_pqtable_dist_scratch,
)?;
}
Expand Down
15 changes: 8 additions & 7 deletions diskann-disk/src/search/provider/disk_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::{

use crate::data_model::GraphDataType;
use diskann::{
error::IntoANNResult,
graph::{
self,
glue::{
Expand Down Expand Up @@ -501,7 +502,7 @@ where
#[derive(Clone)]
struct DiskSearchScratchArgs<'a, ProviderFactory> {
graph_degree: usize,
dim: usize,
pq_dim: usize,
num_pq_chunks: usize,
num_pq_centers: usize,
vertex_factory: &'a ProviderFactory,
Expand All @@ -519,7 +520,7 @@ where
fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
let pq_scratch = PQScratch::new(
args.graph_degree,
args.dim,
args.pq_dim,
args.num_pq_chunks,
args.num_pq_centers,
)?;
Expand Down Expand Up @@ -621,17 +622,17 @@ where
scratch_pool,
&DiskSearchScratchArgs {
graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
dim: provider.graph_header.metadata().dims,
pq_dim: provider.pq_data.get_dim(),
num_pq_chunks: provider.pq_data.get_num_chunks(),
num_pq_centers: provider.pq_data.get_num_centers(),
vertex_factory: vertex_provider_factory,
graph_header: &provider.graph_header,
},
)?;

scratch
.pq_scratch
.set(provider.graph_header.metadata().dims, query)?;
// Decode caller's native vector representation into `f32`; downstream PQ kernels operate purely on `&[f32]`.
let f32_query = Data::VectorDataType::as_f32(query).into_ann_result()?;
scratch.pq_scratch.set(&f32_query)?;
let start_vertex_id = provider.graph_header.metadata().medoid as u32;

let timer = Instant::now();
Expand Down Expand Up @@ -875,7 +876,7 @@ where
let pq_data = disk_index_reader.get_pq_data();
let scratch_pool_args = DiskSearchScratchArgs {
graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
dim: graph_header.metadata().dims,
pq_dim: pq_data.get_dim(),
num_pq_chunks: pq_data.get_num_chunks(),
num_pq_centers: pq_data.get_num_centers(),
vertex_factory: &vertex_provider_factory,
Expand Down
8 changes: 8 additions & 0 deletions diskann-disk/src/storage/quant/pq/pq_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ impl PQData {
&self.pq_pivot_table
}

/// Return the logical dimension of the original (pre-quantization) vectors.
pub fn get_dim(&self) -> usize {
match &self.pq_pivot_table {
PQTable::Transposed(table) => table.dim(),
PQTable::Fixed(table) => table.get_dim(),
}
}

/// Return the number of chunks in the underlying PQ schema.
pub fn get_num_chunks(&self) -> usize {
match &self.pq_pivot_table {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl QuantVectorProvider {
/// Create a query computer for the provided query vector
pub fn query_computer<T>(&self, query: &[T]) -> ANNResult<QueryComputer>
where
T: Copy + VectorRepr,
T: VectorRepr,
{
QueryComputer::new(
self.pq_chunk_table.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use std::sync::{Arc, Mutex};

use arc_swap::{ArcSwap, Guard};
use diskann::{ANNError, ANNResult};
use diskann::{ANNError, ANNResult, error::IntoANNResult, utils::VectorRepr};
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
use rand::{Rng, SeedableRng, rngs::StdRng};

Expand Down Expand Up @@ -77,15 +77,15 @@ impl TestMultiPQProviderAsync {

pub fn get_query_computer<T>(&self, query: &[T]) -> ANNResult<NoneToInfinity<QueryComputer>>
where
T: Copy + Into<f32>,
T: VectorRepr,
{
let table = self.multi_table().map_err(|err| {
ANNError::log_index_error(format_args!("Table construction failed with: {}", err))
})?;
Ok(NoneToInfinity(QueryComputer::new(
table,
self.metric,
query,
&T::as_f32(query).into_ann_result()?,
)?))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl FastMemoryQuantVectorProviderAsync {
/// Create a query computer for the provided query vector.
pub fn query_computer<T>(&self, query: &[T]) -> ANNResult<QueryComputer>
where
T: Copy + VectorRepr,
T: VectorRepr,
{
QueryComputer::new(
self.pq_chunk_table.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use std::sync::Arc;

use crate::storage::{StorageReadProvider, StorageWriteProvider};
use arc_swap::{ArcSwap, Guard};
#[cfg(test)]
use diskann::utils::VectorRepr;
use diskann::{ANNError, ANNResult};
use diskann::{ANNError, ANNResult, error::IntoANNResult};
#[cfg(test)]
use diskann_quantization::CompressInto;
use diskann_utils::object_pool::ObjectPool;
Expand Down Expand Up @@ -88,12 +87,12 @@ impl MemoryQuantVectorProviderAsync {
/// Create a query computer for the provided query vector.
pub fn query_computer<T>(&self, query: &[T]) -> ANNResult<QueryComputer>
where
T: Copy + Into<f32>,
T: VectorRepr,
{
QueryComputer::new(
self.pq_chunk_table.clone(),
self.metric,
query,
&T::as_f32(query).into_ann_result()?,
Some(self.vec_pool.clone()),
)
}
Expand Down
6 changes: 3 additions & 3 deletions diskann-providers/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ pub mod pq;
pub use pq::{
FixedChunkPQTable, GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ,
NUM_PQ_CENTROIDS, accum_row_inplace, calculate_chunk_offsets_auto, compute_pq_distance,
compute_pq_distance_for_pq_coordinates, direct_distance_impl, distance,
generate_pq_data_from_pivots_from_membuf, generate_pq_data_from_pivots_from_membuf_batch,
generate_pq_pivots, generate_pq_pivots_from_membuf,
compute_pq_distance_for_pq_coordinates, distance, generate_pq_data_from_pivots_from_membuf,
generate_pq_data_from_pivots_from_membuf_batch, generate_pq_pivots,
generate_pq_pivots_from_membuf,
Comment on lines +15 to +17
};

pub mod statistics;
Expand Down
55 changes: 12 additions & 43 deletions diskann-providers/src/model/pq/distance/cosine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ impl<T> DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
pub(crate) fn new<U>(parent: T, query: &[U]) -> ANNResult<Self>
where
U: Into<f32> + Copy,
{
/// Caller must ensure `query.len() == parent.get_dim()` (validated by `QueryComputer::new`).
pub(crate) fn new(parent: T, query: &[f32]) -> ANNResult<Self> {
let mut object = Self::new_unpopulated(parent);
object.populate(query)?;
Ok(object)
Expand All @@ -47,24 +45,10 @@ where
}
}

fn populate<U>(&mut self, query: &[U]) -> ANNResult<()>
where
U: Into<f32> + Copy,
{
// Make sure the query we are getting is the expected length.
//
// Alignment means that the size of `query` gets increased ...
// This makes is VERY hard to do error checking on dimension propagation.
assert!(self.query.len() <= query.len());

// Preprocessing currently just converts the query to f32 so we don't have
// to do that every time we want to compute a distance.
//
// If the query is *already* f32, then we can skip this memcpy and use the original
// query for distance computations.
std::iter::zip(self.query.iter_mut(), query.iter()).for_each(|(dst, src)| {
*dst = (*src).into();
});
fn populate(&mut self, query: &[f32]) -> ANNResult<()> {
// Stash a copy of the query so subsequent `evaluate` calls can reuse it
// without converting on every call.
self.query.copy_from_slice(query);
Ok(())
}

Expand Down Expand Up @@ -98,27 +82,12 @@ where

#[cfg(test)]
mod tests {
use std::marker::PhantomData;

use diskann_vector::Half;
use rand::SeedableRng;
use rstest::rstest;

use super::{
super::test_utils::{self, TestDistribution},
*,
};

#[rstest]
#[case(PhantomData::<f32>)]
#[case(PhantomData::<Half>)]
#[case(PhantomData::<i8>)]
#[case(PhantomData::<u8>)]
fn test_cosine<T>(#[case] _marker: PhantomData<T>)
where
T: Into<f32> + TestDistribution,
{
// RNG

use super::{super::test_utils, *};

#[test]
fn test_cosine() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0xc33529acbe474958);
let num_trials = 20;

Expand All @@ -143,7 +112,7 @@ mod tests {

// DirectCosine
test_utils::test_cosine_inner(
|table: &FixedChunkPQTable, query: &[T]| {
|table: &FixedChunkPQTable, query: &[f32]| {
DirectCosine::new(table, query).unwrap()
},
&table,
Expand Down
Loading
Loading