Skip to content

Commit f03de5b

Browse files
committed
[GR-4353] Improve stamps for floating point arithmetic
PullRequest: graal/18438
2 parents 117a91e + d767298 commit f03de5b

File tree

13 files changed

+910
-301
lines changed

13 files changed

+910
-301
lines changed

compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/core/test/MathCopySignStampTest.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -70,13 +70,13 @@ public static float floatCopySign(float magnitude, float sign) {
7070
@Test
7171
public void testFloatCopySign() throws InvalidInstalledCodeException {
7272
for (float f1 : floatValues) {
73-
FloatStamp s1 = new FloatStamp(32, f1, f1, !Float.isNaN(f1));
73+
FloatStamp s1 = FloatStamp.create(32, f1, f1, !Float.isNaN(f1));
7474
for (float f2 : floatValues) {
75-
FloatStamp s2 = new FloatStamp(32, f2, f2, !Float.isNaN(f2));
75+
FloatStamp s2 = FloatStamp.create(32, f2, f2, !Float.isNaN(f2));
7676
for (float f3 : floatValues) {
77-
FloatStamp s3 = new FloatStamp(32, f3, f3, !Float.isNaN(f3));
77+
FloatStamp s3 = FloatStamp.create(32, f3, f3, !Float.isNaN(f3));
7878
for (float f4 : floatValues) {
79-
FloatStamp s4 = new FloatStamp(32, f4, f4, !Float.isNaN(f4));
79+
FloatStamp s4 = FloatStamp.create(32, f4, f4, !Float.isNaN(f4));
8080
stampsToBind = new Stamp[]{s1.meet(s2), s3.meet(s4)};
8181
InstalledCode code = getCode(getResolvedJavaMethod("floatCopySign"), null, true);
8282
Assert.assertEquals(floatCopySign(f1, f3), (float) code.executeVarargs(f1, f3), 0);
@@ -125,13 +125,13 @@ public static double doubleCopySign(double magnitude, double sign) {
125125
@Test
126126
public void testDoubleCopySign() throws InvalidInstalledCodeException {
127127
for (double d1 : doubleValues) {
128-
FloatStamp s1 = new FloatStamp(64, d1, d1, !Double.isNaN(d1));
128+
FloatStamp s1 = FloatStamp.create(64, d1, d1, !Double.isNaN(d1));
129129
for (double d2 : doubleValues) {
130-
FloatStamp s2 = new FloatStamp(64, d2, d2, !Double.isNaN(d2));
130+
FloatStamp s2 = FloatStamp.create(64, d2, d2, !Double.isNaN(d2));
131131
for (double d3 : doubleValues) {
132-
FloatStamp s3 = new FloatStamp(64, d3, d3, !Double.isNaN(d3));
132+
FloatStamp s3 = FloatStamp.create(64, d3, d3, !Double.isNaN(d3));
133133
for (double d4 : doubleValues) {
134-
FloatStamp s4 = new FloatStamp(64, d4, d4, !Double.isNaN(d4));
134+
FloatStamp s4 = FloatStamp.create(64, d4, d4, !Double.isNaN(d4));
135135
stampsToBind = new Stamp[]{s1.meet(s2), s3.meet(s4)};
136136
InstalledCode code = getCode(getResolvedJavaMethod("doubleCopySign"), null, true);
137137
Assert.assertEquals(doubleCopySign(d1, d3), (double) code.executeVarargs(d1, d3), 0);

compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/core/test/MathSignumStampTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -68,9 +68,9 @@ public static float floatSignum(float f) {
6868
@Test
6969
public void testFloatSignum() throws InvalidInstalledCodeException {
7070
for (float f1 : floatValues) {
71-
FloatStamp s1 = new FloatStamp(32, f1, f1, !Float.isNaN(f1));
71+
FloatStamp s1 = FloatStamp.create(32, f1, f1, !Float.isNaN(f1));
7272
for (float f2 : floatValues) {
73-
FloatStamp s2 = new FloatStamp(32, f2, f2, !Float.isNaN(f2));
73+
FloatStamp s2 = FloatStamp.create(32, f2, f2, !Float.isNaN(f2));
7474
stampsToBind = new Stamp[]{s1.meet(s2)};
7575
InstalledCode code = getCode(getResolvedJavaMethod("floatSignum"), null, true);
7676
Assert.assertEquals(floatSignum(f1), (float) code.executeVarargs(f1), 0);
@@ -97,9 +97,9 @@ public static double doubleSignum(double d) {
9797
@Test
9898
public void testDoubleSignum() throws InvalidInstalledCodeException {
9999
for (double d1 : doubleValues) {
100-
FloatStamp s1 = new FloatStamp(64, d1, d1, !Double.isNaN(d1));
100+
FloatStamp s1 = FloatStamp.create(64, d1, d1, !Double.isNaN(d1));
101101
for (double d2 : doubleValues) {
102-
FloatStamp s2 = new FloatStamp(64, d2, d2, !Double.isNaN(d2));
102+
FloatStamp s2 = FloatStamp.create(64, d2, d2, !Double.isNaN(d2));
103103
stampsToBind = new Stamp[]{s1.meet(s2)};
104104
InstalledCode code = getCode(getResolvedJavaMethod("doubleSignum"), null, true);
105105
Assert.assertEquals(doubleSignum(d1), (double) code.executeVarargs(d1), 0);

compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/nodes/test/FloatStampTest.java

Lines changed: 231 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -28,18 +28,26 @@
2828
import static org.junit.Assert.assertFalse;
2929
import static org.junit.Assert.assertTrue;
3030

31+
import java.util.ArrayList;
32+
import java.util.HashSet;
33+
import java.util.Random;
34+
import java.util.function.BiFunction;
35+
import java.util.function.DoubleBinaryOperator;
36+
import java.util.function.DoubleUnaryOperator;
37+
import java.util.function.Function;
38+
3139
import jdk.graal.compiler.core.common.type.ArithmeticOpTable;
3240
import jdk.graal.compiler.core.common.type.FloatStamp;
3341
import jdk.graal.compiler.core.common.type.Stamp;
3442
import jdk.graal.compiler.core.common.type.StampFactory;
43+
import jdk.graal.compiler.core.test.GraalCompilerTest;
3544
import jdk.graal.compiler.graph.test.GraphTest;
3645
import jdk.graal.compiler.nodes.ConstantNode;
3746
import jdk.graal.compiler.nodes.NodeView;
47+
import jdk.vm.ci.meta.JavaKind;
3848
import org.junit.Assert;
3949
import org.junit.Test;
4050

41-
import jdk.vm.ci.meta.JavaKind;
42-
4351
/**
4452
* This class tests that float stamps are created correctly for constants.
4553
*/
@@ -53,9 +61,9 @@ public class FloatStampTest extends GraphTest {
5361

5462
private static FloatStamp createFloatStamp(int bits, double value) {
5563
if (Double.isNaN(value)) {
56-
return new FloatStamp(bits, Double.NaN, Double.NaN, false);
64+
return FloatStamp.createNaN(bits);
5765
}
58-
return new FloatStamp(bits, value, value, true);
66+
return FloatStamp.create(bits, value, value, true);
5967
}
6068

6169
@Test
@@ -152,8 +160,10 @@ public void testMeetJoin() {
152160

153161
@Test
154162
public void testIllegalJoin() {
155-
assertFalse(new FloatStamp(32, 0, Float.POSITIVE_INFINITY, true).join(new FloatStamp(32, Float.NEGATIVE_INFINITY, -Float.MIN_VALUE, true)).hasValues());
156-
assertFalse(new FloatStamp(32, Float.NaN, Float.NaN, false).join(new FloatStamp(32, 0, 0, true)).hasValues());
163+
assertFalse(FloatStamp.create(32, 0, Float.POSITIVE_INFINITY, true).join(FloatStamp.create(32, Float.NEGATIVE_INFINITY, -Float.MIN_VALUE, true)).hasValues());
164+
assertFalse(FloatStamp.create(32, Float.NaN, Float.NaN, false).join(FloatStamp.create(32, 0, 0, true)).hasValues());
165+
assertTrue(((FloatStamp) FloatStamp.create(32, 0, Float.POSITIVE_INFINITY, false).join(FloatStamp.create(32, Float.NEGATIVE_INFINITY, -Float.MIN_VALUE, false))).isNaN());
166+
assertTrue(((FloatStamp) FloatStamp.create(32, Float.NaN, Float.NaN, false).join(FloatStamp.create(32, 0, 0, false))).isNaN());
157167
}
158168

159169
@Test
@@ -195,4 +205,218 @@ public void testUnaryOpFoldEmpty() {
195205
}
196206
}
197207
}
208+
209+
@Test
210+
public void testFoldStamp() {
211+
runFoldStamp(32);
212+
runFoldStamp(64);
213+
}
214+
215+
static void runFoldStamp(int bits) {
216+
Random random = GraalCompilerTest.getRandomInstance();
217+
ArrayList<FloatStamp> stamps = generateStamps(bits, random);
218+
verify(bits, random, stamps);
219+
}
220+
221+
private static ArrayList<FloatStamp> generateStamps(int bits, Random random) {
222+
double[] specialValues;
223+
if (bits == Float.SIZE) {
224+
specialValues = new double[floatNonNaNs.length];
225+
for (int i = 0; i < floatNonNaNs.length; i++) {
226+
specialValues[i] = floatNonNaNs[i];
227+
}
228+
} else {
229+
specialValues = doubleNonNaNs;
230+
}
231+
ArrayList<FloatStamp> stamps = new ArrayList<>();
232+
FloatStamp nan = FloatStamp.createNaN(bits);
233+
stamps.add(nan);
234+
stamps.add(nan.empty());
235+
for (int i = 0; i < specialValues.length; i++) {
236+
double currentValue = specialValues[i];
237+
for (int j = i; j < specialValues.length; j++) {
238+
double otherValue = specialValues[i];
239+
if (Double.compare(currentValue, otherValue) > 0) {
240+
stamps.add(FloatStamp.create(bits, otherValue, currentValue, true));
241+
stamps.add(FloatStamp.create(bits, otherValue, currentValue, false));
242+
} else {
243+
stamps.add(FloatStamp.create(bits, currentValue, otherValue, true));
244+
stamps.add(FloatStamp.create(bits, currentValue, otherValue, false));
245+
}
246+
}
247+
248+
for (int j = 0; j < 10; j++) {
249+
double otherBound;
250+
if (bits == Float.SIZE) {
251+
otherBound = Float.intBitsToFloat(random.nextInt());
252+
} else {
253+
otherBound = Double.longBitsToDouble(random.nextLong());
254+
}
255+
if (Double.isNaN(otherBound)) {
256+
continue;
257+
}
258+
259+
if (Double.compare(currentValue, otherBound) < 0) {
260+
stamps.add(FloatStamp.create(bits, currentValue, otherBound, true));
261+
stamps.add(FloatStamp.create(bits, currentValue, otherBound, false));
262+
} else {
263+
stamps.add(FloatStamp.create(bits, otherBound, currentValue, true));
264+
stamps.add(FloatStamp.create(bits, otherBound, currentValue, false));
265+
}
266+
}
267+
}
268+
269+
for (int i = 0; i < 10; i++) {
270+
double first;
271+
double second;
272+
if (bits == Float.SIZE) {
273+
first = Float.intBitsToFloat(random.nextInt());
274+
second = Float.intBitsToFloat(random.nextInt());
275+
} else {
276+
first = Double.longBitsToDouble(random.nextLong());
277+
second = Double.longBitsToDouble(random.nextLong());
278+
}
279+
if (Double.isNaN(first) || Double.isNaN(second)) {
280+
continue;
281+
}
282+
283+
if (Double.compare(first, second) > 0) {
284+
double temp = first;
285+
first = second;
286+
second = temp;
287+
}
288+
289+
stamps.add(FloatStamp.create(bits, first, first, true));
290+
stamps.add(FloatStamp.create(bits, first, first, false));
291+
stamps.add(FloatStamp.create(bits, first, second, true));
292+
stamps.add(FloatStamp.create(bits, first, second, false));
293+
}
294+
295+
return stamps;
296+
}
297+
298+
private static HashSet<Double> sample(Random random, FloatStamp stamp) {
299+
HashSet<Double> samples = HashSet.newHashSet(20);
300+
if (stamp.isEmpty()) {
301+
return samples;
302+
}
303+
304+
if (!stamp.isNonNaN()) {
305+
samples.add(Double.NaN);
306+
if (stamp.isNaN()) {
307+
return samples;
308+
}
309+
}
310+
311+
samples.add(stamp.lowerBound());
312+
samples.add(stamp.upperBound());
313+
if (stamp.lowerBound() == stamp.upperBound()) {
314+
return samples;
315+
}
316+
317+
double neighbor = stamp.getBits() == Float.SIZE ? Math.nextUp((float) stamp.lowerBound()) : Math.nextUp(stamp.lowerBound());
318+
samples.add(neighbor);
319+
neighbor = stamp.getBits() == Float.SIZE ? Math.nextDown((float) stamp.upperBound()) : Math.nextDown(stamp.upperBound());
320+
samples.add(neighbor);
321+
322+
if (stamp.getBits() == Float.SIZE) {
323+
for (double d : floatNonNaNs) {
324+
if (stamp.contains(d)) {
325+
samples.add(d);
326+
}
327+
}
328+
} else {
329+
for (double d : doubleNonNaNs) {
330+
if (stamp.contains(d)) {
331+
samples.add(d);
332+
}
333+
}
334+
}
335+
336+
double lowerBound = stamp.lowerBound();
337+
double upperBound = stamp.upperBound();
338+
if (lowerBound == Double.NEGATIVE_INFINITY) {
339+
lowerBound = stamp.getBits() == Float.SIZE ? -Float.MAX_VALUE : -Double.MAX_VALUE;
340+
}
341+
if (upperBound == Double.POSITIVE_INFINITY) {
342+
upperBound = stamp.getBits() == Float.SIZE ? Float.MAX_VALUE : Double.MAX_VALUE;
343+
}
344+
for (int i = 0; i < 10; i++) {
345+
double current;
346+
if (stamp.getBits() == Float.SIZE) {
347+
current = random.nextFloat((float) lowerBound, (float) upperBound);
348+
} else {
349+
current = random.nextDouble(lowerBound, upperBound);
350+
}
351+
samples.add(current);
352+
}
353+
return samples;
354+
}
355+
356+
private static void verify(int bits, Random random, ArrayList<FloatStamp> stamps) {
357+
ArrayList<double[]> samples = new ArrayList<>(stamps.size());
358+
for (FloatStamp stamp : stamps) {
359+
HashSet<Double> sampleSet = sample(random, stamp);
360+
double[] sampleArray = new double[sampleSet.size()];
361+
int i = 0;
362+
for (double d : sampleSet) {
363+
sampleArray[i] = d;
364+
i++;
365+
}
366+
samples.add(sampleArray);
367+
}
368+
369+
for (int i = 0; i < stamps.size(); i++) {
370+
FloatStamp stamp = stamps.get(i);
371+
double[] sample = samples.get(i);
372+
verifyUnary(stamp, sample, FloatStamp.OPS.getAbs()::foldStamp, Math::abs);
373+
verifyUnary(stamp, sample, FloatStamp.OPS.getNeg()::foldStamp, x -> -x);
374+
verifyUnary(stamp, sample, FloatStamp.OPS.getSqrt()::foldStamp, bits == Float.SIZE ? x -> (float) Math.sqrt(x) : Math::sqrt);
375+
}
376+
377+
for (int i = 0; i < stamps.size(); i++) {
378+
FloatStamp stamp1 = stamps.get(i);
379+
double[] sample1 = samples.get(i);
380+
for (int j = i; j < stamps.size(); j++) {
381+
FloatStamp stamp2 = stamps.get(j);
382+
double[] sample2 = samples.get(j);
383+
verifyBinary(stamp1, stamp2, sample1, sample2, FloatStamp.OPS.getAdd()::foldStamp, bits == Float.SIZE ? (x, y) -> (float) x + (float) y : Double::sum);
384+
verifyBinary(stamp1, stamp2, sample1, sample2, FloatStamp.OPS.getDiv()::foldStamp, bits == Float.SIZE ? (x, y) -> (float) x / (float) y : (x, y) -> x / y);
385+
verifyBinary(stamp1, stamp2, sample1, sample2, FloatStamp.OPS.getMax()::foldStamp, Math::max);
386+
verifyBinary(stamp1, stamp2, sample1, sample2, FloatStamp.OPS.getMin()::foldStamp, Math::min);
387+
verifyBinary(stamp1, stamp2, sample1, sample2, FloatStamp.OPS.getMul()::foldStamp, bits == Float.SIZE ? (x, y) -> (float) x * (float) y : (x, y) -> x * y);
388+
verifyBinary(stamp1, stamp2, sample1, sample2, FloatStamp.OPS.getSub()::foldStamp, bits == Float.SIZE ? (x, y) -> (float) x - (float) y : (x, y) -> x - y);
389+
}
390+
}
391+
}
392+
393+
private static void verifyUnary(FloatStamp stamp, double[] samples, Function<FloatStamp, Stamp> compute, DoubleUnaryOperator op) {
394+
FloatStamp res = (FloatStamp) compute.apply(stamp);
395+
if (stamp.isEmpty()) {
396+
assertTrue(res.isEmpty());
397+
return;
398+
}
399+
400+
for (double x : samples) {
401+
double y = op.applyAsDouble(x);
402+
assertTrue(stamp.getBits() == Double.SIZE || (Double.compare((float) x, x) == 0 && Double.compare((float) y, y) == 0));
403+
assertTrue(res.contains(y));
404+
}
405+
}
406+
407+
private static void verifyBinary(FloatStamp stamp1, FloatStamp stamp2, double[] sample1, double[] sample2, BiFunction<FloatStamp, FloatStamp, Stamp> compute, DoubleBinaryOperator op) {
408+
FloatStamp res = (FloatStamp) compute.apply(stamp1, stamp2);
409+
if (stamp1.isEmpty() || stamp2.isEmpty()) {
410+
assertTrue(res.isEmpty());
411+
return;
412+
}
413+
414+
for (double x1 : sample1) {
415+
for (double x2 : sample2) {
416+
double y = op.applyAsDouble(x1, x2);
417+
assertTrue(stamp1.getBits() == Double.SIZE || (Double.compare((float) x1, x1) == 0 && Double.compare((float) x2, x2) == 0 && Double.compare((float) y, y) == 0));
418+
assertTrue(res.contains(y));
419+
}
420+
}
421+
}
198422
}

0 commit comments

Comments
 (0)