Skip to content

Commit c869512

Browse files
authored
fix: respect fragment filters during distributed worker training (#6358)
This fixes a bug in distributed vector index builds where any worker that fell back to builder-local training sampled training data from the entire dataset instead of the worker's selected fragments. The change threads `fragment_filter` through IVF and quantizer training, and adds an `IvfSq` regression test that verifies worker-local SQ bounds are derived only from the shard fragments.
1 parent b7f6baa commit c869512

2 files changed

Lines changed: 138 additions & 5 deletions

File tree

rust/lance/src/index/vector/builder.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
385385
dim,
386386
self.distance_type,
387387
ivf_params,
388-
None,
388+
self.fragment_filter.as_deref(),
389389
self.progress.clone(),
390390
)
391391
.await
@@ -414,9 +414,13 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
414414
"loading training data for quantizer. sample size: {}",
415415
sample_size_hint
416416
);
417-
let training_data =
418-
utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint, None)
419-
.await?;
417+
let training_data = utils::maybe_sample_training_data(
418+
dataset,
419+
&self.column,
420+
sample_size_hint,
421+
self.fragment_filter.as_deref(),
422+
)
423+
.await?;
420424
info!(
421425
"Finished loading training data in {:02} seconds",
422426
start.elapsed().as_secs_f32()

rust/lance/src/index/vector/ivf/v2.rs

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,9 @@ mod tests {
660660
use lance_index::vector::quantizer::QuantizerMetadata;
661661
use lance_index::vector::sq::builder::SQBuildParams;
662662
use lance_index::vector::{
663-
pq::storage::ProductQuantizationMetadata, storage::STORAGE_METADATA_KEY,
663+
pq::storage::ProductQuantizationMetadata,
664+
sq::storage::{SQ_METADATA_KEY, ScalarQuantizationMetadata},
665+
storage::STORAGE_METADATA_KEY,
664666
};
665667
use lance_index::{INDEX_AUXILIARY_FILE_NAME, metrics::NoOpMetricsCollector};
666668
use lance_index::{optimize::OptimizeOptions, scalar::IndexReader};
@@ -767,6 +769,37 @@ mod tests {
767769
serde_json::from_str(&metadata_entries[0]).unwrap()
768770
}
769771

772+
async fn get_sq_metadata(
773+
dataset: &Dataset,
774+
scheduler: Arc<ScanScheduler>,
775+
index_uuid: &str,
776+
) -> ScalarQuantizationMetadata {
777+
let index_path = dataset
778+
.indices_dir()
779+
.child(index_uuid)
780+
.child(INDEX_AUXILIARY_FILE_NAME);
781+
let file_scheduler = scheduler
782+
.open_file(&index_path, &CachedFileSize::unknown())
783+
.await
784+
.unwrap();
785+
let reader = FileReader::try_open(
786+
file_scheduler,
787+
None,
788+
Arc::<DecoderPlugins>::default(),
789+
&LanceCache::no_cache(),
790+
FileReaderOptions::default(),
791+
)
792+
.await
793+
.unwrap();
794+
if let Some(metadata) = reader.schema().metadata.get(SQ_METADATA_KEY) {
795+
serde_json::from_str(metadata).unwrap()
796+
} else {
797+
let metadata = reader.schema().metadata.get(STORAGE_METADATA_KEY).unwrap();
798+
let metadata_entries: Vec<String> = serde_json::from_str(metadata).unwrap();
799+
serde_json::from_str(&metadata_entries[0]).unwrap()
800+
}
801+
}
802+
770803
async fn assert_rq_rotation_type(dataset: &Dataset, expected: RQRotationType) {
771804
let obj_store = Arc::new(ObjectStore::local());
772805
let scheduler = ScanScheduler::new(obj_store, SchedulerConfig::default_for_testing());
@@ -948,6 +981,49 @@ mod tests {
948981
)
949982
}
950983

984+
fn make_fragment_offset_batches(
985+
rows_per_fragment: usize,
986+
offsets: &[f32],
987+
) -> (Arc<Schema>, Vec<RecordBatch>) {
988+
let schema = Arc::new(Schema::new(vec![
989+
Field::new("id", DataType::UInt64, false),
990+
Field::new(
991+
"vector",
992+
DataType::FixedSizeList(
993+
Arc::new(Field::new("item", DataType::Float32, true)),
994+
DIM as i32,
995+
),
996+
false,
997+
),
998+
]));
999+
1000+
let mut next_id = 0_u64;
1001+
let batches = offsets
1002+
.iter()
1003+
.map(|offset| {
1004+
let ids = Arc::new(UInt64Array::from_iter_values(
1005+
next_id..next_id + rows_per_fragment as u64,
1006+
));
1007+
next_id += rows_per_fragment as u64;
1008+
1009+
let mut values = Vec::with_capacity(rows_per_fragment * DIM);
1010+
for _ in 0..rows_per_fragment {
1011+
for dim in 0..DIM {
1012+
values.push(*offset + dim as f32);
1013+
}
1014+
}
1015+
1016+
let vectors = Arc::new(
1017+
FixedSizeListArray::try_new_from_values(Float32Array::from(values), DIM as i32)
1018+
.unwrap(),
1019+
);
1020+
RecordBatch::try_new(schema.clone(), vec![ids, vectors]).unwrap()
1021+
})
1022+
.collect();
1023+
1024+
(schema, batches)
1025+
}
1026+
9511027
struct VectorIndexTestContext {
9521028
stats_json: String,
9531029
stats: serde_json::Value,
@@ -2164,6 +2240,59 @@ mod tests {
21642240
.unwrap();
21652241
assert!(result.num_rows() > 0);
21662242
}
2243+
2244+
#[tokio::test]
2245+
async fn test_distributed_ivf_sq_worker_training_respects_fragment_filter() {
2246+
const ROWS_PER_FRAGMENT: usize = 64;
2247+
const FRAGMENT_OFFSETS: [f32; 2] = [0.0, 1000.0];
2248+
2249+
let test_dir = TempStrDir::default();
2250+
let dataset_uri = format!("{}/distributed_sq_fragment_filter", test_dir.as_str());
2251+
let (schema, batches) = make_fragment_offset_batches(ROWS_PER_FRAGMENT, &FRAGMENT_OFFSETS);
2252+
let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
2253+
let mut dataset = Dataset::write(
2254+
batches,
2255+
&dataset_uri,
2256+
Some(WriteParams {
2257+
max_rows_per_file: ROWS_PER_FRAGMENT,
2258+
mode: WriteMode::Overwrite,
2259+
..Default::default()
2260+
}),
2261+
)
2262+
.await
2263+
.unwrap();
2264+
2265+
let fragments = dataset.get_fragments();
2266+
assert_eq!(fragments.len(), FRAGMENT_OFFSETS.len());
2267+
2268+
let ivf_params =
2269+
IvfBuildParams::try_with_centroids(2, build_centroids_for_offsets(&FRAGMENT_OFFSETS))
2270+
.unwrap();
2271+
let params = VectorIndexParams::with_ivf_sq_params(
2272+
DistanceType::L2,
2273+
ivf_params,
2274+
SQBuildParams::default(),
2275+
);
2276+
2277+
let segment = dataset
2278+
.create_index_builder(&["vector"], IndexType::Vector, &params)
2279+
.name("sq_fragment_filter".to_string())
2280+
.fragments(vec![fragments[0].id() as u32])
2281+
.execute_uncommitted()
2282+
.await
2283+
.unwrap();
2284+
2285+
let scheduler = ScanScheduler::new(
2286+
Arc::new(dataset.object_store().clone()),
2287+
SchedulerConfig::default_for_testing(),
2288+
);
2289+
let sq_meta = get_sq_metadata(&dataset, scheduler, &segment.uuid.to_string()).await;
2290+
2291+
assert_eq!(sq_meta.bounds.start, 0.0);
2292+
assert_eq!(sq_meta.bounds.end, (DIM - 1) as f64);
2293+
assert_lt!(sq_meta.bounds.end, FRAGMENT_OFFSETS[1] as f64);
2294+
}
2295+
21672296
async fn test_index(
21682297
params: VectorIndexParams,
21692298
nlist: usize,

0 commit comments

Comments
 (0)