Skip to content

Conversation

@wizardengineer
Copy link
Contributor

This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.

Implementation details:

  • Uses CSEL family of instructions for scalar integer types
  • Uses FCSEL for floating-point types (F16, BF16, F32, F64)
  • Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL
  • Handles vector types appropriately
  • Comprehensive test coverage for AArch64

The implementation includes:

  • ISelLowering: Custom lowering to CTSELECT pseudo-instructions
  • InstrInfo: Pseudo-instruction definitions and patterns
  • MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL
  • Proper handling of condition codes for constant-time guarantees

Copy link
Contributor Author

wizardengineer commented Nov 6, 2025

Warning

This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
Learn more

This stack of pull requests is managed by Graphite. Learn more about stacking.

@github-actions
Copy link

github-actions bot commented Nov 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@wizardengineer wizardengineer changed the title [LLVM][AArch64] Add native ct.select support for ARM64 [ConstantTime] Native ct.select support for ARM64 Nov 6, 2025
This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.

Implementation details:
- Uses CSEL family of instructions for scalar integer types
- Uses FCSEL for floating-point types (F16, BF16, F32, F64)
- Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL
- Handles vector types appropriately
- Comprehensive test coverage for AArch64

The implementation includes:
- ISelLowering: Custom lowering to CTSELECT pseudo-instructions
- InstrInfo: Pseudo-instruction definitions and patterns
- MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL
- Proper handling of condition codes for constant-time guarantees
@wizardengineer wizardengineer force-pushed the users/wizardengineer/ct-select-aarch64 branch from 071428b to 7de2b81 Compare November 6, 2025 17:28
@wizardengineer wizardengineer force-pushed the users/wizardengineer/ct-select-clang branch from cbb5490 to 6ac8221 Compare November 6, 2025 17:28
@wizardengineer wizardengineer marked this pull request as ready for review November 6, 2025 20:36
@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Julius Alexandre (wizardengineer)

Changes

This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.

Implementation details:

  • Uses CSEL family of instructions for scalar integer types
  • Uses FCSEL for floating-point types (F16, BF16, F32, F64)
  • Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL
  • Handles vector types appropriately
  • Comprehensive test coverage for AArch64

The implementation includes:

  • ISelLowering: Custom lowering to CTSELECT pseudo-instructions
  • InstrInfo: Pseudo-instruction definitions and patterns
  • MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL
  • Proper handling of condition codes for constant-time guarantees

Patch is 28.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166706.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+56)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+11)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+90-110)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+40)
  • (modified) llvm/lib/Target/AArch64/AArch64MCInstLower.cpp (+18)
  • (added) llvm/test/CodeGen/AArch64/ctselect.ll (+153)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..54d0ea168d0b6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -511,12 +511,36 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::BR_CC, MVT::f64, Custom);
   setOperationAction(ISD::SELECT, MVT::i32, Custom);
   setOperationAction(ISD::SELECT, MVT::i64, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::i8, Promote);
+  setOperationAction(ISD::CTSELECT, MVT::i16, Promote);
+  setOperationAction(ISD::CTSELECT, MVT::i32, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::i64, Custom);
   if (Subtarget->hasFPARMv8()) {
     setOperationAction(ISD::SELECT, MVT::f16, Custom);
     setOperationAction(ISD::SELECT, MVT::bf16, Custom);
   }
+  if (Subtarget->hasFullFP16()) {
+    setOperationAction(ISD::CTSELECT, MVT::f16, Custom);
+    setOperationAction(ISD::CTSELECT, MVT::bf16, Custom);
+  } else {
+    setOperationAction(ISD::CTSELECT, MVT::f16, Promote);
+    setOperationAction(ISD::CTSELECT, MVT::bf16, Promote);
+  }
   setOperationAction(ISD::SELECT, MVT::f32, Custom);
   setOperationAction(ISD::SELECT, MVT::f64, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::f32, Custom);
