Skip to content

[AArch64][PAC] Select auth+load into LDRAA/LDRAB/LDRA[pre]. #123769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "AArch64.h"
#include "AArch64ExpandImm.h"
#include "AArch64MCInstLower.h"
#include "AArch64MachineFunctionInfo.h"
#include "AArch64RegisterInfo.h"
Expand Down Expand Up @@ -204,6 +205,9 @@ class AArch64AsmPrinter : public AsmPrinter {
// authenticating)
void LowerLOADgotAUTH(const MachineInstr &MI);

// Emit the sequence for LDRA (auth + load from authenticated base).
void LowerPtrauthAuthLoad(const MachineInstr &MI);

/// tblgen'erated driver function for lowering simple MI->MC
/// pseudo instructions.
bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst);
Expand Down Expand Up @@ -2159,6 +2163,111 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, BRInst);
}

void AArch64AsmPrinter::LowerPtrauthAuthLoad(const MachineInstr &MI) {
const bool IsPreWB = MI.getOpcode() == AArch64::LDRApre;

const unsigned DstReg = MI.getOperand(0).getReg();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Register type could be used instead of just unsigned, as suggested by clang-tidy.

Suggested change
const unsigned DstReg = MI.getOperand(0).getReg();
const Register DstReg = MI.getOperand(0).getReg();

const int64_t Offset = MI.getOperand(1).getImm();
const auto Key = (AArch64PACKey::ID)MI.getOperand(2).getImm();
const uint64_t Disc = MI.getOperand(3).getImm();
const unsigned AddrDisc = MI.getOperand(4).getReg();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const unsigned AddrDisc = MI.getOperand(4).getReg();
const Register AddrDisc = MI.getOperand(4).getReg();


Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17);

unsigned AUTOpc = getAUTOpcodeForKey(Key, DiscReg == AArch64::XZR);
auto MIB = MCInstBuilder(AUTOpc).addReg(AArch64::X16).addReg(AArch64::X16);
if (DiscReg != AArch64::XZR)
MIB.addReg(DiscReg);

EmitToStreamer(MIB);

// We have a few options for offset folding:
// - writeback, 0 offset (we already wrote the AUT result): LDRXui
// - no wb, uimm12s8 offset (including 0): LDRXui
if (!Offset || (!IsPreWB && isShiftedUInt<12, 3>(Offset))) {
EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
.addReg(DstReg)
.addReg(AArch64::X16)
.addImm(Offset / 8));
return;
}

// - no wb, simm9 offset: LDURXi
if (!IsPreWB && isInt<9>(Offset)) {
EmitToStreamer(MCInstBuilder(AArch64::LDURXi)
.addReg(DstReg)
.addReg(AArch64::X16)
.addImm(Offset));
return;
}

// - pre-indexed wb, simm9 offset: LDRXpre
if (IsPreWB && isInt<9>(Offset)) {
EmitToStreamer(MCInstBuilder(AArch64::LDRXpre)
.addReg(AArch64::X16)
.addReg(DstReg)
.addReg(AArch64::X16)
.addImm(Offset));
return;
}

// Finally, in the general case, we need a MOVimm either way.
SmallVector<AArch64_IMM::ImmInsnModel, 4> ImmInsns;
AArch64_IMM::expandMOVImm(Offset, 64, ImmInsns);

// X17 is dead at this point, use it as the offset register
for (auto &ImmI : ImmInsns) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &ImmI : ImmInsns) {
for (const auto &ImmI : ImmInsns) {

switch (ImmI.Opcode) {
default:
llvm_unreachable("invalid ldra imm expansion opc!");
break;

case AArch64::ORRXri:
EmitToStreamer(MCInstBuilder(ImmI.Opcode)
.addReg(AArch64::X17)
.addReg(AArch64::XZR)
.addImm(ImmI.Op2));
break;
case AArch64::MOVNXi:
case AArch64::MOVZXi:
EmitToStreamer(MCInstBuilder(ImmI.Opcode)
.addReg(AArch64::X17)
.addImm(ImmI.Op1)
.addImm(ImmI.Op2));
break;
case AArch64::MOVKXi:
EmitToStreamer(MCInstBuilder(ImmI.Opcode)
.addReg(AArch64::X17)
.addReg(AArch64::X17)
.addImm(ImmI.Op1)
.addImm(ImmI.Op2));
break;
}
}

// - no wb, any offset: expanded MOVImm + LDRXroX
if (!IsPreWB) {
EmitToStreamer(MCInstBuilder(AArch64::LDRXroX)
.addReg(DstReg)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0)
.addImm(0));
return;
}

// - pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui
EmitToStreamer(MCInstBuilder(AArch64::ADDXrs)
.addReg(AArch64::X16)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0));
EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
.addReg(DstReg)
.addReg(AArch64::X16)
.addImm(0));
}

const MCExpr *
AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
MCContext &Ctx = OutContext;
Expand Down Expand Up @@ -2698,6 +2807,11 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
LowerLOADgotAUTH(*MI);
return;

case AArch64::LDRA:
case AArch64::LDRApre:
LowerPtrauthAuthLoad(*MI);
return;

case AArch64::BRA:
case AArch64::BLRA:
emitPtrauthBranch(MI);
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def form_truncstore : GICombineRule<
(apply [{ applyFormTruncstore(*${root}, MRI, B, Observer, ${matchinfo}); }])
>;

def form_auth_load_matchdata : GIDefMatchData<"AuthLoadMatchInfo">;
def form_auth_load : GICombineRule<
(defs root:$root, form_auth_load_matchdata:$matchinfo),
(match (wip_match_opcode G_LOAD):$root,
[{ return matchFormAuthLoad(*${root}, MRI, Helper, ${matchinfo}); }]),
(apply [{ applyFormAuthLoad(*${root}, MRI, B, Helper, Observer, ${matchinfo}); }])
>;

