diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 436cdbff75669..069f638cd0e45 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3665,6 +3665,34 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { return nullptr; } +Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) { + auto *CPA = dyn_cast(Call.getCalledOperand()); + if (!CPA) + return nullptr; + + auto *CalleeF = dyn_cast(CPA->getPointer()); + // If the ptrauth constant isn't based on a function pointer, bail out. + if (!CalleeF) + return nullptr; + + // Inspect the call ptrauth bundle to check it matches the ptrauth constant. + auto PAB = Call.getOperandBundle(LLVMContext::OB_ptrauth); + if (!PAB) + return nullptr; + + auto *Key = cast(PAB->Inputs[0]); + Value *Discriminator = PAB->Inputs[1]; + + // If the bundle doesn't match, this is probably going to fail to auth. + if (!CPA->isKnownCompatibleWith(Key, Discriminator, DL)) + return nullptr; + + // If the bundle matches the constant, proceed in making this a direct call. + auto *NewCall = CallBase::removeOperandBundle(&Call, LLVMContext::OB_ptrauth); + NewCall->setCalledOperand(CalleeF); + return NewCall; +} + bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { // Note: We only handle cases which can't be driven from generic attributes @@ -3812,6 +3840,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { if (IntrinsicInst *II = findInitTrampoline(Callee)) return transformCallThroughTrampoline(Call, *II); + // Combine calls to ptrauth constants. + if (Instruction *NewCall = foldPtrAuthConstantCallee(Call)) + return NewCall; + if (isa(Callee) && !Call.doesNotThrow()) { InlineAsm *IA = cast(Callee); if (!IA->canThrow()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 984f02bcccad7..9268cbe594d90 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -282,6 +282,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *transformCallThroughTrampoline(CallBase &Call, IntrinsicInst &Tramp); + /// Try to optimize a call to a ptrauth constant, into its ptrauth bundle: + /// call(ptrauth(f)), ["ptrauth"()] -> call f + /// as long as the key/discriminator are the same in constant and bundle. + Instruction *foldPtrAuthConstantCallee(CallBase &Call); + // Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a). // Otherwise, return std::nullopt // Currently it matches: diff --git a/llvm/test/Transforms/InstCombine/ptrauth-call.ll b/llvm/test/Transforms/InstCombine/ptrauth-call.ll new file mode 100644 index 0000000000000..b4363b528d4e2 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/ptrauth-call.ll @@ -0,0 +1,89 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" + +declare i64 @f(i32) +declare ptr @f2(i32) + +define i32 @test_ptrauth_call(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call( +; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]]) +; CHECK-NEXT: ret i32 [[V0]] +; + %v0 = call i32 ptrauth(ptr @f, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ] + ret i32 %v0 +} + +define i32 @test_ptrauth_call_disc(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_disc( +; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]]) +; CHECK-NEXT: ret i32 [[V0]] +; + %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 5678) ] + ret i32 %v0 +} + +@f_addr_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref) + +define i32 @test_ptrauth_call_addr_disc(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_addr_disc( +; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]]) +; CHECK-NEXT: ret i32 [[V0]] +; + %v0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f_addr_disc.ref to i64)) ] + ret i32 %v0 +} + +@f_both_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref) + +define i32 @test_ptrauth_call_blend(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_blend( +; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]]) +; CHECK-NEXT: ret i32 [[V0]] +; + %v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 1234) + %v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ] + ret i32 %v0 +} + +define i64 @test_ptrauth_call_cast(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_cast( +; CHECK-NEXT: [[V0:%.*]] = call ptr @f2(i32 [[A0:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V0]] to i64 +; CHECK-NEXT: ret i64 [[TMP1]] +; + %v0 = call i64 ptrauth(ptr @f2, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ] + ret i64 %v0 +} + +define i32 @test_ptrauth_call_mismatch_key(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_mismatch_key( +; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 0, i64 5678) ] +; CHECK-NEXT: ret i32 [[V0]] +; + %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 0, i64 5678) ] + ret i32 %v0 +} + +define i32 @test_ptrauth_call_mismatch_disc(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_mismatch_disc( +; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 0) ] +; CHECK-NEXT: ret i32 [[V0]] +; + %v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 0) ] + ret i32 %v0 +} + +define i32 @test_ptrauth_call_mismatch_blend(i32 %a0) { +; CHECK-LABEL: @test_ptrauth_call_mismatch_blend( +; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0) +; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ] +; CHECK-NEXT: ret i32 [[V0]] +; + %v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0) + %v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ] + ret i32 %v0 +} + +declare i64 @llvm.ptrauth.blend(i64, i64)