+  setOperationAction(ISD::CTSELECT, MVT::f64, Custom);
+  for (MVT VT : MVT::vector_valuetypes()) {
+    MVT elemType = VT.getVectorElementType();
+    if (elemType == MVT::i8 || elemType == MVT::i16) {
+      setOperationAction(ISD::CTSELECT, VT, Promote);
+    } else if ((elemType == MVT::f16 || elemType == MVT::bf16) &&
+               !Subtarget->hasFullFP16()) {
+      setOperationAction(ISD::CTSELECT, VT, Promote);
+    } else {
+      setOperationAction(ISD::CTSELECT, VT, Expand);
+    }
+  }
   setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
   setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
@@ -3328,6 +3352,20 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator(
   IntDiscOp.setImm(IntDisc);
 }
 
+MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI,
+                                                       MachineBasicBlock *MBB,
+                                                       unsigned Opcode) const {
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  DebugLoc DL = MI.getDebugLoc();
+  MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode));
+  for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+    Builder.add(MI.getOperand(Idx));
+  }
+  Builder->setFlag(MachineInstr::NoMerge);
+  MBB->remove_instr(&MI);
+  return MBB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
 
@@ -7590,6 +7628,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerSELECT(Op, DAG);
   case ISD::SELECT_CC:
     return LowerSELECT_CC(Op, DAG);
+  case ISD::CTSELECT:
+    return LowerCTSELECT(Op, DAG);
   case ISD::JumpTable:
     return LowerJumpTable(Op, DAG);
   case ISD::BR_JT:
@@ -12149,6 +12189,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
   return Res;
 }
 
+SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op,
+                                             SelectionDAG &DAG) const {
+  SDValue CCVal = Op->getOperand(0);
+  SDValue TVal = Op->getOperand(1);
+  SDValue FVal = Op->getOperand(2);
+  SDLoc DL(Op);
+
+  EVT VT = Op.getValueType();
+
+  SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType());
+  SDValue CC;
+  SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL);
+
+  return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp);
+}
+
 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
                                               SelectionDAG &DAG) const {
   // Jump table entries as PC relative offsets. No additional tweaking
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 2cb8ed29f252a..987377bc49023 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -23,6 +23,11 @@
 
 namespace llvm {
 
+namespace AArch64ISD {
+// Forward declare the enum from the generated file
+enum GenNodeType : unsigned;
+} // namespace AArch64ISD
+
 class AArch64TargetMachine;
 
 namespace AArch64 {
@@ -202,6 +207,9 @@ class AArch64TargetLowering : public TargetLowering {
                                  MachineOperand &AddrDiscOp,
                                  const TargetRegisterClass *AddrDiscRC) const;
 
+  MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB,
+                                  unsigned Opcode) const;
+
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,
                               MachineBasicBlock *MBB) const override;
@@ -684,6 +692,7 @@ class AArch64TargetLowering : public TargetLowering {
                          iterator_range<SDNode::user_iterator> Users,
                          SDNodeFlags Flags, const SDLoc &dl,
                          SelectionDAG &DAG) const;
+  SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
@@ -919,6 +928,8 @@ class AArch64TargetLowering : public TargetLowering {
   bool hasMultipleConditionRegisters(EVT VT) const override {
     return VT.isScalableVector();
   }
+
+  bool isSelectSupported(SelectSupportKind Kind) const override { return true; }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index ccc8eb8a9706d..bab67f57ea6b6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -700,7 +700,7 @@ static unsigned removeCopies(const MachineRegisterInfo &MRI, unsigned VReg) {
 // csel instruction. If so, return the folded opcode, and the replacement
 // register.
 static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
-                                unsigned *NewReg = nullptr) {
+                                unsigned *NewVReg = nullptr) {
   VReg = removeCopies(MRI, VReg);
   if (!Register::isVirtualRegister(VReg))
     return 0;
@@ -708,37 +708,8 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
   bool Is64Bit = AArch64::GPR64allRegClass.hasSubClassEq(MRI.getRegClass(VReg));
   const MachineInstr *DefMI = MRI.getVRegDef(VReg);
   unsigned Opc = 0;
-  unsigned SrcReg = 0;
+  unsigned SrcOpNum = 0;
   switch (DefMI->getOpcode()) {
-  case AArch64::SUBREG_TO_REG:
-    // Check for the following way to define an 64-bit immediate:
-    //   %0:gpr32 = MOVi32imm 1
-    //   %1:gpr64 = SUBREG_TO_REG 0, %0:gpr32, %subreg.sub_32
-    if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 0)
-      return 0;
-    if (!DefMI->getOperand(2).isReg())
-      return 0;
-    if (!DefMI->getOperand(3).isImm() ||
-        DefMI->getOperand(3).getImm() != AArch64::sub_32)
-      return 0;
-    DefMI = MRI.getVRegDef(DefMI->getOperand(2).getReg());
-    if (DefMI->getOpcode() != AArch64::MOVi32imm)
-      return 0;
-    if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
-      return 0;
-    assert(Is64Bit);
-    SrcReg = AArch64::XZR;
-    Opc = AArch64::CSINCXr;
-    break;
-
-  case AArch64::MOVi32imm:
-  case AArch64::MOVi64imm:
-    if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
-      return 0;
-    SrcReg = Is64Bit ? AArch64::XZR : AArch64::WZR;
-    Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
-    break;
-
   case AArch64::ADDSXri:
   case AArch64::ADDSWri:
     // if NZCV is used, do not fold.
@@ -753,7 +724,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
     if (!DefMI->getOperand(2).isImm() || DefMI->getOperand(2).getImm() != 1 ||
         DefMI->getOperand(3).getImm() != 0)
       return 0;
-    SrcReg = DefMI->getOperand(1).getReg();
+    SrcOpNum = 1;
     Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
     break;
 
@@ -763,7 +734,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
     unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
     if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
       return 0;
-    SrcReg = DefMI->getOperand(2).getReg();
+    SrcOpNum = 2;
     Opc = Is64Bit ? AArch64::CSINVXr : AArch64::CSINVWr;
     break;
   }
@@ -782,17 +753,17 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
     unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
     if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
       return 0;
-    SrcReg = DefMI->getOperand(2).getReg();
+    SrcOpNum = 2;
     Opc = Is64Bit ? AArch64::CSNEGXr : AArch64::CSNEGWr;
     break;
   }
   default:
     return 0;
   }
-  assert(Opc && SrcReg && "Missing parameters");
+  assert(Opc && SrcOpNum && "Missing parameters");
 
-  if (NewReg)
-    *NewReg = SrcReg;
+  if (NewVReg)
+    *NewVReg = DefMI->getOperand(SrcOpNum).getReg();
   return Opc;
 }
 
@@ -993,34 +964,28 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB,
 
   // Try folding simple instructions into the csel.
   if (TryFold) {
-    unsigned NewReg = 0;
-    unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewReg);
+    unsigned NewVReg = 0;
+    unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewVReg);
     if (FoldedOpc) {
       // The folded opcodes csinc, csinc and csneg apply the operation to
       // FalseReg, so we need to invert the condition.
       CC = AArch64CC::getInvertedCondCode(CC);
       TrueReg = FalseReg;
     } else
-      FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewReg);
+      FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewVReg);
 
     // Fold the operation. Leave any dead instructions for DCE to clean up.
     if (FoldedOpc) {
-      FalseReg = NewReg;
+      FalseReg = NewVReg;
       Opc = FoldedOpc;
-      // Extend the live range of NewReg.
-      MRI.clearKillFlags(NewReg);
+      // The extends the live range of NewVReg.
+      MRI.clearKillFlags(NewVReg);
     }
   }
 
   // Pull all virtual register into the appropriate class.
   MRI.constrainRegClass(TrueReg, RC);
