Skip to content

Commit 39d1efc

Browse files
committed
Update activity adjustment logic
Update `count_scalar_fields`->`count_leaf_fields` to support more types Add extra activities only if `count_scalar_fields` is leq 2 Logic can be optimized if needed Removed metadata specific fields from test to avoid future fails.
1 parent 56dbc6a commit 39d1efc

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,15 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
8484

8585
let is_product = |t: Ty<'tcx>| matches!(t.kind(), ty::Tuple(_) | ty::Adt(_, _));
8686

87+
// NOTE: When an ADT (Algebraic Data Type) has fewer than two fields and a total size less than pointer_size * 2,
88+
// LLVM will pass its fields separately instead of as a single aggregate.
8789
if layout.size() <= pointer_size * 2 && is_product(*ty) {
88-
let n_scalars = count_scalar_fields(tcx, *ty);
89-
for _ in 0..n_scalars.saturating_sub(1) {
90-
new_activities.push(da[i].clone());
91-
new_positions.push(i + 1);
90+
let n_scalars = count_leaf_fields(tcx, *ty);
91+
if n_scalars <= 2 {
92+
for _ in 0..n_scalars.saturating_sub(1) {
93+
new_activities.push(da[i].clone());
94+
new_positions.push(i + 1);
95+
}
9296
}
9397
}
9498
}
@@ -101,16 +105,25 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
101105
}
102106
}
103107

104-
fn count_scalar_fields<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> usize {
108+
fn count_leaf_fields<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> usize {
105109
match ty.kind() {
106-
ty::Float(_) | ty::Int(_) | ty::Uint(_) => 1,
110+
ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::FnPtr(_, _) => 1,
111+
ty::RawPtr(ty, _) => count_leaf_fields(tcx, *ty),
112+
ty::Ref(_, ty, _) => count_leaf_fields(tcx, *ty),
113+
ty::Array(ty, len) => {
114+
if let Some(len) = len.try_to_target_usize(tcx) {
115+
count_leaf_fields(tcx, *ty) * len as usize
116+
} else {
117+
1 // Not sure about how to handle this case
118+
}
119+
}
107120
ty::Adt(def, substs) if def.is_struct() => def
108121
.non_enum_variant()
109122
.fields
110123
.iter()
111-
.map(|f| count_scalar_fields(tcx, f.ty(tcx, substs)))
124+
.map(|f| count_leaf_fields(tcx, f.ty(tcx, substs)))
112125
.sum(),
113-
ty::Tuple(substs) => substs.iter().map(|t| count_scalar_fields(tcx, t)).sum(),
126+
ty::Tuple(substs) => substs.iter().map(|t| count_leaf_fields(tcx, t)).sum(),
114127
_ => 0,
115128
}
116129
}

tests/codegen/autodiff/abi_handling.rs

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ fn f6(i: NestedInput) -> f32 {
101101
i.x + i.y.z * i.y.z
102102
}
103103

104+
// CHECK: ; abi_handling::f7
105+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
106+
// debug-NEXT: define internal float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
107+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
108+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
109+
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
110+
#[autodiff_forward(df7, Dual, Dual)]
111+
fn f7(x: (&f32, &f32)) -> f32 {
112+
x.0 * x.1
113+
}
114+
104115
// df1
105116
// release: define internal fastcc { float, float }
106117
// release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
@@ -117,12 +128,12 @@ fn f6(i: NestedInput) -> f32 {
117128
// debug-NEXT: start:
118129
// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0
119130
// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0
120-
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4, !alias.scope !4, !noalias !7
121-
// debug-NEXT: %_2 = load float, ptr %0, align 4, !alias.scope !7, !noalias !4
131+
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4
132+
// debug-NEXT: %_2 = load float, ptr %0, align 4
122133
// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1
123134
// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1
124-
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4, !alias.scope !4, !noalias !7
125-
// debug-NEXT: %_5 = load float, ptr %1, align 4, !alias.scope !7, !noalias !4
135+
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4
136+
// debug-NEXT: %_5 = load float, ptr %1, align 4
126137
// debug-NEXT: %_0 = fadd float %_2, %_5
127138
// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl"
128139
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
@@ -147,7 +158,7 @@ fn f6(i: NestedInput) -> f32 {
147158
// debug-NEXT: %"x'de" = alloca float, align 4
148159
// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4
149160
// debug-NEXT: %toreturn = alloca float, align 4
150-
// debug-NEXT: %_0 = call float %f(float %x) #12
161+
// debug-NEXT: %_0 = call float %f(float %x)
151162
// debug-NEXT: store float %_0, ptr %toreturn, align 4
152163
// debug-NEXT: br label %invertstart
153164
// debug-EMPTY:
@@ -172,10 +183,10 @@ fn f6(i: NestedInput) -> f32 {
172183
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
173184
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'")
174185
// debug-NEXT: start:
175-
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4, !alias.scope !9, !noalias !12
176-
// debug-NEXT: %_3 = load float, ptr %x, align 4, !alias.scope !12, !noalias !9
177-
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4, !alias.scope !14, !noalias !17
178-
// debug-NEXT: %_4 = load float, ptr %y, align 4, !alias.scope !17, !noalias !14
186+
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4
187+
// debug-NEXT: %_3 = load float, ptr %x, align 4
188+
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4
189+
// debug-NEXT: %_4 = load float, ptr %y, align 4
179190
// debug-NEXT: %_0 = fmul float %_3, %_4
180191
// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4
181192
// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3
@@ -257,6 +268,28 @@ fn f6(i: NestedInput) -> f32 {
257268
// debug-NEXT: ret { float, float } %5
258269
// debug-NEXT: }
259270

271+
// df7
272+
// release: define internal fastcc { float, float }
273+
// release-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
274+
// release-SAME: (float %x.0.0.val, float %"x.0'.0.val")
275+
// release-NEXT: start:
276+
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.0.val, 0
277+
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'.0.val", 1
278+
// release-NEXT: ret { float, float } %1
279+
// release-NEXT: }
280+
281+
// debug: define internal { float, float }
282+
// debug-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
283+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %"x.0'", ptr align 4 %x.1, ptr align 4 %"x.1'")
284+
// debug-NEXT: start:
285+
// debug-NEXT: %0 = call fast { float, float } @"fwddiffe_ZN49_{{.*}}"
286+
// debug-NEXT: %1 = extractvalue { float, float } %0, 0
287+
// debug-NEXT: %2 = extractvalue { float, float } %0, 1
288+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %1, 0
289+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
290+
// debug-NEXT: ret { float, float } %4
291+
// debug-NEXT: }
292+
260293
fn main() {
261294
let x = std::hint::black_box(2.0);
262295
let y = std::hint::black_box(3.0);
@@ -290,4 +323,9 @@ fn main() {
290323
dbg!(f6(in_f6));
291324
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
292325
dbg!(res_f6);
326+
327+
let in_f7 = (&x, &y);
328+
dbg!(f7(in_f7));
329+
let res_f7 = df7(in_f7, (&1.0, &0.0));
330+
dbg!(res_f7);
293331
}

0 commit comments

Comments
 (0)