Skip to content

Commit db28cf9

Browse files
committed
added preconditioning to help edge cases where quantization causes vectors to become indistinguishable
1 parent 6dd1949 commit db28cf9

File tree

17 files changed

+467
-376
lines changed

17 files changed

+467
-376
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.test.knn;
11+
12+
import org.apache.lucene.store.Directory;
13+
import org.apache.lucene.store.FSDirectory;
14+
import org.apache.lucene.store.IOContext;
15+
import org.apache.lucene.store.MMapDirectory;
16+
import org.apache.lucene.store.NativeFSLockFactory;
17+
import org.apache.lucene.store.ReadAdvice;
18+
import org.elasticsearch.index.StandardIOBehaviorHint;
19+
import org.elasticsearch.index.store.FsDirectoryFactory;
20+
21+
import java.io.IOException;
22+
import java.nio.file.Path;
23+
import java.util.Optional;
24+
import java.util.function.BiFunction;
25+
26+
public class KnnFileUtils {
27+
28+
static Directory getDirectory(Path indexPath) throws IOException {
29+
Directory dir = FSDirectory.open(indexPath);
30+
if (dir instanceof MMapDirectory mmapDir) {
31+
mmapDir.setReadAdvice(getReadAdviceFunc()); // enable madvise
32+
return new FsDirectoryFactory.HybridDirectory(NativeFSLockFactory.INSTANCE, mmapDir, 64);
33+
}
34+
return dir;
35+
}
36+
37+
private static BiFunction<String, IOContext, Optional<ReadAdvice>> getReadAdviceFunc() {
38+
return (name, context) -> {
39+
if (context.hints().contains(StandardIOBehaviorHint.INSTANCE) || name.endsWith(".cfs")) {
40+
return Optional.of(ReadAdvice.NORMAL);
41+
}
42+
return MMapDirectory.ADVISE_BY_CONTEXT.apply(name, context);
43+
};
44+
}
45+
}

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.lucene.store.FSDirectory;
3030
import org.elasticsearch.cli.ProcessInfo;
3131
import org.elasticsearch.common.Strings;
32+
import org.elasticsearch.common.io.Channels;
3233
import org.elasticsearch.common.logging.LogConfigurator;
3334
import org.elasticsearch.common.settings.Settings;
3435
import org.elasticsearch.core.PathUtils;
@@ -52,6 +53,9 @@
5253
import java.io.InputStream;
5354
import java.io.UncheckedIOException;
5455
import java.lang.management.ThreadInfo;
56+
import java.nio.ByteBuffer;
57+
import java.nio.ByteOrder;
58+
import java.nio.channels.FileChannel;
5559
import java.nio.file.Files;
5660
import java.nio.file.Path;
5761
import java.util.ArrayList;
@@ -116,7 +120,7 @@ private static String formatIndexPath(CmdLineArgs args) {
116120
return INDEX_DIR + "/" + args.docVectors().get(0).getFileName() + "-" + String.join("-", suffix) + ".index";
117121
}
118122