-  // FalseReg might be WZR or XZR if the folded operand is a literal 1.
-  assert(
-      (FalseReg.isVirtual() || FalseReg == AArch64::WZR ||
-       FalseReg == AArch64::XZR) &&
-      "FalseReg was folded into a non-virtual register other than WZR or XZR");
-  if (FalseReg.isVirtual())
-    MRI.constrainRegClass(FalseReg, RC);
+  MRI.constrainRegClass(FalseReg, RC);
 
   // Insert the csel.
   BuildMI(MBB, I, DL, get(Opc), DstReg)
@@ -2148,16 +2113,47 @@ bool AArch64InstrInfo::removeCmpToZeroOrOne(
   return true;
 }
 
-bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
-  if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
-      MI.getOpcode() != AArch64::CATCHRET)
-    return false;
+static inline void expandCtSelect(MachineBasicBlock &MBB, MachineInstr &MI,
+                                  DebugLoc &DL, const MCInstrDesc &MCID) {
+  MachineInstrBuilder Builder = BuildMI(MBB, MI, DL, MCID);
+  for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+    Builder.add(MI.getOperand(Idx));
+  }
+  Builder->setFlag(MachineInstr::NoMerge);
+  MBB.remove_instr(&MI);
+}
 
+bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   MachineBasicBlock &MBB = *MI.getParent();
   auto &Subtarget = MBB.getParent()->getSubtarget<AArch64Subtarget>();
   auto TRI = Subtarget.getRegisterInfo();
   DebugLoc DL = MI.getDebugLoc();
 
+  switch (MI.getOpcode()) {
+  case AArch64::I32CTSELECT:
+    expandCtSelect(MBB, MI, DL, get(AArch64::CSELWr));
+    return true;
+  case AArch64::I64CTSELECT:
+    expandCtSelect(MBB, MI, DL, get(AArch64::CSELXr));
+    return true;
+  case AArch64::BF16CTSELECT:
+    expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+    return true;
+  case AArch64::F16CTSELECT:
+    expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+    return true;
+  case AArch64::F32CTSELECT:
+    expandCtSelect(MBB, MI, DL, get(AArch64::FCSELSrrr));
+    return true;
+  case AArch64::F64CTSELECT:
+    expandCtSelect(MBB, MI, DL, get(AArch64::FCSELDrrr));
+    return true;
+  }
+
+  if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
+      MI.getOpcode() != AArch64::CATCHRET)
+    return false;
+
   if (MI.getOpcode() == AArch64::CATCHRET) {
     // Skip to the first instruction before the epilog.
     const TargetInstrInfo *TII =
@@ -5098,7 +5094,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                                    bool RenamableDest,
                                    bool RenamableSrc) const {
   if (AArch64::GPR32spRegClass.contains(DestReg) &&
-      AArch64::GPR32spRegClass.contains(SrcReg)) {
+      (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) {
     if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) {
       // If either operand is WSP, expand to ADD #0.
       if (Subtarget.hasZeroCycleRegMoveGPR64() &&
@@ -5123,14 +5119,30 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
             .addImm(0)
             .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
       }
+    } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR64() &&
+               !Subtarget.hasZeroCycleZeroingGPR32()) {
+      // Use 64-bit zeroing when available but 32-bit zeroing is not
+      MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
+                                                   &AArch64::GPR64spRegClass);
+      assert(DestRegX.isValid() && "Destination super-reg not valid");
+      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+    } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) {
+      BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
     } else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
                !Subtarget.hasZeroCycleRegMoveGPR32()) {
       // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
       MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
                                                    &AArch64::GPR64spRegClass);
       assert(DestRegX.isValid() && "Destination super-reg not valid");
-      MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
-                                                  &AArch64::GPR64spRegClass);
+      MCRegister SrcRegX =
+          SrcReg == AArch64::WZR
+              ? AArch64::XZR
+              : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
+                                       &AArch64::GPR64spRegClass);
       assert(SrcRegX.isValid() && "Source super-reg not valid");
       // This instruction is reading and writing X registers.  This may upset
       // the register scavenger and machine verifier, so we need to indicate
