Skip to content

Prevent ABI changes affect EnzymeAD #142544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
continue;
}
}

let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };

let layout = match tcx.layout_of(pci) {
Ok(layout) => layout.layout,
Err(_) => {
bug!("failed to compute layout for type {:?}", ty);
}
};

match layout.backend_repr() {
rustc_abi::BackendRepr::ScalarPair(_, _) => {
new_activities.push(da[i].clone());
new_positions.push(i + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document why we're adding an entry here.

Is it intended, that the scalar pair entries both have the same diffactivity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's intended. if we changed the activity of an individual field respect the original field for some reason, this could potentially affect the return function signature

}
_ => {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use an exhaustive match and either span_bug! things, report an error, or explain why this one is ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll move it to an if let since we only need to apply corrections to ScalarPair args, that way i don't leave unmatched variants at the match

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention was to document on the other arms why they need no adjustment or are unreachable. Because it isn't clear to me how nonscalar(pair) layouts are handled

Copy link
Contributor Author

@Sa4dUs Sa4dUs Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, other cases don't change the number of args, so doesn't cause a missmatch between the function args and the diff activities, and they don't need to be adjusted (and slices where already handled above)

Copy link
Contributor Author

@Sa4dUs Sa4dUs Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

da: &mut Vec<DiffAcitivity> already contains a diff activity per source code arg, so in adjust_activity_to_abi we are only adjusting activities to prevent errors on codegen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should i explicitly specify that on the code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand enough here to have an educated opinion. @ZuseZ4 is this literally sth only scalar pairs can ever hit, or can there be situations with aggregates or SIMD vectors?

}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
Expand Down
331 changes: 331 additions & 0 deletions tests/codegen/autodiff/abi_handling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
//@ revisions: debug release

//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme

// This test checks that Rust types are lowered to LLVM-IR types in a way
// we expect and Enzyme can handle. We explicitly check release mode to
// ensure that LLVM's O3 pipeline doesn't rewrite function signatures
// into forms that Enzyme can't process correctly.

#![feature(autodiff)]

use std::autodiff::{autodiff_forward, autodiff_reverse};

#[derive(Copy, Clone)]
struct Input {
x: f32,
y: f32,
}

#[derive(Copy, Clone)]
struct Wrapper {
z: f32,
}

#[derive(Copy, Clone)]
struct NestedInput {
x: f32,
y: Wrapper,
}

fn square(x: f32) -> f32 {
x * x
}

// CHECK: ; abi_handling::f1
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E
// debug-SAME: (ptr align 4 %x)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E
// release-SAME: (float %x.0.val, float %x.4.val)
#[autodiff_forward(df1, Dual, Dual)]
fn f1(x: &[f32; 2]) -> f32 {
x[0] + x[1]
}

// CHECK: ; abi_handling::f2
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E
// debug-SAME: (ptr %f, float %x)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E
// release-SAME: (float noundef %x)
#[autodiff_reverse(df2, Const, Active, Active)]
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
f(x)
}

// CHECK: ; abi_handling::f3
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E
// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E
// release-SAME: (float %x.0.val)
#[autodiff_forward(df3, Dual, Dual, Dual)]
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
*x * *y
}

// CHECK: ; abi_handling::f4
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// debug-SAME: (float %x.0, float %x.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// release-SAME: (float noundef %x.0, float noundef %x.1)
#[autodiff_forward(df4, Dual, Dual)]
fn f4(x: (f32, f32)) -> f32 {
x.0 * x.1
}

// CHECK: ; abi_handling::f5
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// debug-SAME: (float %i.0, float %i.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// release-SAME: (float noundef %i.0, float noundef %i.1)
#[autodiff_forward(df5, Dual, Dual)]
fn f5(i: Input) -> f32 {
i.x + i.y
}

// CHECK: ; abi_handling::f6
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE
// debug-SAME: (float %i.0, float %i.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE
// release-SAME: (float noundef %i.0, float noundef %i.1)
#[autodiff_forward(df6, Dual, Dual)]
fn f6(i: NestedInput) -> f32 {
i.x + i.y.z * i.y.z
}

// CHECK: ; abi_handling::f7
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
#[autodiff_forward(df7, Dual, Dual)]
fn f7(x: (&f32, &f32)) -> f32 {
x.0 * x.1
}

// df1
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
// release-SAME: (float %x.0.val, float %x.4.val)
// release-NEXT: start:
// release-NEXT: %_0 = fadd float %x.0.val, %x.4.val
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float 1.000000e+00, 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'")
// debug-NEXT: start:
// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0
// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4
// debug-NEXT: %_2 = load float, ptr %0, align 4
// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1
// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4
// debug-NEXT: %_5 = load float, ptr %1, align 4
// debug-NEXT: %_0 = fadd float %_2, %_5
// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl"
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