119-
static Codec createCodec(CmdLineArgs args) {
123+
static Codec createCodec(CmdLineArgs args, List<Path> docsPaths) throws IOException {
120124
final KnnVectorsFormat format;
121125
int quantizeBits = args.quantizeBits();
122126
if (args.indexType() == IndexType.IVF) {
@@ -128,8 +132,29 @@ static Codec createCodec(CmdLineArgs args) {
128132
"IVF index type only supports 1, 2 or 4 bits quantization, but got: " + quantizeBits
129133
);
130134
};
135+
136+
// get the actual dims
137+
int dims = args.dimensions();
138+
if (dims == -1) {
139+
Path docsPath = docsPaths.getFirst();
140+
try (FileChannel in = FileChannel.open(docsPath)) {
141+
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
142+
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
143+
if (bytesRead < 4) {
144+
long docsPathSizeInBytes = in.size();
145+
throw new IllegalArgumentException(
146+
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
147+
);
148+
}
149+
dims = preamble.getInt(0);
150+
if (dims <= 0) {
151+
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dims);
152+
}
153+
}
154+
}
155+
131156
format = new ESNextDiskBBQVectorsFormat(
132-
args.dimensions(),
157+
dims,
133158
encoding,
134159
args.ivfClusterSize(),
135160
ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER,
@@ -265,7 +290,7 @@ public static void main(String[] args) throws Exception {
265290
}
266291
logger.info("Running with Java: " + Runtime.version());
267292
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
268-
Codec codec = createCodec(cmdLineArgs);
293+
Codec codec = createCodec(cmdLineArgs, cmdLineArgs.docVectors());
269294
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
270295
MergePolicy mergePolicy = getMergePolicy(cmdLineArgs);
271296
if (cmdLineArgs.reindex() || cmdLineArgs.forceMerge()) {

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,8 @@
3434
import org.apache.lucene.index.VectorEncoding;
3535
import org.apache.lucene.index.VectorSimilarityFunction;
3636
import org.apache.lucene.store.Directory;
37-
import org.apache.lucene.store.FSDirectory;
38-
import org.apache.lucene.store.IOContext;
39-
import org.apache.lucene.store.MMapDirectory;
40-
import org.apache.lucene.store.NativeFSLockFactory;
41-
import org.apache.lucene.store.ReadAdvice;
4237
import org.apache.lucene.util.PrintStreamInfoStream;
4338
import org.elasticsearch.common.io.Channels;
44-
import org.elasticsearch.index.StandardIOBehaviorHint;
45-
import org.elasticsearch.index.store.FsDirectoryFactory;
4639

4740
import java.io.IOException;
4841
import java.io.UncheckedIOException;
@@ -54,15 +47,14 @@
5447
import java.util.ArrayList;
5548
import java.util.List;
5649
import java.util.Objects;
57-
import java.util.Optional;
5850
import java.util.concurrent.ExecutionException;
5951
import java.util.concurrent.ExecutorService;
6052
import java.util.concurrent.Executors;
6153
import java.util.concurrent.Future;
6254
import java.util.concurrent.TimeUnit;
6355
import java.util.concurrent.atomic.AtomicInteger;
64-
import java.util.function.BiFunction;
6556

57+
import static org.elasticsearch.test.knn.KnnFileUtils.getDirectory;
6658
import static org.elasticsearch.test.knn.KnnIndexTester.logger;
6759

6860
class KnnIndexer {
@@ -243,24 +235,6 @@ public boolean isEnabled(String component) {
243235
results.forceMergeTimeMS = TimeUnit.NANOSECONDS.toMillis(elapsedNSec);
244236
}
245237

246-
static Directory getDirectory(Path indexPath) throws IOException {
247-
Directory dir = FSDirectory.open(indexPath);
248-
if (dir instanceof MMapDirectory mmapDir) {
249-
mmapDir.setReadAdvice(getReadAdviceFunc()); // enable madvise
250-
return new FsDirectoryFactory.HybridDirectory(NativeFSLockFactory.INSTANCE, mmapDir, 64);
251-
}
252-
return dir;
253-
}
254-
255-
private static BiFunction<String, IOContext, Optional<ReadAdvice>> getReadAdviceFunc() {
256-
return (name, context) -> {
257-
if (context.hints().contains(StandardIOBehaviorHint.INSTANCE) || name.endsWith(".cfs")) {
258-
return Optional.of(ReadAdvice.NORMAL);
259-
}
260-
return MMapDirectory.ADVISE_BY_CONTEXT.apply(name, context);
261-
};
262-
}
263-
264238
static class IndexerThread extends Thread {
265239
private final IndexWriter iw;
266240
private final AtomicInteger numDocsIndexed;

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
import java.util.function.IntConsumer;
9191

9292
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
93+
import static org.elasticsearch.test.knn.KnnFileUtils.getDirectory;
9394
import static org.elasticsearch.test.knn.KnnIndexTester.logger;
9495
import static org.elasticsearch.test.knn.KnnIndexer.ID_FIELD;
9596
import static org.elasticsearch.test.knn.KnnIndexer.VECTOR_FIELD;
@@ -180,7 +181,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
180181
);
181182
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
182183
long startNS;
183-
try (Directory dir = KnnIndexer.getDirectory(indexPath)) {
184+
try (Directory dir = getDirectory(indexPath)) {
184185
try (DirectoryReader reader = DirectoryReader.open(dir)) {
185186
IndexSearcher searcher = searchThreads > 1 ? new IndexSearcher(reader, executorService) : new IndexSearcher(reader);
186187
byte[] targetBytes = new byte[dim];

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.index.codec.vectors.diskbbq;
1111

1212
import org.apache.lucene.index.FieldInfo;
13+
import org.apache.lucene.index.FieldInfos;
1314
import org.apache.lucene.index.FloatVectorValues;
1415
import org.apache.lucene.index.SegmentReadState;
1516
import org.apache.lucene.index.VectorEncoding;
@@ -27,6 +28,7 @@
2728
import org.elasticsearch.simdvec.ESVectorUtil;
2829

2930
import java.io.IOException;
31+
import java.util.List;
3032

3133
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
3234
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
@@ -46,7 +48,7 @@ public class ES920DiskBBQVectorsReader extends IVFVectorsReader {
4648
}
4749

4850
@Override
49-
public void doInitExtraFiles(SegmentReadState state, int version) throws IOException {
51+
public void doInitExtraFiles(SegmentReadState state, int version, FieldInfos fieldInfo) throws IOException {
5052
// no extra files to init
5153
}
5254

@@ -179,11 +181,22 @@ protected FieldEntry doReadField(
179181
}
180182

181183
@Override
182-
protected float[] preconditionVector(float[] vector) {
184+
protected void doAdditionalIntegrityChecks() throws IOException {
185+
// no-op
186+
}
187+
188+
@Override
189+
protected float[] preconditionVector(FieldInfo fieldInfo, float[] vector) {
183190
// no-op
184191
return vector;
185192
}
186193

194+
@Override
195+
protected List<IndexInput> getAdditionalCloseables() {
196+
// no-op
197+
return List.of();
198+
}
199+
187200
private static CentroidIterator getCentroidIteratorNoParent(
188201
FieldInfo fieldInfo,
189202
IndexInput centroids,

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.index.codec.vectors.diskbbq;
1111

12-
import org.apache.lucene.codecs.CodecUtil;
1312
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1413
import org.apache.lucene.index.FieldInfo;
1514
import org.apache.lucene.index.FloatVectorValues;
@@ -22,7 +21,6 @@
2221
import org.apache.lucene.util.hnsw.IntToIntFunction;
2322
import org.apache.lucene.util.packed.PackedInts;
2423
import org.apache.lucene.util.packed.PackedLongValues;
25-
import org.elasticsearch.core.IOUtils;
2624
import org.elasticsearch.core.SuppressForbidden;
2725
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
2826
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
@@ -39,6 +37,7 @@
3937
import java.nio.ByteBuffer;
4038
import java.nio.ByteOrder;
4139
import java.util.Arrays;
40+
import java.util.List;
4241

4342
import static org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans.NO_SOAR_ASSIGNMENT;
4443

@@ -379,12 +378,24 @@ public CentroidSupplier createCentroidSupplier(
379378
return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
380379
}
381380

381+
@Override
382+
public boolean createPreconditioner() throws IOException {
383+
// no-op
384+
return false;
385+
}
386+
382387
@Override
383388
public FloatVectorValues preconditionVectors(FloatVectorValues floatVectorValues) throws IOException {
384389
// no-op
385390
return floatVectorValues;
386391
}
387392

393+
@Override
394+
public List<float[]> preconditionVectors(List<float[]> vectors) {
395+
// no-op
396+
return vectors;
397+
}
398+
388399
@Override
389400
public void writeCentroids(
390401
FieldInfo fieldInfo,

0 commit comments

Comments
 (0)