@@ -5149,59 +5161,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     return;
   }
 
-  // GPR32 zeroing
-  if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) {
-    if (Subtarget.hasZeroCycleZeroingGPR64() &&
-        !Subtarget.hasZeroCycleZeroingGPR32()) {
-      MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
-                                                   &AArch64::GPR64spRegClass);
-      assert(DestRegX.isValid() && "Destination super-reg not valid");
-      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else if (Subtarget.hasZeroCycleZeroingGPR32()) {
-      BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else {
-      BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
-          .addReg(AArch64::WZR)
-          .addReg(AArch64::WZR);
-    }
-    return;
-  }
-
-  if (AArch64::GPR64spRegClass.contains(DestReg) &&
-      AArch64::GPR64spRegClass.contains(SrcReg)) {
-    if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
-      // If either operand is SP, expand to ADD #0.
-      BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
-          .addReg(SrcReg, getKillRegState(KillSrc))
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else {
-      // Otherwise, expand to ORR XZR.
-      BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
-          .addReg(AArch64::XZR)
-          .addReg(SrcReg, getKillRegState(KillSrc));
-    }
-    return;
-  }
-
-  // GPR64 zeroing
-  if (AArch64::GPR64spRegClass.contains(DestReg) && SrcReg == AArch64::XZR) {
-    if (Subtarget.hasZeroCycleZeroingGPR64()) {
-      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
-          .addImm(0)
-          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
-    } else {
-      BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
-          .addReg(AArch64::XZR)
-          .addReg(AArch64::XZR);
-    }
-    return;
-  }
-
   // Copy a Predicate register by ORRing with itself.
   if (AArch64::PPRRegClass.contains(DestReg) &&
       AArch64::PPRRegClass.contains(SrcReg)) {
@@ -5286,6 +5245,27 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     return;
   }
 
+  if (AArch64::GPR64spRegClass.contains(DestReg) &&
+      (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) {
+    if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
+      // If either operand is SP, expand to ADD #0.
+      BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
+          .addReg(SrcReg, getKillRegState(KillSrc))
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+    } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) {
+      BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
+          .addImm(0)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+    } else {
+      // Otherwise, expand to ORR XZR.
+      BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
+          .addReg(AArch64::XZR)
+          .addReg(SrcReg, getKillRegState(KillSrc));
+    }
+    return;
+  }
+
   // Copy a DDDD register quad by copying the individual sub-registers.
   if (AArch64::DDDDRegClass.contains(DestReg) &&
       AArch64::DDDDRegClass.contains(SrcReg)) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 2871a20e28b65..5017a39789d08 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -476,6 +476,9 @@ def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>;
 def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>,
                                         SDTCisVT<2, OtherVT>]>;
 
+def SDT_AArch64CtSelect : SDTypeProfile<1, 4,
+                                        [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
+                                         SDTCisInt<3>, SDTCisVT<4, i32>]>;
 def SDT_AArch64CSel  : SDTypeProfile<1, 4,
                                    [SDTCisSameAs<0, 1>,
                                     SDTCisSameAs<0, 2>,
@@ -843,6 +846,7 @@ def AArch64tbz           : SDNode<"AArch64ISD::TBZ", SDT_AArch64tbz,
 def AArch64tbnz           : SDNode<"AArch64ISD::TBNZ", SDT_AArch64tbz,
                                 [SDNPHasChain]>;
 
+def AArch64ctselect : SDNode<"AArch64ISD::CTSELECT", SDT_AArch64CtSelect>;
 
 def AArch64csel          : SDNode<"AArch64ISD::CSEL", SDT_AArch64CSel>;
 // Conditional select invert.
@@ -5644,6 +5648,42 @@ def F128CSEL : Pseudo<(outs FPR128:$Rd),
   let hasNoSchedulingInfo = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Constant-time conditional selection instructions
+//===-------------------------------------------...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants