-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Prevent ABI changes affect EnzymeAD #142544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,23 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec | |
continue; | ||
} | ||
} | ||
|
||
let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty }; | ||
|
||
let layout = match tcx.layout_of(pci) { | ||
Ok(layout) => layout.layout, | ||
Err(_) => { | ||
bug!("failed to compute layout for type {:?}", ty); | ||
} | ||
}; | ||
|
||
match layout.backend_repr() { | ||
rustc_abi::BackendRepr::ScalarPair(_, _) => { | ||
new_activities.push(da[i].clone()); | ||
new_positions.push(i + 1); | ||
} | ||
_ => {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use an exhaustive match and either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'll move it to an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My intention was to document on the other arms why they need no adjustment or are unreachable. Because it isn't clear to me how nonscalar(pair) layouts are handled There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. afaik, other cases don't change the number of args, so doesn't cause a missmatch between the function args and the diff activities, and they don't need to be adjusted (and slices where already handled above) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should i explicitly specify that on the code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand enough here to have an educated opinion. @ZuseZ4 is this literally sth only scalar pairs can ever hit, or can there be situations with aggregates or SIMD vectors? |
||
} | ||
} | ||
// now add the extra activities coming from slices | ||
// Reverse order to not invalidate the indices | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,331 @@ | ||
//@ revisions: debug release | ||
|
||
//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat | ||
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat | ||
//@ no-prefer-dynamic | ||
//@ needs-enzyme | ||
|
||
// This test checks that Rust types are lowered to LLVM-IR types in a way | ||
// we expect and Enzyme can handle. We explicitly check release mode to | ||
// ensure that LLVM's O3 pipeline doesn't rewrite function signatures | ||
// into forms that Enzyme can't process correctly. | ||
|
||
#![feature(autodiff)] | ||
|
||
use std::autodiff::{autodiff_forward, autodiff_reverse}; | ||
|
||
#[derive(Copy, Clone)] | ||
struct Input { | ||
x: f32, | ||
y: f32, | ||
} | ||
|
||
#[derive(Copy, Clone)] | ||
struct Wrapper { | ||
z: f32, | ||
} | ||
|
||
#[derive(Copy, Clone)] | ||
struct NestedInput { | ||
x: f32, | ||
y: Wrapper, | ||
} | ||
|
||
fn square(x: f32) -> f32 { | ||
x * x | ||
} | ||
|
||
// CHECK: ; abi_handling::f1 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E | ||
// debug-SAME: (ptr align 4 %x) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E | ||
// release-SAME: (float %x.0.val, float %x.4.val) | ||
#[autodiff_forward(df1, Dual, Dual)] | ||
fn f1(x: &[f32; 2]) -> f32 { | ||
x[0] + x[1] | ||
} | ||
|
||
// CHECK: ; abi_handling::f2 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E | ||
// debug-SAME: (ptr %f, float %x) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E | ||
// release-SAME: (float noundef %x) | ||
#[autodiff_reverse(df2, Const, Active, Active)] | ||
fn f2(f: fn(f32) -> f32, x: f32) -> f32 { | ||
f(x) | ||
} | ||
|
||
// CHECK: ; abi_handling::f3 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E | ||
// debug-SAME: (ptr align 4 %x, ptr align 4 %y) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E | ||
// release-SAME: (float %x.0.val) | ||
#[autodiff_forward(df3, Dual, Dual, Dual)] | ||
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 { | ||
*x * *y | ||
} | ||
|
||
// CHECK: ; abi_handling::f4 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE | ||
// debug-SAME: (float %x.0, float %x.1) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE | ||
// release-SAME: (float noundef %x.0, float noundef %x.1) | ||
#[autodiff_forward(df4, Dual, Dual)] | ||
fn f4(x: (f32, f32)) -> f32 { | ||
x.0 * x.1 | ||
} | ||
|
||
// CHECK: ; abi_handling::f5 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E | ||
// debug-SAME: (float %i.0, float %i.1) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E | ||
// release-SAME: (float noundef %i.0, float noundef %i.1) | ||
#[autodiff_forward(df5, Dual, Dual)] | ||
fn f5(i: Input) -> f32 { | ||
i.x + i.y | ||
} | ||
|
||
// CHECK: ; abi_handling::f6 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE | ||
// debug-SAME: (float %i.0, float %i.1) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE | ||
// release-SAME: (float noundef %i.0, float noundef %i.1) | ||
#[autodiff_forward(df6, Dual, Dual)] | ||
fn f6(i: NestedInput) -> f32 { | ||
i.x + i.y.z * i.y.z | ||
} | ||
|
||
// CHECK: ; abi_handling::f7 | ||
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} | ||
// debug-NEXT: define internal float @_ZN12abi_handling2f717h44e3cff234e3b2d5E | ||
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1) | ||
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717h44e3cff234e3b2d5E | ||
// release-SAME: (float %x.0.0.val, float %x.1.0.val) | ||
#[autodiff_forward(df7, Dual, Dual)] | ||
fn f7(x: (&f32, &f32)) -> f32 { | ||
x.0 * x.1 | ||
} | ||
|
||
// df1 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E | ||
// release-SAME: (float %x.0.val, float %x.4.val) | ||
// release-NEXT: start: | ||
// release-NEXT: %_0 = fadd float %x.0.val, %x.4.val | ||
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0 | ||
// release-NEXT: %1 = insertvalue { float, float } %0, float 1.000000e+00, 1 | ||
// release-NEXT: ret { float, float } %1 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E | ||
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'") | ||
// debug-NEXT: start: | ||
// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0 | ||
// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0 | ||
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4 | ||
// debug-NEXT: %_2 = load float, ptr %0, align 4 | ||
// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1 | ||
// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1 | ||
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4 | ||
// debug-NEXT: %_5 = load float, ptr %1, align 4 | ||
// debug-NEXT: %_0 = fadd float %_2, %_5 | ||
// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl" | ||
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 | ||
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 | ||
// debug-NEXT: ret { float, float } %4 | ||
// debug-NEXT: } | ||
|
||
// df2 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E | ||
// release-SAME: (float noundef %x) | ||
// release-NEXT: invertstart: | ||
// release-NEXT: %_0.i = fmul float %x, %x | ||
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0 | ||
// release-NEXT: %1 = insertvalue { float, float } %0, float 0.000000e+00, 1 | ||
// release-NEXT: ret { float, float } %1 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E | ||
// debug-SAME: (ptr %f, float %x, float %differeturn) | ||
// debug-NEXT: start: | ||
// debug-NEXT: %"x'de" = alloca float, align 4 | ||
// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4 | ||
// debug-NEXT: %toreturn = alloca float, align 4 | ||
// debug-NEXT: %_0 = call float %f(float %x) | ||
// debug-NEXT: store float %_0, ptr %toreturn, align 4 | ||
// debug-NEXT: br label %invertstart | ||
// debug-EMPTY: | ||
// debug-NEXT: invertstart: ; preds = %start | ||
// debug-NEXT: %retreload = load float, ptr %toreturn, align 4 | ||
// debug-NEXT: %0 = load float, ptr %"x'de", align 4 | ||
// debug-NEXT: %1 = insertvalue { float, float } undef, float %retreload, 0 | ||
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1 | ||
// debug-NEXT: ret { float, float } %2 | ||
// debug-NEXT: } | ||
|
||
// df3 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E | ||
// release-SAME: (float %x.0.val) | ||
// release-NEXT: start: | ||
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0 | ||
// release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1 | ||
// release-NEXT: ret { float, float } %1 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E | ||
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'") | ||
// debug-NEXT: start: | ||
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4 | ||
// debug-NEXT: %_3 = load float, ptr %x, align 4 | ||
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4 | ||
// debug-NEXT: %_4 = load float, ptr %y, align 4 | ||
// debug-NEXT: %_0 = fmul float %_3, %_4 | ||
// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4 | ||
// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3 | ||
// debug-NEXT: %2 = fadd fast float %0, %1 | ||
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 | ||
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 | ||
// debug-NEXT: ret { float, float } %4 | ||
// debug-NEXT: } | ||
|
||
// df4 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE | ||
// release-SAME: (float noundef %x.0, float %"x.0'") | ||
// release-NEXT: start: | ||
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0 | ||
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1 | ||
// release-NEXT: ret { float, float } %1 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE | ||
// debug-SAME: (float %x.0, float %"x.0'", float %x.1, float %"x.1'") | ||
// debug-NEXT: start: | ||
// debug-NEXT: %_0 = fmul float %x.0, %x.1 | ||
// debug-NEXT: %0 = fmul fast float %"x.0'", %x.1 | ||
// debug-NEXT: %1 = fmul fast float %"x.1'", %x.0 | ||
// debug-NEXT: %2 = fadd fast float %0, %1 | ||
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 | ||
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 | ||
// debug-NEXT: ret { float, float } %4 | ||
// debug-NEXT: } | ||
|
||
// df5 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E | ||
// release-SAME: (float noundef %i.0, float %"i.0'") | ||
// release-NEXT: start: | ||
// release-NEXT: %_0 = fadd float %i.0, 1.000000e+00 | ||
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0 | ||
// release-NEXT: %1 = insertvalue { float, float } %0, float %"i.0'", 1 | ||
// release-NEXT: ret { float, float } %1 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E | ||
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'") | ||
// debug-NEXT: start: | ||
// debug-NEXT: %_0 = fadd float %i.0, %i.1 | ||
// debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'" | ||
// debug-NEXT: %1 = insertvalue { float, float } undef, float %_0, 0 | ||
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1 | ||
// debug-NEXT: ret { float, float } %2 | ||
// debug-NEXT: } | ||
|
||
// df6 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE | ||
// release-SAME: (float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'") | ||
// release-NEXT: start: | ||
// release-NEXT: %_3 = fmul float %i.1, %i.1 | ||
// release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'" | ||
// release-NEXT: %1 = fmul fast float %0, %i.1 | ||
// release-NEXT: %_0 = fadd float %i.0, %_3 | ||
// release-NEXT: %2 = fadd fast float %"i.0'", %1 | ||
// release-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 | ||
// release-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 | ||
// release-NEXT: ret { float, float } %4 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE | ||
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'") | ||
// debug-NEXT: start: | ||
// debug-NEXT: %_3 = fmul float %i.1, %i.1 | ||
// debug-NEXT: %0 = fmul fast float %"i.1'", %i.1 | ||
// debug-NEXT: %1 = fmul fast float %"i.1'", %i.1 | ||
// debug-NEXT: %2 = fadd fast float %0, %1 | ||
// debug-NEXT: %_0 = fadd float %i.0, %_3 | ||
// debug-NEXT: %3 = fadd fast float %"i.0'", %2 | ||
// debug-NEXT: %4 = insertvalue { float, float } undef, float %_0, 0 | ||
// debug-NEXT: %5 = insertvalue { float, float } %4, float %3, 1 | ||
// debug-NEXT: ret { float, float } %5 | ||
// debug-NEXT: } | ||
|
||
// df7 | ||
// release: define internal fastcc { float, float } | ||
// release-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E | ||
// release-SAME: (float %x.0.0.val, float %"x.0'.0.val") | ||
// release-NEXT: start: | ||
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.0.val, 0 | ||
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'.0.val", 1 | ||
// release-NEXT: ret { float, float } %1 | ||
// release-NEXT: } | ||
|
||
// debug: define internal { float, float } | ||
// debug-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E | ||
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %"x.0'", ptr align 4 %x.1, ptr align 4 %"x.1'") | ||
// debug-NEXT: start: | ||
// debug-NEXT: %0 = call fast { float, float } @"fwddiffe_ZN49_{{.*}}" | ||
// debug-NEXT: %1 = extractvalue { float, float } %0, 0 | ||
// debug-NEXT: %2 = extractvalue { float, float } %0, 1 | ||
// debug-NEXT: %3 = insertvalue { float, float } undef, float %1, 0 | ||
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 | ||
// debug-NEXT: ret { float, float } %4 | ||
// debug-NEXT: } | ||
|
||
fn main() { | ||
let x = std::hint::black_box(2.0); | ||
let y = std::hint::black_box(3.0); | ||
let z = std::hint::black_box(4.0); | ||
static Y: f32 = std::hint::black_box(3.2); | ||
|
||
let in_f1 = [x, y]; | ||
dbg!(f1(&in_f1)); | ||
let res_f1 = df1(&in_f1, &[1.0, 0.0]); | ||
dbg!(res_f1); | ||
|
||
dbg!(f2(square, x)); | ||
let res_f2 = df2(square, x, 1.0); | ||
dbg!(res_f2); | ||
|
||
dbg!(f3(&x, &Y)); | ||
let res_f3 = df3(&x, &Y, &1.0, &0.0); | ||
dbg!(res_f3); | ||
|
||
let in_f4 = (x, y); | ||
dbg!(f4(in_f4)); | ||
let res_f4 = df4(in_f4, (1.0, 0.0)); | ||
dbg!(res_f4); | ||
|
||
let in_f5 = Input { x, y }; | ||
dbg!(f5(in_f5)); | ||
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 }); | ||
dbg!(res_f5); | ||
|
||
let in_f6 = NestedInput { x, y: Wrapper { z: y } }; | ||
dbg!(f6(in_f6)); | ||
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } }); | ||
dbg!(res_f6); | ||
|
||
let in_f7 = (&x, &y); | ||
dbg!(f7(in_f7)); | ||
let res_f7 = df7(in_f7, (&1.0, &0.0)); | ||
dbg!(res_f7); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document why we're adding an entry here.
Is it intended, that the scalar pair entries both have the same diffactivity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it's intended. if we changed the activity of an individual field respect the original field for some reason, this could potentially affect the return function signature