Skip to content

Commit 4ca96c1

Browse files
authored
Introduce a vectorize soarDistance function (#129744)
This commit replaces the method #soarResidual with a method call #soarDistance which perfoms better for computing soar distances.
1 parent 1d913f3 commit 4ca96c1

File tree

6 files changed

+27
-21
lines changed

6 files changed

+27
-21
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,20 +237,21 @@ public static void subtract(float[] v1, float[] v2, float[] result) {
237237
}
238238

239239
/**
240-
* calculates the spill-over score for a vector and a centroid, given its residual with
241-
* its actually nearest centroid
240+
* calculates the soar distance for a vector and a centroid
242241
* @param v1 the vector
243242
* @param centroid the centroid
244243
* @param originalResidual the residual with the actually nearest centroid
245-
* @return the spill-over score (soar)
244+
* @param soarLambda the lambda parameter
245+
* @param rnorm distance to the nearest centroid
246+
* @return the soar distance
246247
*/
247-
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
248+
public static float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
248249
if (v1.length != centroid.length) {
249250
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
250251
}
251252
if (originalResidual.length != v1.length) {
252253
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
253254
}
254-
return IMPL.soarResidual(v1, centroid, originalResidual);
255+
return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
255256
}
256257
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.lucene.util.BitUtil;
1313
import org.apache.lucene.util.Constants;
14+
import org.apache.lucene.util.VectorUtil;
1415

1516
final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
1617

@@ -139,15 +140,16 @@ public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float
139140
}
140141

141142
@Override
142-
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
143+
public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
143144
assert v1.length == centroid.length;
144145
assert v1.length == originalResidual.length;
146+
float dsq = VectorUtil.squareDistance(v1, centroid);
145147
float proj = 0;
146148
for (int i = 0; i < v1.length; i++) {
147149
float djk = v1[i] - centroid[i];
148150
proj = fma(djk, originalResidual[i], proj);
149151
}
150-
return proj;
152+
return dsq + soarLambda * proj * proj / rnorm;
151153
}
152154

153155
public static int ipByteBitImpl(byte[] q, byte[] d) {

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ public interface ESVectorUtilSupport {
3737

3838
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
3939

40-
float soarResidual(float[] v1, float[] centroid, float[] originalResidual);
40+
float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);
4141

4242
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,14 +368,17 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
368368
}
369369

370370
@Override
371-
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
371+
public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
372372
assert v1.length == centroid.length;
373373
assert v1.length == originalResidual.length;
374374
float proj = 0;
375+
float dsq = 0;
375376
int i = 0;
376377
if (v1.length > 2 * FLOAT_SPECIES.length()) {
377378
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
378379
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
380+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
381+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
379382
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
380383
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
381384
// one
@@ -384,13 +387,15 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual
384387
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
385388
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
386389
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);
390+
acc1 = fma(djkVec0, djkVec0, acc1);
387391

388392
// two
389393
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
390394
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
391395
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
392396
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
393397
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
398+
acc2 = fma(djkVec1, djkVec1, acc2);
394399
}
395400
// vector tail
396401
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
@@ -399,15 +404,18 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual
399404
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
400405
FloatVector djkVec = v1Vec.sub(centroidVec);
401406
projVec1 = fma(djkVec, originalResidualVec, projVec1);
407+
acc1 = fma(djkVec, djkVec, acc1);
402408
}
403409
proj += projVec1.add(projVec2).reduceLanes(ADD);
410+
dsq += acc1.add(acc2).reduceLanes(ADD);
404411
}
405412
// tail
406413
for (; i < v1.length; i++) {
407414
float djk = v1[i] - centroid[i];
408415
proj = fma(djk, originalResidual[i], proj);
416+
dsq = fma(djk, djk, dsq);
409417
}
410-
return proj;
418+
return dsq + soarLambda * proj * proj / rnorm;
411419
}
412420

413421
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ public void testOsqGridPoints() {
268268
}
269269
}
270270

271-
public void testSoarOverspillScore() {
271+
public void testSoarDistance() {
272272
int size = random().nextInt(128, 512);
273273
float deltaEps = 1e-5f * size;
274274
var vector = new float[size];
@@ -279,8 +279,10 @@ public void testSoarOverspillScore() {
279279
centroid[i] = random().nextFloat();
280280
preResidual[i] = random().nextFloat();
281281
}
282-
var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
283-
var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
282+
float soarLambda = random().nextFloat();
283+
float rnorm = random().nextFloat();
284+
var expected = defaultedProvider.getVectorUtilSupport().soarDistance(vector, centroid, preResidual, soarLambda, rnorm);
285+
var result = defOrPanamaProvider.getVectorUtilSupport().soarDistance(vector, centroid, preResidual, soarLambda, rnorm);
284286
assertEquals(expected, result, deltaEps);
285287
}
286288

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
202202
continue;
203203
}
204204
float[] neighborCentroid = centroids[neighbor];
205-
float soar = distanceSoar(diffs, vector, neighborCentroid, vectorCentroidDist);
205+
float soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
206206
if (soar < minSoar) {
207207
bestAssignment = neighbor;
208208
minSoar = soar;
@@ -215,13 +215,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
215215
return spilledAssignments;
216216
}
217217

218-
private float distanceSoar(float[] residual, float[] vector, float[] centroid, float rnorm) {
219-
// TODO: combine these to be more efficient
220-
float dsq = VectorUtil.squareDistance(vector, centroid);
221-
float rproj = ESVectorUtil.soarResidual(vector, centroid, residual);
222-
return dsq + soarLambda * rproj * rproj / rnorm;
223-
}
224-
225218
/**
226219
* cluster using a lloyd k-means algorithm that is not neighbor aware
227220
*

0 commit comments

Comments
 (0)