def fold_merge_to_zext : GICombineRule<
(defs root:$d),
(match (wip_match_opcode G_MERGE_VALUES):$d,
Expand Down Expand Up @@ -315,6 +323,7 @@ def AArch64PostLegalizerLowering
[shuffle_vector_lowering, vashr_vlshr_imm,
icmp_lowering, build_vector_lowering,
lower_vector_fcmp, form_truncstore,
form_auth_load,
vector_sext_inreg_to_shift,
unmerge_ext_to_unmerge, lower_mull,
vector_unmerge_lowering, insertelt_nonconst]> {
Expand Down
165 changes: 163 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {

bool tryIndexedLoad(SDNode *N);

bool tryAuthLoad(SDNode *N);

void SelectPtrauthAuth(SDNode *N);
void SelectPtrauthResign(SDNode *N);

Expand Down Expand Up @@ -1671,6 +1673,163 @@ bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) {
return true;
}

bool AArch64DAGToDAGISel::tryAuthLoad(SDNode *N) {
LoadSDNode *LD = cast<LoadSDNode>(N);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
LoadSDNode *LD = cast<LoadSDNode>(N);
const LoadSDNode *LD = cast<LoadSDNode>(N);

EVT VT = LD->getMemoryVT();
if (VT != MVT::i64)
return false;

assert(LD->getExtensionType() == ISD::NON_EXTLOAD && "invalid 64bit extload");

ISD::MemIndexedMode AM = LD->getAddressingMode();
if (AM != ISD::PRE_INC && AM != ISD::UNINDEXED)
return false;
bool IsPre = AM == ISD::PRE_INC;

SDValue Chain = LD->getChain();
SDValue Ptr = LD->getBasePtr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, Ptr variable deserves a const modifier and a comment explaining that it is the value that is replaced by the written back pre-incremented address. This would answer the question "When writing the address back merely to store it in authenticated form, how the offset is handled?" - "No adjustment needed, the offset was already there".


SDValue Base = Ptr;

int64_t OffsetVal = 0;
if (IsPre) {
OffsetVal = cast<ConstantSDNode>(LD->getOffset())->getSExtValue();
} else if (CurDAG->isBaseWithConstantOffset(Base)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be better for readability to refer to Ptr instead of Base here?

// We support both 'base' and 'base + constant offset' modes.
ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
if (!RHS)
return false;
Comment on lines +1699 to +1701
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC isBaseWithConstantOffset(Base) being true ensures isa<ConstantSDNode>(Op.getOperand(1)).

Suggested change
ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
if (!RHS)
return false;
ConstantSDNode *RHS = cast<ConstantSDNode>(Base.getOperand(1));

OffsetVal = RHS->getSExtValue();
Base = Base.getOperand(0);
}

// The base must be of the form:
// (int_ptrauth_auth <signedbase>, <key>, <disc>)
// with disc being either a constant int, or:
// (int_ptrauth_blend <addrdisc>, <const int disc>)
Comment on lines +1706 to +1709
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is disc being an opaque register (such as a non-blended address discriminator) supported as well?

if (Base.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
return false;

unsigned IntID = cast<ConstantSDNode>(Base.getOperand(0))->getZExtValue();
if (IntID != Intrinsic::ptrauth_auth)
return false;

unsigned KeyC = cast<ConstantSDNode>(Base.getOperand(2))->getZExtValue();
bool IsDKey = KeyC == AArch64PACKey::DA || KeyC == AArch64PACKey::DB;
SDValue Disc = Base.getOperand(3);

Base = Base.getOperand(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be easier to read this function, if we define SignedBase variable here (as per your code comment above) and use it from now on.

Suggested change
Base = Base.getOperand(1);
SDValue SignedBase = Base.getOperand(1);


bool ZeroDisc = isNullConstant(Disc);
SDValue IntDisc, AddrDisc;
std::tie(IntDisc, AddrDisc) = extractPtrauthBlendDiscriminators(Disc, CurDAG);

// If this is an indexed pre-inc load, we obviously need the writeback form.
bool needsWriteback = IsPre;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Invalid case of variable name.

// If not, but the base authenticated pointer has any other use, it's
// beneficial to use the writeback form, to "writeback" the auth, even if
// there is no base+offset addition.
if (!Ptr.hasOneUse()) {
needsWriteback = true;

// However, we can only do that if we don't introduce cycles between the
// load node and any other user of the pointer computation nodes. That can
// happen if the load node uses any of said other users.
// In other words: we can only do this transformation if none of the other
// uses of the pointer computation to be folded are predecessors of the load
// we're folding into.
//
// Visited is a cache containing nodes that are known predecessors of N.
// Worklist is the set of nodes we're looking for predecessors of.
// For the first lookup, that only contains the load node N. Each call to
// hasPredecessorHelper adds any of the potential predecessors of N to the
// Worklist.
SmallPtrSet<const SDNode *, 32> Visited;
SmallVector<const SDNode *, 16> Worklist;
Worklist.push_back(N);
for (SDNode *U : Ptr.getNode()->users())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
for (SDNode *U : Ptr.getNode()->users())
for (const SDNode *U : Ptr.getNode()->users())

if (SDNode::hasPredecessorHelper(U, Visited, Worklist, /*Max=*/32,
/*TopologicalPrune=*/true))
return false;
}

// We have 2 main isel alternatives:
// - LDRAA/LDRAB, writeback or indexed. Zero disc, small offsets, D key.
// - LDRA/LDRApre. Pointer needs to be in X16.
SDLoc DL(N);
MachineSDNode *Res = nullptr;
SDValue Writeback, ResVal, OutChain;

// If the discriminator is zero and the offset fits, we can use LDRAA/LDRAB.
// Do that here to avoid needlessly constraining regalloc into using X16.
if (ZeroDisc && isShiftedInt<10, 3>(OffsetVal) && IsDKey) {
unsigned Opc = 0;
switch (KeyC) {
case AArch64PACKey::DA:
Opc = needsWriteback ? AArch64::LDRAAwriteback : AArch64::LDRAAindexed;
break;
case AArch64PACKey::DB:
Opc = needsWriteback ? AArch64::LDRABwriteback : AArch64::LDRABindexed;
break;
default:
llvm_unreachable("Invalid key for LDRAA/LDRAB");
}
// The offset is encoded as scaled, for an element size of 8 bytes.
SDValue Offset = CurDAG->getTargetConstant(OffsetVal / 8, DL, MVT::i64);
SDValue Ops[] = {Base, Offset, Chain};
Res = needsWriteback
? CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::i64, MVT::Other,
Ops)
: CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, Ops);
if (needsWriteback) {
Writeback = SDValue(Res, 0);
ResVal = SDValue(Res, 1);
OutChain = SDValue(Res, 2);
} else {
ResVal = SDValue(Res, 0);
OutChain = SDValue(Res, 1);
}
} else {
// Otherwise, use the generalized LDRA pseudos.
unsigned Opc = needsWriteback ? AArch64::LDRApre : AArch64::LDRA;

SDValue X16Copy =
CurDAG->getCopyToReg(Chain, DL, AArch64::X16, Base, SDValue());
SDValue Offset = CurDAG->getTargetConstant(OffsetVal, DL, MVT::i64);
SDValue Key = CurDAG->getTargetConstant(KeyC, DL, MVT::i32);
SDValue Ops[] = {Offset, Key, IntDisc, AddrDisc, X16Copy.getValue(1)};
Res = CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, MVT::Glue, Ops);
if (needsWriteback)
Writeback = CurDAG->getCopyFromReg(SDValue(Res, 1), DL, AArch64::X16,
MVT::i64, SDValue(Res, 2));
ResVal = SDValue(Res, 0);
OutChain = SDValue(Res, 1);
Comment on lines +1802 to +1806
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly: strictly speaking, OutChain should be SDValue(Writeback, 1) if needsWriteback is set, but Writeback is scheduled "immediately after LDRApre" and all other successors are simply "after LDRApre", so everything is scheduled correctly without further complicating the computations of OutChain?

}

if (IsPre) {
// If the original load was pre-inc, the resulting LDRA is writeback.
assert(needsWriteback && "preinc loads can't be selected into non-wb ldra");
ReplaceUses(SDValue(N, 1), Writeback); // writeback
ReplaceUses(SDValue(N, 0), ResVal); // loaded value
ReplaceUses(SDValue(N, 2), OutChain); // chain
} else if (needsWriteback) {
// If the original load was unindexed, but we emitted a writeback form,
// we need to replace the uses of the original auth(signedbase)[+offset]
// computation.
ReplaceUses(Ptr, Writeback); // writeback
ReplaceUses(SDValue(N, 0), ResVal); // loaded value
ReplaceUses(SDValue(N, 1), OutChain); // chain
} else {
// Otherwise, we selected a simple load to a simple non-wb ldra.
assert(Ptr.hasOneUse() && "reused auth ptr should be folded into ldra");
ReplaceUses(SDValue(N, 0), ResVal); // loaded value
ReplaceUses(SDValue(N, 1), OutChain); // chain
}

CurDAG->RemoveDeadNode(N);
return true;
}

void AArch64DAGToDAGISel::SelectLoad(SDNode *N, unsigned NumVecs, unsigned Opc,
unsigned SubRegIdx) {
SDLoc dl(N);
Expand Down Expand Up @@ -4643,8 +4802,10 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
break;

case ISD::LOAD: {
// Try to select as an indexed load. Fall through to normal processing
// if we can't.
// Try to select as an indexed or authenticating load. Fall through to
// normal processing if we can't.
if (tryAuthLoad(Node))
return;
if (tryIndexedLoad(Node))
return;
break;
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ def G_ADD_LOW : AArch64GenericInstruction {
let hasSideEffects = 0;
}

// Represents an auth-load instruction. Produced post-legalization from
// G_LOADs of ptrauth_auth intrinsics, with variants for keys/discriminators.
def G_LDRA : AArch64GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
let hasSideEffects = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted that in GISel, G_LDRA and G_LDRApre both define hasSideEffects to false. In DAGISel, on the other hand, both LDRA and LDRApre define hasSideEffects to true. Is that intended?

let mayLoad = 1;
}

// Represents a pre-inc writeback auth-load instruction. Similar to G_LDRA.
def G_LDRApre : AArch64GenericInstruction {
let OutOperandList = (outs type0:$dst, ptype1:$newaddr);
let InOperandList = (ins ptype1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
let hasSideEffects = 0;
let mayLoad = 1;
}

// Pseudo for a rev16 instruction. Produced post-legalization from
// G_SHUFFLE_VECTORs with appropriate masks.
def G_REV16 : AArch64GenericInstruction {
Expand Down
Loading
Loading