Skip to content

[AArch64][PAC] Select auth+load into LDRAA/LDRAB/LDRA[pre]. #123769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ahmedbougacha
Copy link
Member

This can lower loads of a ptrauth.auth base into a fixed sequence that doesn't allow the raw intermediate value to be exposed.

It's based on the AArch64 LDRAA/LDRAB instructions, but as those have limited encodings (in particular, small immediate offsets, and only zero discriminators), it generalizes them with a LDRA pseudo.

It handles arbitrary ptrauth schemas on the authentication, materializing the integer constant discriminator and blending it with an address discriminator if needed.

It handles arbitrary offsets (applied after the authentication).

It also handles pre-indexing with writeback, either writing back the authentication result alone if the offset is 0, or both authentication and offset addition otherwise.

At ISel time, the real LDRAA family of instructions is selected when possible, to avoid needlessly constraining regalloc with X16/X17. After ISel, the LDRA pseudos are expanded in AsmPrinter, into either of:

  • writeback, 0 offset (we already wrote the AUT result): LDRXui
  • no wb, uimm12s8 offset (including 0): LDRXui
  • no wb, simm9 offset: LDURXi
  • pre-indexed wb, simm9 offset: LDRXpre
  • no wb, any offset: expanded MOVImm + LDRXroX
  • pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui

Though the main intended optimization target is vtable-like codegen, where both the base vtable pointer is signed, as well as its entries, at small fixed offsets. This does benefit from writeback, hence the ISel complexity to support that, as it's otherwise unlikely to be worthwhile.

GlobalISel would benefit from further optimization, as this lowering conflicts with the generic indexed lowering there.

I did a pass to refresh the old patch, but it's been a while, please let me know if I missed a spot!

This can lower loads of a ptrauth.auth base into a fixed sequence that
doesn't allow the raw intermediate value to be exposed.

It's based on the AArch64 LDRAA/LDRAB instructions, but as those have
limited encodings (in particular, small immediate offsets, and only zero
discriminators), it generalizes them with a LDRA pseudo.

It handles arbitrary ptrauth schemas on the authentication,
materializing the integer constant discriminator and blending it with an
address discriminator if needed.

It handles arbitrary offsets (applied after the authentication).

It also handles pre-indexing with writeback, either writing back the
authentication result alone if the offset is 0, or both authentication
and offset addition otherwise.

At ISel time, the real LDRAA family of instructions is selected when
possible, to avoid needlessly constraining regalloc with X16/X17.
After ISel, the LDRA pseudos are expanded in AsmPrinter, into either of:
- writeback, 0 offset (we already wrote the AUT result): LDRXui
- no wb, uimm12s8 offset (including 0): LDRXui
- no wb, simm9 offset: LDURXi
- pre-indexed wb, simm9 offset: LDRXpre
- no wb, any offset: expanded MOVImm + LDRXroX
- pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui

Though the main intended optimization target is vtable-like codegen,
where both the base vtable pointer is signed, as well as its entries, at
small fixed offsets.  This does benefit from writeback, hence the ISel
complexity to support that, as it's otherwise unlikely to be worthwhile.

GlobalISel would benefit from further optimization, as this lowering
conflicts with the generic indexed lowering there.
@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Ahmed Bougacha (ahmedbougacha)

Changes

This can lower loads of a ptrauth.auth base into a fixed sequence that doesn't allow the raw intermediate value to be exposed.

It's based on the AArch64 LDRAA/LDRAB instructions, but as those have limited encodings (in particular, small immediate offsets, and only zero discriminators), it generalizes them with a LDRA pseudo.

It handles arbitrary ptrauth schemas on the authentication, materializing the integer constant discriminator and blending it with an address discriminator if needed.

It handles arbitrary offsets (applied after the authentication).

It also handles pre-indexing with writeback, either writing back the authentication result alone if the offset is 0, or both authentication and offset addition otherwise.

At ISel time, the real LDRAA family of instructions is selected when possible, to avoid needlessly constraining regalloc with X16/X17. After ISel, the LDRA pseudos are expanded in AsmPrinter, into either of:

  • writeback, 0 offset (we already wrote the AUT result): LDRXui
  • no wb, uimm12s8 offset (including 0): LDRXui
  • no wb, simm9 offset: LDURXi
  • pre-indexed wb, simm9 offset: LDRXpre
  • no wb, any offset: expanded MOVImm + LDRXroX
  • pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui

Though the main intended optimization target is vtable-like codegen, where both the base vtable pointer is signed, as well as its entries, at small fixed offsets. This does benefit from writeback, hence the ISel complexity to support that, as it's otherwise unlikely to be worthwhile.

GlobalISel would benefit from further optimization, as this lowering conflicts with the generic indexed lowering there.

I did a pass to refresh the old patch, but it's been a while, please let me know if I missed a spot!


Patch is 47.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123769.diff

8 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+114)
  • (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+163-2)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrGISel.td (+17)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+36)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+71)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp (+115)
  • (added) llvm/test/CodeGen/AArch64/ptrauth-load.ll (+716)
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 27e65d60122fd7..b3876ff4862e1d 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "AArch64.h"
+#include "AArch64ExpandImm.h"
 #include "AArch64MCInstLower.h"
 #include "AArch64MachineFunctionInfo.h"
 #include "AArch64RegisterInfo.h"
@@ -204,6 +205,9 @@ class AArch64AsmPrinter : public AsmPrinter {
   // authenticating)
   void LowerLOADgotAUTH(const MachineInstr &MI);
 
+  // Emit the sequence for LDRA (auth + load from authenticated base).
+  void LowerPtrauthAuthLoad(const MachineInstr &MI);
+
   /// tblgen'erated driver function for lowering simple MI->MC
   /// pseudo instructions.
   bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst);
@@ -2159,6 +2163,111 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
   EmitToStreamer(*OutStreamer, BRInst);
 }
 
+void AArch64AsmPrinter::LowerPtrauthAuthLoad(const MachineInstr &MI) {
+  const bool IsPreWB = MI.getOpcode() == AArch64::LDRApre;
+
+  const unsigned DstReg = MI.getOperand(0).getReg();
+  const int64_t Offset = MI.getOperand(1).getImm();
+  const auto Key = (AArch64PACKey::ID)MI.getOperand(2).getImm();
+  const uint64_t Disc = MI.getOperand(3).getImm();
+  const unsigned AddrDisc = MI.getOperand(4).getReg();
+
+  Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17);
+
+  unsigned AUTOpc = getAUTOpcodeForKey(Key, DiscReg == AArch64::XZR);
+  auto MIB = MCInstBuilder(AUTOpc).addReg(AArch64::X16).addReg(AArch64::X16);
+  if (DiscReg != AArch64::XZR)
+    MIB.addReg(DiscReg);
+
+  EmitToStreamer(MIB);
+
+  // We have a few options for offset folding:
+  // - writeback, 0 offset (we already wrote the AUT result): LDRXui
+  // - no wb, uimm12s8 offset (including 0): LDRXui
+  if (!Offset || (!IsPreWB && isShiftedUInt<12, 3>(Offset))) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addImm(Offset / 8));
+    return;
+  }
+
+  // - no wb, simm9 offset: LDURXi
+  if (!IsPreWB && isInt<9>(Offset)) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDURXi)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addImm(Offset));
+    return;
+  }
+
+  // - pre-indexed wb, simm9 offset: LDRXpre
+  if (IsPreWB && isInt<9>(Offset)) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXpre)
+                       .addReg(AArch64::X16)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addImm(Offset));
+    return;
+  }
+
+  // Finally, in the general case, we need a MOVimm either way.
+  SmallVector<AArch64_IMM::ImmInsnModel, 4> ImmInsns;
+  AArch64_IMM::expandMOVImm(Offset, 64, ImmInsns);
+
+  // X17 is dead at this point, use it as the offset register
+  for (auto &ImmI : ImmInsns) {
+    switch (ImmI.Opcode) {
+    default:
+      llvm_unreachable("invalid ldra imm expansion opc!");
+      break;
+
+    case AArch64::ORRXri:
+      EmitToStreamer(MCInstBuilder(ImmI.Opcode)
+                         .addReg(AArch64::X17)
+                         .addReg(AArch64::XZR)
+                         .addImm(ImmI.Op2));
+      break;
+    case AArch64::MOVNXi:
+    case AArch64::MOVZXi:
+      EmitToStreamer(MCInstBuilder(ImmI.Opcode)
+                         .addReg(AArch64::X17)
+                         .addImm(ImmI.Op1)
+                         .addImm(ImmI.Op2));
+      break;
+    case AArch64::MOVKXi:
+      EmitToStreamer(MCInstBuilder(ImmI.Opcode)
+                         .addReg(AArch64::X17)
+                         .addReg(AArch64::X17)
+                         .addImm(ImmI.Op1)
+                         .addImm(ImmI.Op2));
+      break;
+    }
+  }
+
+  // - no wb, any offset: expanded MOVImm + LDRXroX
+  if (!IsPreWB) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXroX)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addReg(AArch64::X17)
+                       .addImm(0)
+                       .addImm(0));
+    return;
+  }
+
+  // - pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui
+  EmitToStreamer(MCInstBuilder(AArch64::ADDXrs)
+                     .addReg(AArch64::X16)
+                     .addReg(AArch64::X16)
+                     .addReg(AArch64::X17)
+                     .addImm(0));
+  EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
+                     .addReg(DstReg)
+                     .addReg(AArch64::X16)
+                     .addImm(0));
+}
+
 const MCExpr *
 AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
   MCContext &Ctx = OutContext;
@@ -2698,6 +2807,11 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     LowerLOADgotAUTH(*MI);
     return;
 
+  case AArch64::LDRA:
+  case AArch64::LDRApre:
+    LowerPtrauthAuthLoad(*MI);
+    return;
+
   case AArch64::BRA:
   case AArch64::BLRA:
     emitPtrauthBranch(MI);
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ce1980697abbbb..f1853f2b8cc5ad 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -255,6 +255,14 @@ def form_truncstore : GICombineRule<
   (apply [{ applyFormTruncstore(*${root}, MRI, B, Observer, ${matchinfo}); }])
 >;
 
+def form_auth_load_matchdata : GIDefMatchData<"AuthLoadMatchInfo">;
+def form_auth_load : GICombineRule<
+  (defs root:$root, form_auth_load_matchdata:$matchinfo),
+  (match (wip_match_opcode G_LOAD):$root,
+         [{ return matchFormAuthLoad(*${root}, MRI, Helper, ${matchinfo}); }]),
+  (apply [{ applyFormAuthLoad(*${root}, MRI, B, Helper, Observer, ${matchinfo}); }])
+>;
+
 def fold_merge_to_zext : GICombineRule<
   (defs root:$d),
   (match (wip_match_opcode G_MERGE_VALUES):$d,
@@ -315,6 +323,7 @@ def AArch64PostLegalizerLowering
                        [shuffle_vector_lowering, vashr_vlshr_imm,
                         icmp_lowering, build_vector_lowering,
                         lower_vector_fcmp, form_truncstore,
+                        form_auth_load,
                         vector_sext_inreg_to_shift,
                         unmerge_ext_to_unmerge, lower_mull,
                         vector_unmerge_lowering, insertelt_nonconst]> {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 6aa8cd4f0232ac..8660b2d0bc8e6f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -361,6 +361,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
 
   bool tryIndexedLoad(SDNode *N);
 
+  bool tryAuthLoad(SDNode *N);
+
   void SelectPtrauthAuth(SDNode *N);
   void SelectPtrauthResign(SDNode *N);
 
@@ -1671,6 +1673,163 @@ bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) {
   return true;
 }
 
+bool AArch64DAGToDAGISel::tryAuthLoad(SDNode *N) {
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+  EVT VT = LD->getMemoryVT();
+  if (VT != MVT::i64)
+    return false;
+
+  assert(LD->getExtensionType() == ISD::NON_EXTLOAD && "invalid 64bit extload");
+
+  ISD::MemIndexedMode AM = LD->getAddressingMode();
+  if (AM != ISD::PRE_INC && AM != ISD::UNINDEXED)
+    return false;
+  bool IsPre = AM == ISD::PRE_INC;
+
+  SDValue Chain = LD->getChain();
+  SDValue Ptr = LD->getBasePtr();
+
+  SDValue Base = Ptr;
+
+  int64_t OffsetVal = 0;
+  if (IsPre) {
+    OffsetVal = cast<ConstantSDNode>(LD->getOffset())->getSExtValue();
+  } else if (CurDAG->isBaseWithConstantOffset(Base)) {
+    // We support both 'base' and 'base + constant offset' modes.
+    ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
+    if (!RHS)
+      return false;
+    OffsetVal = RHS->getSExtValue();
+    Base = Base.getOperand(0);
+  }
+
+  // The base must be of the form:
+  //   (int_ptrauth_auth <signedbase>, <key>, <disc>)
+  // with disc being either a constant int, or:
+  //   (int_ptrauth_blend <addrdisc>, <const int disc>)
+  if (Base.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return false;
+
+  unsigned IntID = cast<ConstantSDNode>(Base.getOperand(0))->getZExtValue();
+  if (IntID != Intrinsic::ptrauth_auth)
+    return false;
+
+  unsigned KeyC = cast<ConstantSDNode>(Base.getOperand(2))->getZExtValue();
+  bool IsDKey = KeyC == AArch64PACKey::DA || KeyC == AArch64PACKey::DB;
+  SDValue Disc = Base.getOperand(3);
+
+  Base = Base.getOperand(1);
+
+  bool ZeroDisc = isNullConstant(Disc);
+  SDValue IntDisc, AddrDisc;
+  std::tie(IntDisc, AddrDisc) = extractPtrauthBlendDiscriminators(Disc, CurDAG);
+
+  // If this is an indexed pre-inc load, we obviously need the writeback form.
+  bool needsWriteback = IsPre;
+  // If not, but the base authenticated pointer has any other use, it's
+  // beneficial to use the writeback form, to "writeback" the auth, even if
+  // there is no base+offset addition.
+  if (!Ptr.hasOneUse()) {
+    needsWriteback = true;
+
+    // However, we can only do that if we don't introduce cycles between the
+    // load node and any other user of the pointer computation nodes.  That can
+    // happen if the load node uses any of said other users.
+    // In other words: we can only do this transformation if none of the other
+    // uses of the pointer computation to be folded are predecessors of the load
+    // we're folding into.
+    //
+    // Visited is a cache containing nodes that are known predecessors of N.
+    // Worklist is the set of nodes we're looking for predecessors of.
+    // For the first lookup, that only contains the load node N.  Each call to
+    // hasPredecessorHelper adds any of the potential predecessors of N to the
+    // Worklist.
+    SmallPtrSet<const SDNode *, 32> Visited;
+    SmallVector<const SDNode *, 16> Worklist;
+    Worklist.push_back(N);
+    for (SDNode *U : Ptr.getNode()->users())
+      if (SDNode::hasPredecessorHelper(U, Visited, Worklist, /*Max=*/32,
+                                       /*TopologicalPrune=*/true))
+        return false;
+  }
+
+  // We have 2 main isel alternatives:
+  // - LDRAA/LDRAB, writeback or indexed.  Zero disc, small offsets, D key.
+  // - LDRA/LDRApre.  Pointer needs to be in X16.
+  SDLoc DL(N);
+  MachineSDNode *Res = nullptr;
+  SDValue Writeback, ResVal, OutChain;
+
+  // If the discriminator is zero and the offset fits, we can use LDRAA/LDRAB.
+  // Do that here to avoid needlessly constraining regalloc into using X16.
+  if (ZeroDisc && isShiftedInt<10, 3>(OffsetVal) && IsDKey) {
+    unsigned Opc = 0;
+    switch (KeyC) {
+    case AArch64PACKey::DA:
+      Opc = needsWriteback ? AArch64::LDRAAwriteback : AArch64::LDRAAindexed;
+      break;
+    case AArch64PACKey::DB:
+      Opc = needsWriteback ? AArch64::LDRABwriteback : AArch64::LDRABindexed;
+      break;
+    default:
+      llvm_unreachable("Invalid key for LDRAA/LDRAB");
+    }
+    // The offset is encoded as scaled, for an element size of 8 bytes.
+    SDValue Offset = CurDAG->getTargetConstant(OffsetVal / 8, DL, MVT::i64);
+    SDValue Ops[] = {Base, Offset, Chain};
+    Res = needsWriteback
+              ? CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::i64, MVT::Other,
+                                       Ops)
+              : CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, Ops);
+    if (needsWriteback) {
+      Writeback = SDValue(Res, 0);
+      ResVal = SDValue(Res, 1);
+      OutChain = SDValue(Res, 2);
+    } else {
+      ResVal = SDValue(Res, 0);
+      OutChain = SDValue(Res, 1);
+    }
+  } else {
+    // Otherwise, use the generalized LDRA pseudos.
+    unsigned Opc = needsWriteback ? AArch64::LDRApre : AArch64::LDRA;
+
+    SDValue X16Copy =
+        CurDAG->getCopyToReg(Chain, DL, AArch64::X16, Base, SDValue());
+    SDValue Offset = CurDAG->getTargetConstant(OffsetVal, DL, MVT::i64);
+    SDValue Key = CurDAG->getTargetConstant(KeyC, DL, MVT::i32);
+    SDValue Ops[] = {Offset, Key, IntDisc, AddrDisc, X16Copy.getValue(1)};
+    Res = CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, MVT::Glue, Ops);
+    if (needsWriteback)
+      Writeback = CurDAG->getCopyFromReg(SDValue(Res, 1), DL, AArch64::X16,
+                                         MVT::i64, SDValue(Res, 2));
+    ResVal = SDValue(Res, 0);
+    OutChain = SDValue(Res, 1);
+  }
+
+  if (IsPre) {
+    // If the original load was pre-inc, the resulting LDRA is writeback.
+    assert(needsWriteback && "preinc loads can't be selected into non-wb ldra");
+    ReplaceUses(SDValue(N, 1), Writeback); // writeback
+    ReplaceUses(SDValue(N, 0), ResVal);    // loaded value
+    ReplaceUses(SDValue(N, 2), OutChain);  // chain
+  } else if (needsWriteback) {
+    // If the original load was unindexed, but we emitted a writeback form,
+    // we need to replace the uses of the original auth(signedbase)[+offset]
+    // computation.
+    ReplaceUses(Ptr, Writeback);          // writeback
+    ReplaceUses(SDValue(N, 0), ResVal);   // loaded value
+    ReplaceUses(SDValue(N, 1), OutChain); // chain
+  } else {
+    // Otherwise, we selected a simple load to a simple non-wb ldra.
+    assert(Ptr.hasOneUse() && "reused auth ptr should be folded into ldra");
+    ReplaceUses(SDValue(N, 0), ResVal);   // loaded value
+    ReplaceUses(SDValue(N, 1), OutChain); // chain
+  }
+
+  CurDAG->RemoveDeadNode(N);
+  return true;
+}
+
 void AArch64DAGToDAGISel::SelectLoad(SDNode *N, unsigned NumVecs, unsigned Opc,
                                      unsigned SubRegIdx) {
   SDLoc dl(N);
@@ -4643,8 +4802,10 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
     break;
 
   case ISD::LOAD: {
-    // Try to select as an indexed load. Fall through to normal processing
-    // if we can't.
+    // Try to select as an indexed or authenticating load. Fall through to
+    // normal processing if we can't.
+    if (tryAuthLoad(Node))
+      return;
     if (tryIndexedLoad(Node))
       return;
     break;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
index 2d2b2bee99ec41..1b544c4f8c19a6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
@@ -25,6 +25,23 @@ def G_ADD_LOW : AArch64GenericInstruction {
   let hasSideEffects = 0;
 }
 
+// Represents an auth-load instruction.  Produced post-legalization from
+// G_LOADs of ptrauth_auth intrinsics, with variants for keys/discriminators.
+def G_LDRA : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
+  let hasSideEffects = 0;
+  let mayLoad = 1;
+}
+
+// Represents a pre-inc writeback auth-load instruction.  Similar to G_LDRA.
+def G_LDRApre : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst, ptype1:$newaddr);
+  let InOperandList = (ins ptype1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
+  let hasSideEffects = 0;
+  let mayLoad = 1;
+}
+
 // Pseudo for a rev16 instruction. Produced post-legalization from
 // G_SHUFFLE_VECTORs with appropriate masks.
 def G_REV16 : AArch64GenericInstruction {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 8e575abf83d449..44be2fe00f0aa7 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1973,6 +1973,42 @@ let Predicates = [HasPAuth] in {
     let Size = 8;
   }
 
+  // LDRA pseudo: generalized LDRAA/Bindexed, allowing arbitrary discriminators,
+  // and wider offsets.
+  // This directly manipulates x16/x17, which are the only registers the OS
+  // guarantees are safe to use for sensitive operations.
+  // The loaded value is in $Rt.  The signed pointer is in X16.
+  // $Rt could be GPR64 but is GPR64noip to help out regalloc: we imp-def 2/3rds
+  // of the difference between the two, and the 3rd reg (LR) is often reserved.
+  def LDRA : Pseudo<(outs GPR64noip:$Rt),
+                    (ins i64imm:$Offset, i32imm:$Key, i64imm:$Disc,
+                         GPR64noip:$AddrDisc),
+                    []>, Sched<[]> {
+    let isCodeGenOnly = 1;
+    let hasSideEffects = 1;
+    let mayStore = 0;
+    let mayLoad = 1;
+    let Size = 48;
+    let Defs = [X16,X17];
+    let Uses = [X16];
+  }
+
+  // Pre-indexed + writeback variant of LDRA.
+  // The signed pointer is in X16, and is written back, after being
+  // authenticated and offset, into X16.
+  def LDRApre : Pseudo<(outs GPR64noip:$Rt),
+                       (ins i64imm:$Offset, i32imm:$Key, i64imm:$Disc,
+                            GPR64noip:$AddrDisc),
+                    []>, Sched<[]> {
+    let isCodeGenOnly = 1;
+    let hasSideEffects = 1;
+    let mayStore = 0;
+    let mayLoad = 1;
+    let Size = 48;
+    let Defs = [X16,X17];
+    let Uses = [X16];
+  }
+
   // Size 16: 4 fixed + 8 variable, to compute discriminator.
   // The size returned by getInstSizeInBytes() is incremented according
   // to the variant of LR check.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 07f03644336cdd..87158df0b75c2a 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -225,6 +225,7 @@ class AArch64InstructionSelector : public InstructionSelector {
   bool selectTLSGlobalValue(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectPtrAuthGlobalValue(MachineInstr &I,
                                 MachineRegisterInfo &MRI) const;
+  bool selectAuthLoad(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectReduction(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectMOPS(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectUSMovFromExtend(MachineInstr &I, MachineRegisterInfo &MRI);
@@ -2992,6 +2993,10 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
   case TargetOpcode::G_PTRAUTH_GLOBAL_VALUE:
     return selectPtrAuthGlobalValue(I, MRI);
 
+  case AArch64::G_LDRA:
+  case AArch64::G_LDRApre:
+    return selectAuthLoad(I, MRI);
+
   case TargetOpcode::G_ZEXTLOAD:
   case TargetOpcode::G_LOAD:
   case TargetOpcode::G_STORE: {
@@ -6976,6 +6981,72 @@ bool AArch64InstructionSelector::selectPtrAuthGlobalValue(
   return true;
 }
 
+bool AArch64InstructionSelector::selectAuthLoad(MachineInstr &I,
+                                                MachineRegisterInfo &MRI) {
+  bool Writeback = I.getOpcode() == AArch64::G_LDRApre;
+
+  Register ValReg = I.getOperand(0).getReg();
+  Register PtrReg = I.getOperand(1 + Writeback).getReg();
+  int64_t Offset = I.getOperand(2 + Writeback).getImm();
+  auto Key =
+      static_cast<AArch64PACKey::ID>(I.getOperand(3 + Writeback).getImm());
+  uint64_t DiscImm = I.getOperand(4 + Writeback).getImm();
+  Register AddrDisc = I.getOperand(5 + Writeback).getReg();
+
+  bool IsDKey = Key == AArch64PACKey::DA || Key == AArch64PACKey::DB;
+  bool ZeroDisc = AddrDisc == AArch64::NoRegister && !DiscImm;
+
+  // If the discriminator is zero and the offset fits, we can use LDRAA/LDRAB.
+  // Do that here to avoid needlessly constraining regalloc into using X16.
+  if (ZeroDisc && isShiftedInt<10, 3>(Offset) && IsDKey) {
+    unsigned Opc = 0;
+    switch (Key) {
+    case AArch64PACKey::DA:
+      Opc = Writeback ? AArch64::LDRAAwriteback : AArch64::LDRAAindexed;
+      break;
+    case AArch64PACKey::DB:
+      Opc = Writeback ? AArch64::LDRABwriteback : AArch64::LDRABindexed;
+      break;
+    default:
+      llvm_unreachable("Invalid key for LDRAA/LDRAB");
+    }
+    // The LDRAA/LDRAB offset immediate is scaled.
+    Offset /= 8;
+    if (Writeback) {
+      MIB.buildInstr(Opc, {I.getOperand(1).getReg(), ValReg}, {PtrReg, Offset})
+          .constrainAllUses(TII, TRI, RBI);
+      RBI.constrainGenericRegister(I.getOperand(1).getReg(),
+                                   AArch64::GPR64spRegClass, MRI);
+    } else {
+      MIB.buildInstr(Opc, {ValReg}, {PtrReg, Offset})
+          .constrainAllUses(TII, TRI, RBI);
+    }
+    I.eraseFromParent();
+    return...
[truncated]

@ahmedbougacha ahmedbougacha self-assigned this Jan 21, 2025
@asl asl requested a review from atrosinenko January 22, 2025 09:19
void AArch64AsmPrinter::LowerPtrauthAuthLoad(const MachineInstr &MI) {
const bool IsPreWB = MI.getOpcode() == AArch64::LDRApre;

const unsigned DstReg = MI.getOperand(0).getReg();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Register type could be used instead of just unsigned, as suggested by clang-tidy.

Suggested change
const unsigned DstReg = MI.getOperand(0).getReg();
const Register DstReg = MI.getOperand(0).getReg();

const int64_t Offset = MI.getOperand(1).getImm();
const auto Key = (AArch64PACKey::ID)MI.getOperand(2).getImm();
const uint64_t Disc = MI.getOperand(3).getImm();
const unsigned AddrDisc = MI.getOperand(4).getReg();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const unsigned AddrDisc = MI.getOperand(4).getReg();
const Register AddrDisc = MI.getOperand(4).getReg();

Comment on lines +1699 to +1701
ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
if (!RHS)
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC isBaseWithConstantOffset(Base) being true ensures isa<ConstantSDNode>(Op.getOperand(1)).

Suggested change
ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
if (!RHS)
return false;
ConstantSDNode *RHS = cast<ConstantSDNode>(Base.getOperand(1));

std::tie(IntDisc, AddrDisc) = extractPtrauthBlendDiscriminators(Disc, CurDAG);

// If this is an indexed pre-inc load, we obviously need the writeback form.
bool needsWriteback = IsPre;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Invalid case of variable name.

int64_t OffsetVal = 0;
if (IsPre) {
OffsetVal = cast<ConstantSDNode>(LD->getOffset())->getSExtValue();
} else if (CurDAG->isBaseWithConstantOffset(Base)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be better for readability to refer to Ptr instead of Base here?

bool IsPre = AM == ISD::PRE_INC;

SDValue Chain = LD->getChain();
SDValue Ptr = LD->getBasePtr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, Ptr variable deserves a const modifier and a comment explaining that it is the value that is replaced by the written back pre-incremented address. This would answer the question "When writing the address back merely to store it in authenticated form, how the offset is handled?" - "No adjustment needed, the offset was already there".

Comment on lines +1706 to +1709
// The base must be of the form:
// (int_ptrauth_auth <signedbase>, <key>, <disc>)
// with disc being either a constant int, or:
// (int_ptrauth_blend <addrdisc>, <const int disc>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is disc being an opaque register (such as a non-blended address discriminator) supported as well?

bool IsDKey = KeyC == AArch64PACKey::DA || KeyC == AArch64PACKey::DB;
SDValue Disc = Base.getOperand(3);

Base = Base.getOperand(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be easier to read this function, if we define SignedBase variable here (as per your code comment above) and use it from now on.

Suggested change
Base = Base.getOperand(1);
SDValue SignedBase = Base.getOperand(1);

Comment on lines +1802 to +1806
if (needsWriteback)
Writeback = CurDAG->getCopyFromReg(SDValue(Res, 1), DL, AArch64::X16,
MVT::i64, SDValue(Res, 2));
ResVal = SDValue(Res, 0);
OutChain = SDValue(Res, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly: strictly speaking, OutChain should be SDValue(Writeback, 1) if needsWriteback is set, but Writeback is scheduled "immediately after LDRApre" and all other successors are simply "after LDRApre", so everything is scheduled correctly without further complicating the computations of OutChain?

def G_LDRA : AArch64GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
let hasSideEffects = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted that in GISel, G_LDRA and G_LDRApre both define hasSideEffects to false. In DAGISel, on the other hand, both LDRA and LDRApre define hasSideEffects to true. Is that intended?

@@ -1671,6 +1673,163 @@ bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) {
return true;
}

bool AArch64DAGToDAGISel::tryAuthLoad(SDNode *N) {
LoadSDNode *LD = cast<LoadSDNode>(N);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
LoadSDNode *LD = cast<LoadSDNode>(N);
const LoadSDNode *LD = cast<LoadSDNode>(N);

AArch64_IMM::expandMOVImm(Offset, 64, ImmInsns);

// X17 is dead at this point, use it as the offset register
for (auto &ImmI : ImmInsns) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &ImmI : ImmInsns) {
for (const auto &ImmI : ImmInsns) {

SmallPtrSet<const SDNode *, 32> Visited;
SmallVector<const SDNode *, 16> Worklist;
Worklist.push_back(N);
for (SDNode *U : Ptr.getNode()->users())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
for (SDNode *U : Ptr.getNode()->users())
for (const SDNode *U : Ptr.getNode()->users())

@@ -0,0 +1,716 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple arm64e-apple-darwin -verify-machineinstrs -global-isel=0 | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've ensured that for linux the test passes as well, so it would be nice if you add identical RUN lines with -mtriple aarch64 -mattr=+pauth. The only required change is replacing ; %bb to %bb everywhere since // is used instead of ; on linux.

@asl asl added this to the LLVM 21.x Release milestone Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

5 participants