-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[ConstantTime] Native ct.select support for ARM64 #166706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/wizardengineer/ct-select-clang
Are you sure you want to change the base?
[ConstantTime] Native ct.select support for ARM64 #166706
Conversation
|
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch implements architecture-specific lowering for ct.select on AArch64 using CSEL (conditional select) instructions for constant-time selection. Implementation details: - Uses CSEL family of instructions for scalar integer types - Uses FCSEL for floating-point types (F16, BF16, F32, F64) - Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL - Handles vector types appropriately - Comprehensive test coverage for AArch64 The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - InstrInfo: Pseudo-instruction definitions and patterns - MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL - Proper handling of condition codes for constant-time guarantees
071428b to
7de2b81
Compare
cbb5490 to
6ac8221
Compare
|
@llvm/pr-subscribers-backend-aarch64 Author: Julius Alexandre (wizardengineer) ChangesThis patch implements architecture-specific lowering for ct.select on AArch64 Implementation details:
The implementation includes:
Patch is 28.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166706.diff 6 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..54d0ea168d0b6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -511,12 +511,36 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BR_CC, MVT::f64, Custom);
setOperationAction(ISD::SELECT, MVT::i32, Custom);
setOperationAction(ISD::SELECT, MVT::i64, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::i8, Promote);
+ setOperationAction(ISD::CTSELECT, MVT::i16, Promote);
+ setOperationAction(ISD::CTSELECT, MVT::i32, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::i64, Custom);
if (Subtarget->hasFPARMv8()) {
setOperationAction(ISD::SELECT, MVT::f16, Custom);
setOperationAction(ISD::SELECT, MVT::bf16, Custom);
}
+ if (Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::CTSELECT, MVT::f16, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::bf16, Custom);
+ } else {
+ setOperationAction(ISD::CTSELECT, MVT::f16, Promote);
+ setOperationAction(ISD::CTSELECT, MVT::bf16, Promote);
+ }
setOperationAction(ISD::SELECT, MVT::f32, Custom);
setOperationAction(ISD::SELECT, MVT::f64, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::f32, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::f64, Custom);
+ for (MVT VT : MVT::vector_valuetypes()) {
+ MVT elemType = VT.getVectorElementType();
+ if (elemType == MVT::i8 || elemType == MVT::i16) {
+ setOperationAction(ISD::CTSELECT, VT, Promote);
+ } else if ((elemType == MVT::f16 || elemType == MVT::bf16) &&
+ !Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::CTSELECT, VT, Promote);
+ } else {
+ setOperationAction(ISD::CTSELECT, VT, Expand);
+ }
+ }
setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
@@ -3328,6 +3352,20 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator(
IntDiscOp.setImm(IntDisc);
}
+MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI,
+ MachineBasicBlock *MBB,
+ unsigned Opcode) const {
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ DebugLoc DL = MI.getDebugLoc();
+ MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode));
+ for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+ Builder.add(MI.getOperand(Idx));
+ }
+ Builder->setFlag(MachineInstr::NoMerge);
+ MBB->remove_instr(&MI);
+ return MBB;
+}
+
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {
@@ -7590,6 +7628,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerSELECT(Op, DAG);
case ISD::SELECT_CC:
return LowerSELECT_CC(Op, DAG);
+ case ISD::CTSELECT:
+ return LowerCTSELECT(Op, DAG);
case ISD::JumpTable:
return LowerJumpTable(Op, DAG);
case ISD::BR_JT:
@@ -12149,6 +12189,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
return Res;
}
+SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue CCVal = Op->getOperand(0);
+ SDValue TVal = Op->getOperand(1);
+ SDValue FVal = Op->getOperand(2);
+ SDLoc DL(Op);
+
+ EVT VT = Op.getValueType();
+
+ SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType());
+ SDValue CC;
+ SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL);
+
+ return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp);
+}
+
SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
SelectionDAG &DAG) const {
// Jump table entries as PC relative offsets. No additional tweaking
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 2cb8ed29f252a..987377bc49023 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -23,6 +23,11 @@
namespace llvm {
+namespace AArch64ISD {
+// Forward declare the enum from the generated file
+enum GenNodeType : unsigned;
+} // namespace AArch64ISD
+
class AArch64TargetMachine;
namespace AArch64 {
@@ -202,6 +207,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineOperand &AddrDiscOp,
const TargetRegisterClass *AddrDiscRC) const;
+ MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB,
+ unsigned Opcode) const;
+
MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
MachineBasicBlock *MBB) const override;
@@ -684,6 +692,7 @@ class AArch64TargetLowering : public TargetLowering {
iterator_range<SDNode::user_iterator> Users,
SDNodeFlags Flags, const SDLoc &dl,
SelectionDAG &DAG) const;
+ SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
@@ -919,6 +928,8 @@ class AArch64TargetLowering : public TargetLowering {
bool hasMultipleConditionRegisters(EVT VT) const override {
return VT.isScalableVector();
}
+
+ bool isSelectSupported(SelectSupportKind Kind) const override { return true; }
};
namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index ccc8eb8a9706d..bab67f57ea6b6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -700,7 +700,7 @@ static unsigned removeCopies(const MachineRegisterInfo &MRI, unsigned VReg) {
// csel instruction. If so, return the folded opcode, and the replacement
// register.
static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
- unsigned *NewReg = nullptr) {
+ unsigned *NewVReg = nullptr) {
VReg = removeCopies(MRI, VReg);
if (!Register::isVirtualRegister(VReg))
return 0;
@@ -708,37 +708,8 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
bool Is64Bit = AArch64::GPR64allRegClass.hasSubClassEq(MRI.getRegClass(VReg));
const MachineInstr *DefMI = MRI.getVRegDef(VReg);
unsigned Opc = 0;
- unsigned SrcReg = 0;
+ unsigned SrcOpNum = 0;
switch (DefMI->getOpcode()) {
- case AArch64::SUBREG_TO_REG:
- // Check for the following way to define an 64-bit immediate:
- // %0:gpr32 = MOVi32imm 1
- // %1:gpr64 = SUBREG_TO_REG 0, %0:gpr32, %subreg.sub_32
- if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 0)
- return 0;
- if (!DefMI->getOperand(2).isReg())
- return 0;
- if (!DefMI->getOperand(3).isImm() ||
- DefMI->getOperand(3).getImm() != AArch64::sub_32)
- return 0;
- DefMI = MRI.getVRegDef(DefMI->getOperand(2).getReg());
- if (DefMI->getOpcode() != AArch64::MOVi32imm)
- return 0;
- if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
- return 0;
- assert(Is64Bit);
- SrcReg = AArch64::XZR;
- Opc = AArch64::CSINCXr;
- break;
-
- case AArch64::MOVi32imm:
- case AArch64::MOVi64imm:
- if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1)
- return 0;
- SrcReg = Is64Bit ? AArch64::XZR : AArch64::WZR;
- Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
- break;
-
case AArch64::ADDSXri:
case AArch64::ADDSWri:
// if NZCV is used, do not fold.
@@ -753,7 +724,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
if (!DefMI->getOperand(2).isImm() || DefMI->getOperand(2).getImm() != 1 ||
DefMI->getOperand(3).getImm() != 0)
return 0;
- SrcReg = DefMI->getOperand(1).getReg();
+ SrcOpNum = 1;
Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr;
break;
@@ -763,7 +734,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
return 0;
- SrcReg = DefMI->getOperand(2).getReg();
+ SrcOpNum = 2;
Opc = Is64Bit ? AArch64::CSINVXr : AArch64::CSINVWr;
break;
}
@@ -782,17 +753,17 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg,
unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg());
if (ZReg != AArch64::XZR && ZReg != AArch64::WZR)
return 0;
- SrcReg = DefMI->getOperand(2).getReg();
+ SrcOpNum = 2;
Opc = Is64Bit ? AArch64::CSNEGXr : AArch64::CSNEGWr;
break;
}
default:
return 0;
}
- assert(Opc && SrcReg && "Missing parameters");
+ assert(Opc && SrcOpNum && "Missing parameters");
- if (NewReg)
- *NewReg = SrcReg;
+ if (NewVReg)
+ *NewVReg = DefMI->getOperand(SrcOpNum).getReg();
return Opc;
}
@@ -993,34 +964,28 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB,
// Try folding simple instructions into the csel.
if (TryFold) {
- unsigned NewReg = 0;
- unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewReg);
+ unsigned NewVReg = 0;
+ unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewVReg);
if (FoldedOpc) {
// The folded opcodes csinc, csinc and csneg apply the operation to
// FalseReg, so we need to invert the condition.
CC = AArch64CC::getInvertedCondCode(CC);
TrueReg = FalseReg;
} else
- FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewReg);
+ FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewVReg);
// Fold the operation. Leave any dead instructions for DCE to clean up.
if (FoldedOpc) {
- FalseReg = NewReg;
+ FalseReg = NewVReg;
Opc = FoldedOpc;
- // Extend the live range of NewReg.
- MRI.clearKillFlags(NewReg);
+ // The extends the live range of NewVReg.
+ MRI.clearKillFlags(NewVReg);
}
}
// Pull all virtual register into the appropriate class.
MRI.constrainRegClass(TrueReg, RC);
- // FalseReg might be WZR or XZR if the folded operand is a literal 1.
- assert(
- (FalseReg.isVirtual() || FalseReg == AArch64::WZR ||
- FalseReg == AArch64::XZR) &&
- "FalseReg was folded into a non-virtual register other than WZR or XZR");
- if (FalseReg.isVirtual())
- MRI.constrainRegClass(FalseReg, RC);
+ MRI.constrainRegClass(FalseReg, RC);
// Insert the csel.
BuildMI(MBB, I, DL, get(Opc), DstReg)
@@ -2148,16 +2113,47 @@ bool AArch64InstrInfo::removeCmpToZeroOrOne(
return true;
}
-bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
- if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
- MI.getOpcode() != AArch64::CATCHRET)
- return false;
+static inline void expandCtSelect(MachineBasicBlock &MBB, MachineInstr &MI,
+ DebugLoc &DL, const MCInstrDesc &MCID) {
+ MachineInstrBuilder Builder = BuildMI(MBB, MI, DL, MCID);
+ for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
+ Builder.add(MI.getOperand(Idx));
+ }
+ Builder->setFlag(MachineInstr::NoMerge);
+ MBB.remove_instr(&MI);
+}
+bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
MachineBasicBlock &MBB = *MI.getParent();
auto &Subtarget = MBB.getParent()->getSubtarget<AArch64Subtarget>();
auto TRI = Subtarget.getRegisterInfo();
DebugLoc DL = MI.getDebugLoc();
+ switch (MI.getOpcode()) {
+ case AArch64::I32CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::CSELWr));
+ return true;
+ case AArch64::I64CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::CSELXr));
+ return true;
+ case AArch64::BF16CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+ return true;
+ case AArch64::F16CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELHrrr));
+ return true;
+ case AArch64::F32CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELSrrr));
+ return true;
+ case AArch64::F64CTSELECT:
+ expandCtSelect(MBB, MI, DL, get(AArch64::FCSELDrrr));
+ return true;
+ }
+
+ if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
+ MI.getOpcode() != AArch64::CATCHRET)
+ return false;
+
if (MI.getOpcode() == AArch64::CATCHRET) {
// Skip to the first instruction before the epilog.
const TargetInstrInfo *TII =
@@ -5098,7 +5094,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool RenamableDest,
bool RenamableSrc) const {
if (AArch64::GPR32spRegClass.contains(DestReg) &&
- AArch64::GPR32spRegClass.contains(SrcReg)) {
+ (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) {
if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) {
// If either operand is WSP, expand to ADD #0.
if (Subtarget.hasZeroCycleRegMoveGPR64() &&
@@ -5123,14 +5119,30 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
}
+ } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR64() &&
+ !Subtarget.hasZeroCycleZeroingGPR32()) {
+ // Use 64-bit zeroing when available but 32-bit zeroing is not
+ MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(DestRegX.isValid() && "Destination super-reg not valid");
+ BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) {
+ BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
} else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
!Subtarget.hasZeroCycleRegMoveGPR32()) {
// Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
&AArch64::GPR64spRegClass);
assert(DestRegX.isValid() && "Destination super-reg not valid");
- MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
+ MCRegister SrcRegX =
+ SrcReg == AArch64::WZR
+ ? AArch64::XZR
+ : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
assert(SrcRegX.isValid() && "Source super-reg not valid");
// This instruction is reading and writing X registers. This may upset
// the register scavenger and machine verifier, so we need to indicate
@@ -5149,59 +5161,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}
- // GPR32 zeroing
- if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) {
- if (Subtarget.hasZeroCycleZeroingGPR64() &&
- !Subtarget.hasZeroCycleZeroingGPR32()) {
- MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
- assert(DestRegX.isValid() && "Destination super-reg not valid");
- BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX)
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else if (Subtarget.hasZeroCycleZeroingGPR32()) {
- BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else {
- BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
- .addReg(AArch64::WZR)
- .addReg(AArch64::WZR);
- }
- return;
- }
-
- if (AArch64::GPR64spRegClass.contains(DestReg) &&
- AArch64::GPR64spRegClass.contains(SrcReg)) {
- if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
- // If either operand is SP, expand to ADD #0.
- BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
- .addReg(SrcReg, getKillRegState(KillSrc))
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else {
- // Otherwise, expand to ORR XZR.
- BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
- .addReg(AArch64::XZR)
- .addReg(SrcReg, getKillRegState(KillSrc));
- }
- return;
- }
-
- // GPR64 zeroing
- if (AArch64::GPR64spRegClass.contains(DestReg) && SrcReg == AArch64::XZR) {
- if (Subtarget.hasZeroCycleZeroingGPR64()) {
- BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
- .addImm(0)
- .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
- } else {
- BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
- .addReg(AArch64::XZR)
- .addReg(AArch64::XZR);
- }
- return;
- }
-
// Copy a Predicate register by ORRing with itself.
if (AArch64::PPRRegClass.contains(DestReg) &&
AArch64::PPRRegClass.contains(SrcReg)) {
@@ -5286,6 +5245,27 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}
+ if (AArch64::GPR64spRegClass.contains(DestReg) &&
+ (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) {
+ if (DestReg == AArch64::SP || SrcReg == AArch64::SP) {
+ // If either operand is SP, expand to ADD #0.
+ BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg)
+ .addReg(SrcReg, getKillRegState(KillSrc))
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) {
+ BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg)
+ .addImm(0)
+ .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else {
+ // Otherwise, expand to ORR XZR.
+ BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg)
+ .addReg(AArch64::XZR)
+ .addReg(SrcReg, getKillRegState(KillSrc));
+ }
+ return;
+ }
+
// Copy a DDDD register quad by copying the individual sub-registers.
if (AArch64::DDDDRegClass.contains(DestReg) &&
AArch64::DDDDRegClass.contains(SrcReg)) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 2871a20e28b65..5017a39789d08 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -476,6 +476,9 @@ def SDT_AArch64cbz : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisVT<1, OtherVT>]>;
def SDT_AArch64tbz : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>,
SDTCisVT<2, OtherVT>]>;
+def SDT_AArch64CtSelect : SDTypeProfile<1, 4,
+ [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
+ SDTCisInt<3>, SDTCisVT<4, i32>]>;
def SDT_AArch64CSel : SDTypeProfile<1, 4,
[SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
@@ -843,6 +846,7 @@ def AArch64tbz : SDNode<"AArch64ISD::TBZ", SDT_AArch64tbz,
def AArch64tbnz : SDNode<"AArch64ISD::TBNZ", SDT_AArch64tbz,
[SDNPHasChain]>;
+def AArch64ctselect : SDNode<"AArch64ISD::CTSELECT", SDT_AArch64CtSelect>;
def AArch64csel : SDNode<"AArch64ISD::CSEL", SDT_AArch64CSel>;
// Conditional select invert.
@@ -5644,6 +5648,42 @@ def F128CSEL : Pseudo<(outs FPR128:$Rd),
let hasNoSchedulingInfo = 1;
}
+//===----------------------------------------------------------------------===//
+// Constant-time conditional selection instructions
+//===-------------------------------------------...
[truncated]
|

This patch implements architecture-specific lowering for ct.select on AArch64
using CSEL (conditional select) instructions for constant-time selection.
Implementation details:
The implementation includes: