Skip to content

Commit de79c48

Browse files
committed
[InstCombine] Combine ptrauth constant callee into bundle.
Try to optimize a call to a ptrauth constant, into its ptrauth bundle: call(ptrauth(f)), ["ptrauth"()] -> call f as long as the key/discriminator are the same in constant and bundle.
1 parent 0f286f8 commit de79c48

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3665,6 +3665,34 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
36653665
return nullptr;
36663666
}
36673667

3668+
Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) {
3669+
auto *CPA = dyn_cast<ConstantPtrAuth>(Call.getCalledOperand());
3670+
if (!CPA)
3671+
return nullptr;
3672+
3673+
auto *CalleeF = dyn_cast<Function>(CPA->getPointer()->stripPointerCasts());
3674+
// If the ptrauth constant isn't based on a function pointer, bail out.
3675+
if (!CalleeF)
3676+
return nullptr;
3677+
3678+
// Inspect the call ptrauth bundle to check it matches the ptrauth constant.
3679+
auto PAB = Call.getOperandBundle(LLVMContext::OB_ptrauth);
3680+
if (!PAB)
3681+
return nullptr;
3682+
3683+
auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
3684+
Value *Discriminator = PAB->Inputs[1];
3685+
3686+
// If the bundle doesn't match, this is probably going to fail to auth.
3687+
if (!CPA->isKnownCompatibleWith(Key, Discriminator, DL))
3688+
return nullptr;
3689+
3690+
// If the bundle matches the constant, proceed in making this a direct call.
3691+
auto *NewCall = CallBase::removeOperandBundle(&Call, LLVMContext::OB_ptrauth);
3692+
NewCall->setCalledOperand(CalleeF);
3693+
return NewCall;
3694+
}
3695+
36683696
bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
36693697
const TargetLibraryInfo *TLI) {
36703698
// Note: We only handle cases which can't be driven from generic attributes
@@ -3812,6 +3840,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
38123840
if (IntrinsicInst *II = findInitTrampoline(Callee))
38133841
return transformCallThroughTrampoline(Call, *II);
38143842

3843+
// Combine calls to ptrauth constants.
3844+
if (Instruction *NewCall = foldPtrAuthConstantCallee(Call))
3845+
return NewCall;
3846+
38153847
if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
38163848
InlineAsm *IA = cast<InlineAsm>(Callee);
38173849
if (!IA->canThrow()) {

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
282282
Instruction *transformCallThroughTrampoline(CallBase &Call,
283283
IntrinsicInst &Tramp);
284284

285+
/// Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
286+
/// call(ptrauth(f)), ["ptrauth"()] -> call f
287+
/// as long as the key/discriminator are the same in constant and bundle.
288+
Instruction *foldPtrAuthConstantCallee(CallBase &Call);
289+
285290
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
286291
// Otherwise, return std::nullopt
287292
// Currently it matches:
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
5+
6+
declare i64 @f(i32)
7+
declare ptr @f2(i32)
8+
9+
define i32 @test_ptrauth_call(i32 %a0) {
10+
; CHECK-LABEL: @test_ptrauth_call(
11+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
12+
; CHECK-NEXT: ret i32 [[V0]]
13+
;
14+
%v0 = call i32 ptrauth(ptr @f, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
15+
ret i32 %v0
16+
}
17+
18+
define i32 @test_ptrauth_call_disc(i32 %a0) {
19+
; CHECK-LABEL: @test_ptrauth_call_disc(
20+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
21+
; CHECK-NEXT: ret i32 [[V0]]
22+
;
23+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 5678) ]
24+
ret i32 %v0
25+
}
26+
27+
@f_addr_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)
28+
29+
define i32 @test_ptrauth_call_addr_disc(i32 %a0) {
30+
; CHECK-LABEL: @test_ptrauth_call_addr_disc(
31+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
32+
; CHECK-NEXT: ret i32 [[V0]]
33+
;
34+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f_addr_disc.ref to i64)) ]
35+
ret i32 %v0
36+
}
37+
38+
@f_both_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)
39+
40+
define i32 @test_ptrauth_call_blend(i32 %a0) {
41+
; CHECK-LABEL: @test_ptrauth_call_blend(
42+
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
43+
; CHECK-NEXT: ret i32 [[V0]]
44+
;
45+
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 1234)
46+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
47+
ret i32 %v0
48+
}
49+
50+
define i64 @test_ptrauth_call_cast(i32 %a0) {
51+
; CHECK-LABEL: @test_ptrauth_call_cast(
52+
; CHECK-NEXT: [[V0:%.*]] = call ptr @f2(i32 [[A0:%.*]])
53+
; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V0]] to i64
54+
; CHECK-NEXT: ret i64 [[TMP1]]
55+
;
56+
%v0 = call i64 ptrauth(ptr @f2, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
57+
ret i64 %v0
58+
}
59+
60+
define i32 @test_ptrauth_call_mismatch_key(i32 %a0) {
61+
; CHECK-LABEL: @test_ptrauth_call_mismatch_key(
62+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 0, i64 5678) ]
63+
; CHECK-NEXT: ret i32 [[V0]]
64+
;
65+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 0, i64 5678) ]
66+
ret i32 %v0
67+
}
68+
69+
define i32 @test_ptrauth_call_mismatch_disc(i32 %a0) {
70+
; CHECK-LABEL: @test_ptrauth_call_mismatch_disc(
71+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 0) ]
72+
; CHECK-NEXT: ret i32 [[V0]]
73+
;
74+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 0) ]
75+
ret i32 %v0
76+
}
77+
78+
define i32 @test_ptrauth_call_mismatch_blend(i32 %a0) {
79+
; CHECK-LABEL: @test_ptrauth_call_mismatch_blend(
80+
; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
81+
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
82+
; CHECK-NEXT: ret i32 [[V0]]
83+
;
84+
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
85+
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
86+
ret i32 %v0
87+
}
88+
89+
declare i64 @llvm.ptrauth.blend(i64, i64)

0 commit comments

Comments
 (0)