1
1
/*
2
- * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
2
+ * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
3
3
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4
4
*
5
5
* This code is free software; you can redistribute it and/or modify it
28
28
import static org .junit .Assert .assertFalse ;
29
29
import static org .junit .Assert .assertTrue ;
30
30
31
+ import java .util .ArrayList ;
32
+ import java .util .HashSet ;
33
+ import java .util .Random ;
34
+ import java .util .function .BiFunction ;
35
+ import java .util .function .DoubleBinaryOperator ;
36
+ import java .util .function .DoubleUnaryOperator ;
37
+ import java .util .function .Function ;
38
+
31
39
import jdk .graal .compiler .core .common .type .ArithmeticOpTable ;
32
40
import jdk .graal .compiler .core .common .type .FloatStamp ;
33
41
import jdk .graal .compiler .core .common .type .Stamp ;
34
42
import jdk .graal .compiler .core .common .type .StampFactory ;
43
+ import jdk .graal .compiler .core .test .GraalCompilerTest ;
35
44
import jdk .graal .compiler .graph .test .GraphTest ;
36
45
import jdk .graal .compiler .nodes .ConstantNode ;
37
46
import jdk .graal .compiler .nodes .NodeView ;
47
+ import jdk .vm .ci .meta .JavaKind ;
38
48
import org .junit .Assert ;
39
49
import org .junit .Test ;
40
50
41
- import jdk .vm .ci .meta .JavaKind ;
42
-
43
51
/**
44
52
* This class tests that float stamps are created correctly for constants.
45
53
*/
@@ -53,9 +61,9 @@ public class FloatStampTest extends GraphTest {
53
61
54
62
private static FloatStamp createFloatStamp (int bits , double value ) {
55
63
if (Double .isNaN (value )) {
56
- return new FloatStamp (bits , Double . NaN , Double . NaN , false );
64
+ return FloatStamp . createNaN (bits );
57
65
}
58
- return new FloatStamp (bits , value , value , true );
66
+ return FloatStamp . create (bits , value , value , true );
59
67
}
60
68
61
69
@ Test
@@ -152,8 +160,10 @@ public void testMeetJoin() {
152
160
153
161
@ Test
154
162
public void testIllegalJoin () {
155
- assertFalse (new FloatStamp (32 , 0 , Float .POSITIVE_INFINITY , true ).join (new FloatStamp (32 , Float .NEGATIVE_INFINITY , -Float .MIN_VALUE , true )).hasValues ());
156
- assertFalse (new FloatStamp (32 , Float .NaN , Float .NaN , false ).join (new FloatStamp (32 , 0 , 0 , true )).hasValues ());
163
+ assertFalse (FloatStamp .create (32 , 0 , Float .POSITIVE_INFINITY , true ).join (FloatStamp .create (32 , Float .NEGATIVE_INFINITY , -Float .MIN_VALUE , true )).hasValues ());
164
+ assertFalse (FloatStamp .create (32 , Float .NaN , Float .NaN , false ).join (FloatStamp .create (32 , 0 , 0 , true )).hasValues ());
165
+ assertTrue (((FloatStamp ) FloatStamp .create (32 , 0 , Float .POSITIVE_INFINITY , false ).join (FloatStamp .create (32 , Float .NEGATIVE_INFINITY , -Float .MIN_VALUE , false ))).isNaN ());
166
+ assertTrue (((FloatStamp ) FloatStamp .create (32 , Float .NaN , Float .NaN , false ).join (FloatStamp .create (32 , 0 , 0 , false ))).isNaN ());
157
167
}
158
168
159
169
@ Test
@@ -195,4 +205,218 @@ public void testUnaryOpFoldEmpty() {
195
205
}
196
206
}
197
207
}
208
+
209
+ @ Test
210
+ public void testFoldStamp () {
211
+ runFoldStamp (32 );
212
+ runFoldStamp (64 );
213
+ }
214
+
215
+ static void runFoldStamp (int bits ) {
216
+ Random random = GraalCompilerTest .getRandomInstance ();
217
+ ArrayList <FloatStamp > stamps = generateStamps (bits , random );
218
+ verify (bits , random , stamps );
219
+ }
220
+
221
+ private static ArrayList <FloatStamp > generateStamps (int bits , Random random ) {
222
+ double [] specialValues ;
223
+ if (bits == Float .SIZE ) {
224
+ specialValues = new double [floatNonNaNs .length ];
225
+ for (int i = 0 ; i < floatNonNaNs .length ; i ++) {
226
+ specialValues [i ] = floatNonNaNs [i ];
227
+ }
228
+ } else {
229
+ specialValues = doubleNonNaNs ;
230
+ }
231
+ ArrayList <FloatStamp > stamps = new ArrayList <>();
232
+ FloatStamp nan = FloatStamp .createNaN (bits );
233
+ stamps .add (nan );
234
+ stamps .add (nan .empty ());
235
+ for (int i = 0 ; i < specialValues .length ; i ++) {
236
+ double currentValue = specialValues [i ];
237
+ for (int j = i ; j < specialValues .length ; j ++) {
238
+ double otherValue = specialValues [i ];
239
+ if (Double .compare (currentValue , otherValue ) > 0 ) {
240
+ stamps .add (FloatStamp .create (bits , otherValue , currentValue , true ));
241
+ stamps .add (FloatStamp .create (bits , otherValue , currentValue , false ));
242
+ } else {
243
+ stamps .add (FloatStamp .create (bits , currentValue , otherValue , true ));
244
+ stamps .add (FloatStamp .create (bits , currentValue , otherValue , false ));
245
+ }
246
+ }
247
+
248
+ for (int j = 0 ; j < 10 ; j ++) {
249
+ double otherBound ;
250
+ if (bits == Float .SIZE ) {
251
+ otherBound = Float .intBitsToFloat (random .nextInt ());
252
+ } else {
253
+ otherBound = Double .longBitsToDouble (random .nextLong ());
254
+ }
255
+ if (Double .isNaN (otherBound )) {
256
+ continue ;
257
+ }
258
+
259
+ if (Double .compare (currentValue , otherBound ) < 0 ) {
260
+ stamps .add (FloatStamp .create (bits , currentValue , otherBound , true ));
261
+ stamps .add (FloatStamp .create (bits , currentValue , otherBound , false ));
262
+ } else {
263
+ stamps .add (FloatStamp .create (bits , otherBound , currentValue , true ));
264
+ stamps .add (FloatStamp .create (bits , otherBound , currentValue , false ));
265
+ }
266
+ }
267
+ }
268
+
269
+ for (int i = 0 ; i < 10 ; i ++) {
270
+ double first ;
271
+ double second ;
272
+ if (bits == Float .SIZE ) {
273
+ first = Float .intBitsToFloat (random .nextInt ());
274
+ second = Float .intBitsToFloat (random .nextInt ());
275
+ } else {
276
+ first = Double .longBitsToDouble (random .nextLong ());
277
+ second = Double .longBitsToDouble (random .nextLong ());
278
+ }
279
+ if (Double .isNaN (first ) || Double .isNaN (second )) {
280
+ continue ;
281
+ }
282
+
283
+ if (Double .compare (first , second ) > 0 ) {
284
+ double temp = first ;
285
+ first = second ;
286
+ second = temp ;
287
+ }
288
+
289
+ stamps .add (FloatStamp .create (bits , first , first , true ));
290
+ stamps .add (FloatStamp .create (bits , first , first , false ));
291
+ stamps .add (FloatStamp .create (bits , first , second , true ));
292
+ stamps .add (FloatStamp .create (bits , first , second , false ));
293
+ }
294
+
295
+ return stamps ;
296
+ }
297
+
298
+ private static HashSet <Double > sample (Random random , FloatStamp stamp ) {
299
+ HashSet <Double > samples = HashSet .newHashSet (20 );
300
+ if (stamp .isEmpty ()) {
301
+ return samples ;
302
+ }
303
+
304
+ if (!stamp .isNonNaN ()) {
305
+ samples .add (Double .NaN );
306
+ if (stamp .isNaN ()) {
307
+ return samples ;
308
+ }
309
+ }
310
+
311
+ samples .add (stamp .lowerBound ());
312
+ samples .add (stamp .upperBound ());
313
+ if (stamp .lowerBound () == stamp .upperBound ()) {
314
+ return samples ;
315
+ }
316
+
317
+ double neighbor = stamp .getBits () == Float .SIZE ? Math .nextUp ((float ) stamp .lowerBound ()) : Math .nextUp (stamp .lowerBound ());
318
+ samples .add (neighbor );
319
+ neighbor = stamp .getBits () == Float .SIZE ? Math .nextDown ((float ) stamp .upperBound ()) : Math .nextDown (stamp .upperBound ());
320
+ samples .add (neighbor );
321
+
322
+ if (stamp .getBits () == Float .SIZE ) {
323
+ for (double d : floatNonNaNs ) {
324
+ if (stamp .contains (d )) {
325
+ samples .add (d );
326
+ }
327
+ }
328
+ } else {
329
+ for (double d : doubleNonNaNs ) {
330
+ if (stamp .contains (d )) {
331
+ samples .add (d );
332
+ }
333
+ }
334
+ }
335
+
336
+ double lowerBound = stamp .lowerBound ();
337
+ double upperBound = stamp .upperBound ();
338
+ if (lowerBound == Double .NEGATIVE_INFINITY ) {
339
+ lowerBound = stamp .getBits () == Float .SIZE ? -Float .MAX_VALUE : -Double .MAX_VALUE ;
340
+ }
341
+ if (upperBound == Double .POSITIVE_INFINITY ) {
342
+ upperBound = stamp .getBits () == Float .SIZE ? Float .MAX_VALUE : Double .MAX_VALUE ;
343
+ }
344
+ for (int i = 0 ; i < 10 ; i ++) {
345
+ double current ;
346
+ if (stamp .getBits () == Float .SIZE ) {
347
+ current = random .nextFloat ((float ) lowerBound , (float ) upperBound );
348
+ } else {
349
+ current = random .nextDouble (lowerBound , upperBound );
350
+ }
351
+ samples .add (current );
352
+ }
353
+ return samples ;
354
+ }
355
+
356
+ private static void verify (int bits , Random random , ArrayList <FloatStamp > stamps ) {
357
+ ArrayList <double []> samples = new ArrayList <>(stamps .size ());
358
+ for (FloatStamp stamp : stamps ) {
359
+ HashSet <Double > sampleSet = sample (random , stamp );
360
+ double [] sampleArray = new double [sampleSet .size ()];
361
+ int i = 0 ;
362
+ for (double d : sampleSet ) {
363
+ sampleArray [i ] = d ;
364
+ i ++;
365
+ }
366
+ samples .add (sampleArray );
367
+ }
368
+
369
+ for (int i = 0 ; i < stamps .size (); i ++) {
370
+ FloatStamp stamp = stamps .get (i );
371
+ double [] sample = samples .get (i );
372
+ verifyUnary (stamp , sample , FloatStamp .OPS .getAbs ()::foldStamp , Math ::abs );
373
+ verifyUnary (stamp , sample , FloatStamp .OPS .getNeg ()::foldStamp , x -> -x );
374
+ verifyUnary (stamp , sample , FloatStamp .OPS .getSqrt ()::foldStamp , bits == Float .SIZE ? x -> (float ) Math .sqrt (x ) : Math ::sqrt );
375
+ }
376
+
377
+ for (int i = 0 ; i < stamps .size (); i ++) {
378
+ FloatStamp stamp1 = stamps .get (i );
379
+ double [] sample1 = samples .get (i );
380
+ for (int j = i ; j < stamps .size (); j ++) {
381
+ FloatStamp stamp2 = stamps .get (j );
382
+ double [] sample2 = samples .get (j );
383
+ verifyBinary (stamp1 , stamp2 , sample1 , sample2 , FloatStamp .OPS .getAdd ()::foldStamp , bits == Float .SIZE ? (x , y ) -> (float ) x + (float ) y : Double ::sum );
384
+ verifyBinary (stamp1 , stamp2 , sample1 , sample2 , FloatStamp .OPS .getDiv ()::foldStamp , bits == Float .SIZE ? (x , y ) -> (float ) x / (float ) y : (x , y ) -> x / y );
385
+ verifyBinary (stamp1 , stamp2 , sample1 , sample2 , FloatStamp .OPS .getMax ()::foldStamp , Math ::max );
386
+ verifyBinary (stamp1 , stamp2 , sample1 , sample2 , FloatStamp .OPS .getMin ()::foldStamp , Math ::min );
387
+ verifyBinary (stamp1 , stamp2 , sample1 , sample2 , FloatStamp .OPS .getMul ()::foldStamp , bits == Float .SIZE ? (x , y ) -> (float ) x * (float ) y : (x , y ) -> x * y );
388
+ verifyBinary (stamp1 , stamp2 , sample1 , sample2 , FloatStamp .OPS .getSub ()::foldStamp , bits == Float .SIZE ? (x , y ) -> (float ) x - (float ) y : (x , y ) -> x - y );
389
+ }
390
+ }
391
+ }
392
+
393
+ private static void verifyUnary (FloatStamp stamp , double [] samples , Function <FloatStamp , Stamp > compute , DoubleUnaryOperator op ) {
394
+ FloatStamp res = (FloatStamp ) compute .apply (stamp );
395
+ if (stamp .isEmpty ()) {
396
+ assertTrue (res .isEmpty ());
397
+ return ;
398
+ }
399
+
400
+ for (double x : samples ) {
401
+ double y = op .applyAsDouble (x );
402
+ assertTrue (stamp .getBits () == Double .SIZE || (Double .compare ((float ) x , x ) == 0 && Double .compare ((float ) y , y ) == 0 ));
403
+ assertTrue (res .contains (y ));
404
+ }
405
+ }
406
+
407
+ private static void verifyBinary (FloatStamp stamp1 , FloatStamp stamp2 , double [] sample1 , double [] sample2 , BiFunction <FloatStamp , FloatStamp , Stamp > compute , DoubleBinaryOperator op ) {
408
+ FloatStamp res = (FloatStamp ) compute .apply (stamp1 , stamp2 );
409
+ if (stamp1 .isEmpty () || stamp2 .isEmpty ()) {
410
+ assertTrue (res .isEmpty ());
411
+ return ;
412
+ }
413
+
414
+ for (double x1 : sample1 ) {
415
+ for (double x2 : sample2 ) {
416
+ double y = op .applyAsDouble (x1 , x2 );
417
+ assertTrue (stamp1 .getBits () == Double .SIZE || (Double .compare ((float ) x1 , x1 ) == 0 && Double .compare ((float ) x2 , x2 ) == 0 && Double .compare ((float ) y , y ) == 0 ));
418
+ assertTrue (res .contains (y ));
419
+ }
420
+ }
421
+ }
198
422
}
0 commit comments