Skip to content

Commit accd9f8

Browse files
committed
[GR-41901] Enforce injected branch probability invariant
PullRequest: graal/12963
2 parents ed47352 + 04b38a3 commit accd9f8

File tree

7 files changed

+214
-6
lines changed

7 files changed

+214
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation. Oracle designates this
8+
* particular file as subject to the "Classpath" exception as provided
9+
* by Oracle in the LICENSE file that accompanied this code.
10+
*
11+
* This code is distributed in the hope that it will be useful, but WITHOUT
12+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14+
* version 2 for more details (a copy is included in the LICENSE file that
15+
* accompanied this code).
16+
*
17+
* You should have received a copy of the GNU General Public License version
18+
* 2 along with this work; if not, write to the Free Software Foundation,
19+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20+
*
21+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22+
* or visit www.oracle.com if you need additional information or have any
23+
* questions.
24+
*/
25+
package org.graalvm.compiler.api.directives.test;
26+
27+
import static org.graalvm.compiler.api.directives.GraalDirectives.injectBranchProbability;
28+
import static org.graalvm.compiler.api.directives.GraalDirectives.sideEffect;
29+
30+
import org.graalvm.compiler.core.test.GraalCompilerTest;
31+
import org.graalvm.compiler.debug.GraalError;
32+
import org.graalvm.compiler.nodes.IfNode;
33+
import org.graalvm.compiler.nodes.ProfileData.ProfileSource;
34+
import org.graalvm.compiler.nodes.ShortCircuitOrNode;
35+
import org.graalvm.compiler.nodes.StructuredGraph;
36+
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
37+
38+
import org.junit.Assert;
39+
import org.junit.Test;
40+
41+
public class ProbabilityDirectiveShortCircuitTest extends GraalCompilerTest {
42+
43+
private void checkProfiles(String snippetName) {
44+
StructuredGraph graph = parseForCompile(getResolvedJavaMethod(snippetName));
45+
CanonicalizerPhase canonicalizer = createCanonicalizerPhase();
46+
createInliningPhase(canonicalizer).apply(graph, getDefaultHighTierContext());
47+
canonicalizer.apply(graph, getDefaultHighTierContext());
48+
for (IfNode ifNode : graph.getNodes(IfNode.TYPE)) {
49+
Assert.assertEquals(ifNode + " profile source", ProfileSource.INJECTED, ifNode.profileSource());
50+
}
51+
for (ShortCircuitOrNode shortCircuit : graph.getNodes(ShortCircuitOrNode.TYPE)) {
52+
Assert.assertEquals(shortCircuit + " profile source", ProfileSource.INJECTED, shortCircuit.getShortCircuitProbability().getProfileSource());
53+
}
54+
}
55+
56+
@Test
57+
public void andIf() {
58+
checkProfiles("andIfSnippet");
59+
}
60+
61+
public static boolean andIfSnippet(boolean a, int b, double c) {
62+
if (injectBranchProbability(0.125, a) && injectBranchProbability(0.125, b == 42) && injectBranchProbability(0.125, c > 0.0)) {
63+
sideEffect(); // prevent folding to a conditional
64+
return true;
65+
} else {
66+
return false;
67+
}
68+
}
69+
70+
@Test
71+
public void andConditional() {
72+
checkProfiles("andConditionalSnippet");
73+
}
74+
75+
public static boolean andConditionalSnippet(boolean a, int b, double c) {
76+
// This will fold to a conditional.
77+
if (injectBranchProbability(0.125, a) && injectBranchProbability(0.125, b == 42) && injectBranchProbability(0.125, c > 0.0)) {
78+
return true;
79+
} else {
80+
return false;
81+
}
82+
}
83+
84+
@Test(expected = GraalError.class)
85+
public void andReturn() {
86+
checkProfiles("andReturnSnippet");
87+
}
88+
89+
@BytecodeParserForceInline
90+
public static boolean andReturnSnippet(boolean a, int b, double c) {
91+
/*
92+
* This builds a graph shape with a BranchProbabilityNode used by a ValuePhi used by a
93+
* Return. That is not accepted by the simplification in BranchProbabilityNode that
94+
* propagates injected probabilities to the correct usage. Top-level snippets wanting to
95+
* inject probabilities this way have to use an explicit if statement (which will then be
96+
* folded correctly). However, inlining this pattern into a snippet works fine as long as
97+
* the call site also injects a probability.
98+
*/
99+
return injectBranchProbability(0.125, a) && injectBranchProbability(0.125, b == 42) && injectBranchProbability(0.125, c > 0.0);
100+
}
101+
102+
@Test
103+
public void andInlined() {
104+
checkProfiles("andInlinedSnippet");
105+
}
106+
107+
public static boolean andInlinedSnippet(boolean a, int b, double c) {
108+
if (injectBranchProbability(0.25, andReturnSnippet(a, b, c))) {
109+
sideEffect();
110+
return true;
111+
} else {
112+
return false;
113+
}
114+
}
115+
116+
@Test
117+
public void orInlined() {
118+
checkProfiles("orInlinedSnippet");
119+
}
120+
121+
@BytecodeParserForceInline
122+
public static boolean orHelper(Integer i) {
123+
if (injectBranchProbability(0.25, i == null) || injectBranchProbability(0.25, i < 42)) {
124+
return true;
125+
} else {
126+
return false;
127+
}
128+
}
129+
130+
public static boolean orInlinedSnippet(Integer i) {
131+
if (injectBranchProbability(0.25, orHelper(i))) {
132+
sideEffect();
133+
return true;
134+
} else {
135+
return false;
136+
}
137+
}
138+
}

compiler/src/org.graalvm.compiler.api.directives/src/org/graalvm/compiler/api/directives/GraalDirectives.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ public static long sideEffect(long a) {
162162

163163
/**
164164
* Injects a probability for the given condition into the profiling information of a branch
165-
* instruction. The probability must be a value between 0.0 and 1.0 (inclusive).
165+
* instruction. The probability must be a value between 0.0 and 1.0 (inclusive). This directive
166+
* should only be used for the condition of an if statement. The parameter condition should also
167+
* only denote a simple condition and not a combined condition involving &amp;&amp; or ||
168+
* operators.
166169
*
167170
* Example usage (it specifies that the likelihood for a to be greater than b is 90%):
168171
*

compiler/src/org.graalvm.compiler.hotspot/src/org/graalvm/compiler/hotspot/replacements/HotSpotAllocationSnippets.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ private static Class<?> validateNewInstanceClass(Class<?> type, Class<?> classCl
210210
}
211211
Class<?> nonNullType = PiNode.piCastNonNullClass(type, SnippetAnchorNode.anchor());
212212
if (probability(DEOPT_PROBABILITY,
213-
DynamicNewInstanceNode.throwsInstantiationException(nonNullType, classClass))) {
213+
DynamicNewInstanceNode.throwsInstantiationExceptionInjectedProbability(DEOPT_PROBABILITY, nonNullType, classClass))) {
214214
DeoptimizeNode.deopt(None, RuntimeConstraint);
215215
}
216216
return nonNullType;

compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/extended/BranchProbabilityNode.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@
3737
import org.graalvm.compiler.nodeinfo.NodeInfo;
3838
import org.graalvm.compiler.nodes.ConstantNode;
3939
import org.graalvm.compiler.nodes.FixedGuardNode;
40+
import org.graalvm.compiler.nodes.FrameState;
4041
import org.graalvm.compiler.nodes.IfNode;
4142
import org.graalvm.compiler.nodes.NodeView;
4243
import org.graalvm.compiler.nodes.ProfileData.BranchProbabilityData;
4344
import org.graalvm.compiler.nodes.ProfileData.ProfileSource;
4445
import org.graalvm.compiler.nodes.ReturnNode;
46+
import org.graalvm.compiler.nodes.ShortCircuitOrNode;
4547
import org.graalvm.compiler.nodes.ValueNode;
48+
import org.graalvm.compiler.nodes.ValuePhiNode;
4649
import org.graalvm.compiler.nodes.calc.ConditionalNode;
4750
import org.graalvm.compiler.nodes.calc.FloatingNode;
4851
import org.graalvm.compiler.nodes.calc.IntegerEqualsNode;
@@ -113,6 +116,9 @@ public BranchProbabilityNode(ValueNode probability, ValueNode condition) {
113116
super(TYPE, StampFactory.forKind(JavaKind.Boolean));
114117
this.probability = probability;
115118
this.condition = condition;
119+
120+
GraalError.guarantee(!(condition instanceof ShortCircuitOrNode),
121+
"Branch probabilities must be injected on simple conditions, not short-circuiting && or ||: %s", condition);
116122
}
117123

118124
public BranchProbabilityNode(ValueNode condition) {
@@ -180,6 +186,9 @@ public void simplify(SimplifierTool tool) {
180186
}
181187
}
182188
}
189+
if (!usageFound) {
190+
usageFound = hasValidPhiUsage();
191+
}
183192
if (usageFound) {
184193
ValueNode currentCondition = condition;
185194
IntegerStamp currentStamp = (IntegerStamp) currentCondition.stamp(NodeView.DEFAULT);
@@ -241,6 +250,53 @@ private boolean isSubstitutionGraph() {
241250
return hasExactlyOneUsage() && usages().first() instanceof ReturnNode;
242251
}
243252

253+
/**
254+
* Normally a branch probability should be consumed directly as a condition, but in some cases
255+
* it can be used as a value itself. For example:
256+
*
257+
* <pre>
258+
* boolean helper() {
259+
* if (probability(a, ...) || probability(b, ...) || probability(c, condition)) {
260+
* return true;
261+
* } else {
262+
* return false;
263+
* }
264+
* }
265+
*
266+
* ...
267+
* if (probability(d, helper()) {
268+
* ...
269+
* }
270+
* </pre>
271+
*
272+
* After inlining the helper, {@code probability(c, condition)} can be represented as a branch
273+
* probability node that feeds into a phi which is then used in an {@code if} condition. This is
274+
* benign if that {@code if} has an injected branch probability itself.
275+
*/
276+
private boolean hasValidPhiUsage() {
277+
for (Node usage : this.usages()) {
278+
if (usage instanceof ValuePhiNode && !((ValuePhiNode) usage).isLoopPhi()) {
279+
Node phi = usage;
280+
// We want exactly one non-state usage, and it must be a branch probability node.
281+
Node uniquePhiUsage = null;
282+
for (Node phiUsage : phi.usages()) {
283+
if (phiUsage instanceof FrameState) {
284+
continue;
285+
} else if (uniquePhiUsage == null) {
286+
uniquePhiUsage = phiUsage;
287+
} else if (phiUsage != uniquePhiUsage) {
288+
uniquePhiUsage = null;
289+
break;
290+
}
291+
}
292+
if (uniquePhiUsage instanceof BranchProbabilityNode) {
293+
return true;
294+
}
295+
}
296+
}
297+
return false;
298+
}
299+
244300
/**
245301
* This intrinsic should only be used for the condition of an if statement. The parameter
246302
* condition should also only denote a simple condition and not a combined condition involving

compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/java/DynamicNewInstanceNode.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
*/
2525
package org.graalvm.compiler.nodes.java;
2626

27+
import static org.graalvm.compiler.nodes.extended.BranchProbabilityNode.probability;
28+
2729
import java.lang.reflect.Modifier;
2830

2931
import org.graalvm.compiler.core.common.type.ObjectStamp;
@@ -87,8 +89,16 @@ public Node canonical(CanonicalizerTool tool) {
8789
return this;
8890
}
8991

90-
public static boolean throwsInstantiationException(Class<?> type, Class<?> classClass) {
91-
return type.isPrimitive() || type.isArray() || type.isInterface() || Modifier.isAbstract(type.getModifiers()) || type == classClass;
92+
public static boolean throwsInstantiationExceptionInjectedProbability(double probability, Class<?> type, Class<?> classClass) {
93+
/*
94+
* This method is for use in a snippet and therefore injects probabilities for each
95+
* disjunct.
96+
*/
97+
return probability(probability, type.isPrimitive()) ||
98+
probability(probability, type.isArray()) ||
99+
probability(probability, type.isInterface()) ||
100+
probability(probability, Modifier.isAbstract(type.getModifiers())) ||
101+
probability(probability, type == classClass);
92102
}
93103

94104
public static boolean throwsInstantiationException(ResolvedJavaType type, MetaAccessProvider metaAccess) {

compiler/src/org.graalvm.compiler.replacements/src/org/graalvm/compiler/replacements/arraycopy/ArrayCopySnippets.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import static jdk.vm.ci.services.Services.IS_BUILDING_NATIVE_IMAGE;
2828
import static org.graalvm.compiler.nodes.extended.BranchProbabilityNode.DEOPT_PROBABILITY;
2929
import static org.graalvm.compiler.nodes.extended.BranchProbabilityNode.FAST_PATH_PROBABILITY;
30+
import static org.graalvm.compiler.nodes.extended.BranchProbabilityNode.FREQUENT_PROBABILITY;
3031
import static org.graalvm.compiler.nodes.extended.BranchProbabilityNode.NOT_FREQUENT_PROBABILITY;
3132
import static org.graalvm.compiler.nodes.extended.BranchProbabilityNode.probability;
3233

@@ -306,7 +307,7 @@ protected void doExactArraycopyWithExpandedLoopSnippet(Object src, int srcPos, O
306307
long destOffset = arrayBaseOffset + (long) destPos * scale;
307308

308309
GuardingNode anchor = SnippetAnchorNode.anchor();
309-
if (probability(NOT_FREQUENT_PROBABILITY, src == dest && srcPos < destPos)) {
310+
if (probability(FREQUENT_PROBABILITY, src == dest) && probability(NOT_FREQUENT_PROBABILITY, srcPos < destPos)) {
310311
// bad aliased case so we need to copy the array from back to front
311312
for (int position = length - 1; probability(FAST_PATH_PROBABILITY, position >= 0); position--) {
312313
Object value = GuardedUnsafeLoadNode.guardedLoad(src, sourceOffset + ((long) position) * scale, elementKind, arrayLocation, anchor);

compiler/src/org.graalvm.compiler.replacements/src/org/graalvm/compiler/replacements/gc/G1WriteBarrierSnippets.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ public void g1ArrayRangePreWriteBarrier(Address address, long length, @ConstantP
261261
Word thread = getThread();
262262
byte markingValue = thread.readByte(satbQueueMarkingActiveOffset(), SATB_QUEUE_MARKING_ACTIVE_LOCATION);
263263
// If the concurrent marker is not enabled or the vector length is zero, return.
264-
if (probability(FREQUENT_PROBABILITY, markingValue == (byte) 0 || length == 0)) {
264+
if (probability(FREQUENT_PROBABILITY, markingValue == (byte) 0) || probability(NOT_FREQUENT_PROBABILITY, length == 0)) {
265265
return;
266266
}
267267

0 commit comments

Comments
 (0)