1
1
2
2
module forward_diff_no_inf
3
- using Core. Compiler : SSAValue
3
+ using Core: SSAValue
4
4
const CC = Core. Compiler
5
5
6
6
using Diffractor, Test
@@ -20,7 +20,14 @@ module forward_diff_no_inf
20
20
mi. specTypes = Tuple{map (CC. widenconst, ir. argtypes)... }
21
21
mi. def = @__MODULE__
22
22
23
- for i in 1 : length (ir. stmts) # For testuing purposes we are going to refine everything
23
+ for i in 1 : length (ir. stmts)
24
+ inst = ir[SSAValue (i)][:inst ]
25
+ if Meta. isexpr (inst, :code_coverage_effect )
26
+ # delete these as CC._ir_abstract_constant_propagation doesn't work on them
27
+ ir[SSAValue (i)][:inst ] = nothing
28
+ ir[SSAValue (i)][:type ] = Nothing
29
+ end
30
+ # For testing purposes we are going to refine everything else
24
31
ir[SSAValue (i)][:flag ] |= CC. IR_FLAG_REFINED
25
32
end
26
33
@@ -39,14 +46,28 @@ module forward_diff_no_inf
39
46
typ = stmt[:type ]
40
47
! isa (typ, Type) && continue # If not a Type then something even more informed like a Const
41
48
if isabstracttype (typ) || typ <: Union || typ <: UnionAll
42
- # @error "Not fully inferred" inst typ
49
+ # @error "Not fully inferred" inst typ
43
50
return false
44
51
end
45
52
end
46
53
end
47
54
return true
48
55
end
49
56
57
+ function findfirst_ssa (predicate, ir)
58
+ for ii in 1 : length (ir. stmts)
59
+ try
60
+ inst = ir[SSAValue (ii)][:inst ]
61
+ if predicate (inst)
62
+ return SSAValue (ii)
63
+ end
64
+ catch
65
+ # ignore errors so predicate can be simple
66
+ end
67
+ end
68
+ return nothing
69
+ end
70
+
50
71
# ############################## Actual tests:
51
72
52
73
@testset " Constructors in forward_diff_no_inf!" begin
@@ -108,21 +129,22 @@ module forward_diff_no_inf
108
129
end
109
130
110
131
# only test this on new enough julia versions as exactly what infers can be fussy, as is running inference manually
111
- VERSION >= v " 1.12.0-DEV.283" && @testset " Eras mode: $eras_mode " for eras_mode in (false , true )
132
+ VERSION >= v " 1.12.0-DEV.283" && @testset " Eras mode: $eras_mode " for eras_mode in (false , true )
112
133
foo (x, y) = x* x + y* y
113
134
ir = first (only (Base. code_ircode (foo, Tuple{Any, Any})))
114
- Diffractor. forward_diff_no_inf! (ir, [SSAValue (1 )] .=> 1 ; transform! = identity_transform!, eras_mode)
135
+ mul1_ssa = findfirst_ssa (x-> x. args[1 ]. name== :* , ir)
136
+ Diffractor. forward_diff_no_inf! (ir, [mul1_ssa] .=> 1 ; transform! = identity_transform!, eras_mode)
115
137
ir = CC. compact! (ir)
116
138
ir. argtypes[2 : end ] .= Float64
117
139
ir = CC. compact! (ir)
118
140
infer_ir! (ir)
119
141
CC. verify_ir (ir)
120
142
@test isfully_inferred (ir) # passes with and without eras mode
121
-
122
- Diffractor. forward_diff_no_inf! (ir, [SSAValue (3 )] .=> 1 ; transform! = identity_transform!, eras_mode)
143
+
144
+ add_ssa = findfirst_ssa (x-> x. args[1 ]. name== :+ , ir)
145
+ Diffractor. forward_diff_no_inf! (ir, [add_ssa] .=> 1 ; transform! = identity_transform!, eras_mode)
123
146
ir = CC. compact! (ir)
124
147
infer_ir! (ir)
125
-
126
148
CC. verify_ir (ir)
127
149
if eras_mode
128
150
@test isfully_inferred (ir)
@@ -131,6 +153,5 @@ module forward_diff_no_inf
131
153
@assert ! isfully_inferred (ir)
132
154
end
133
155
end
134
-
135
156
end # module
136
157
0 commit comments