// df2
// release: define internal fastcc { float, float }
// release-SAME: @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E
// release-SAME: (float noundef %x)
// release-NEXT: invertstart:
// release-NEXT: %_0.i = fmul float %x, %x
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float 0.000000e+00, 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E
// debug-SAME: (ptr %f, float %x, float %differeturn)
// debug-NEXT: start:
// debug-NEXT: %"x'de" = alloca float, align 4
// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4
// debug-NEXT: %toreturn = alloca float, align 4
// debug-NEXT: %_0 = call float %f(float %x)
// debug-NEXT: store float %_0, ptr %toreturn, align 4
// debug-NEXT: br label %invertstart
// debug-EMPTY:
// debug-NEXT: invertstart: ; preds = %start
// debug-NEXT: %retreload = load float, ptr %toreturn, align 4
// debug-NEXT: %0 = load float, ptr %"x'de", align 4
// debug-NEXT: %1 = insertvalue { float, float } undef, float %retreload, 0
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
// debug-NEXT: ret { float, float } %2
// debug-NEXT: }

// df3
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
// release-SAME: (float %x.0.val)
// release-NEXT: start:
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'")
// debug-NEXT: start:
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4
// debug-NEXT: %_3 = load float, ptr %x, align 4
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4
// debug-NEXT: %_4 = load float, ptr %y, align 4
// debug-NEXT: %_0 = fmul float %_3, %_4
// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4
// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3
// debug-NEXT: %2 = fadd fast float %0, %1
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

// df4
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// release-SAME: (float noundef %x.0, float %"x.0'")
// release-NEXT: start:
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// debug-SAME: (float %x.0, float %"x.0'", float %x.1, float %"x.1'")
// debug-NEXT: start:
// debug-NEXT: %_0 = fmul float %x.0, %x.1
// debug-NEXT: %0 = fmul fast float %"x.0'", %x.1
// debug-NEXT: %1 = fmul fast float %"x.1'", %x.0
// debug-NEXT: %2 = fadd fast float %0, %1
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

// df5
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// release-SAME: (float noundef %i.0, float %"i.0'")
// release-NEXT: start:
// release-NEXT: %_0 = fadd float %i.0, 1.000000e+00
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float %"i.0'", 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'")
// debug-NEXT: start:
// debug-NEXT: %_0 = fadd float %i.0, %i.1
// debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'"
// debug-NEXT: %1 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
// debug-NEXT: ret { float, float } %2
// debug-NEXT: }

// df6
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE
// release-SAME: (float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'")
// release-NEXT: start:
// release-NEXT: %_3 = fmul float %i.1, %i.1
// release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'"
// release-NEXT: %1 = fmul fast float %0, %i.1
// release-NEXT: %_0 = fadd float %i.0, %_3
// release-NEXT: %2 = fadd fast float %"i.0'", %1
// release-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// release-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// release-NEXT: ret { float, float } %4
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'")
// debug-NEXT: start:
// debug-NEXT: %_3 = fmul float %i.1, %i.1
// debug-NEXT: %0 = fmul fast float %"i.1'", %i.1
// debug-NEXT: %1 = fmul fast float %"i.1'", %i.1
// debug-NEXT: %2 = fadd fast float %0, %1
// debug-NEXT: %_0 = fadd float %i.0, %_3
// debug-NEXT: %3 = fadd fast float %"i.0'", %2
// debug-NEXT: %4 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %5 = insertvalue { float, float } %4, float %3, 1
// debug-NEXT: ret { float, float } %5
// debug-NEXT: }

// df7
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
// release-SAME: (float %x.0.0.val, float %"x.0'.0.val")
// release-NEXT: start:
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.0.val, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'.0.val", 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float }
// debug-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %"x.0'", ptr align 4 %x.1, ptr align 4 %"x.1'")
// debug-NEXT: start:
// debug-NEXT: %0 = call fast { float, float } @"fwddiffe_ZN49_{{.*}}"
// debug-NEXT: %1 = extractvalue { float, float } %0, 0
// debug-NEXT: %2 = extractvalue { float, float } %0, 1
// debug-NEXT: %3 = insertvalue { float, float } undef, float %1, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

fn main() {
let x = std::hint::black_box(2.0);
let y = std::hint::black_box(3.0);
let z = std::hint::black_box(4.0);
static Y: f32 = std::hint::black_box(3.2);

let in_f1 = [x, y];
dbg!(f1(&in_f1));
let res_f1 = df1(&in_f1, &[1.0, 0.0]);
dbg!(res_f1);

dbg!(f2(square, x));
let res_f2 = df2(square, x, 1.0);
dbg!(res_f2);

dbg!(f3(&x, &Y));
let res_f3 = df3(&x, &Y, &1.0, &0.0);
dbg!(res_f3);

let in_f4 = (x, y);
dbg!(f4(in_f4));
let res_f4 = df4(in_f4, (1.0, 0.0));
dbg!(res_f4);

let in_f5 = Input { x, y };
dbg!(f5(in_f5));
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
dbg!(res_f5);

let in_f6 = NestedInput { x, y: Wrapper { z: y } };
dbg!(f6(in_f6));
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
dbg!(res_f6);

let in_f7 = (&x, &y);
dbg!(f7(in_f7));
let res_f7 = df7(in_f7, (&1.0, &0.0));
dbg!(res_f7);
}
Loading