Skip to content

Commit 95ac3d1

Browse files
committed
[AArch64][GlobalISel] Add G_VECREDUCE fewerElements support for full scalarization.
For some reductions like G_VECREDUCE_OR on AArch64, we need to scalarize completely if the source is <= 64b. This change adds support for that in the legalizer. If the source has a pow-2 num elements, then we can do a tree reduction using the scalar operation in the individual elements. Otherwise, we just create a sequential chain of operations. For AArch64, we only need to scalarize if the input is <64b. If it's great than 64b then we can first do a fewElements step to 64b, taking advantage of vector instructions until we reach the point of scalarization. I also had to relax the verifier checks for reductions because the intrinsics support <1 x EltTy> types, which we lower to scalars for GlobalISel. Differential Revision: https://reviews.llvm.org/D108276
1 parent fbb8e77 commit 95ac3d1

File tree

8 files changed

+1137
-37
lines changed

8 files changed

+1137
-37
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h

+1
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ class LegalizerHelper {
403403
LegalizeResult lowerAbsToAddXor(MachineInstr &MI);
404404
LegalizeResult lowerAbsToMaxNeg(MachineInstr &MI);
405405
LegalizeResult lowerIsNaN(MachineInstr &MI);
406+
LegalizeResult lowerVectorReduction(MachineInstr &MI);
406407
};
407408

408409
/// Helper function that creates a libcall to the given \p Name using the given

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

+94-30
Original file line numberDiff line numberDiff line change
@@ -3489,6 +3489,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
34893489
return lowerRotate(MI);
34903490
case G_ISNAN:
34913491
return lowerIsNaN(MI);
3492+
GISEL_VECREDUCE_CASES_NONSEQ
3493+
return lowerVectorReduction(MI);
34923494
}
34933495
}
34943496

@@ -4637,35 +4639,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
46374639
return Legalized;
46384640
}
46394641

4640-
LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
4641-
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4642-
unsigned Opc = MI.getOpcode();
4643-
assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
4644-
Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
4645-
"Sequential reductions not expected");
4646-
4647-
if (TypeIdx != 1)
4648-
return UnableToLegalize;
4649-
4650-
// The semantics of the normal non-sequential reductions allow us to freely
4651-
// re-associate the operation.
4652-
Register SrcReg = MI.getOperand(1).getReg();
4653-
LLT SrcTy = MRI.getType(SrcReg);
4654-
Register DstReg = MI.getOperand(0).getReg();
4655-
LLT DstTy = MRI.getType(DstReg);
4656-
4657-
if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)
4658-
return UnableToLegalize;
4659-
4660-
SmallVector<Register> SplitSrcs;
4661-
const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements();
4662-
extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
4663-
SmallVector<Register> PartialReductions;
4664-
for (unsigned Part = 0; Part < NumParts; ++Part) {
4665-
PartialReductions.push_back(
4666-
MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
4667-
}
4668-
4642+
static unsigned getScalarOpcForReduction(unsigned Opc) {
46694643
unsigned ScalarOpc;
46704644
switch (Opc) {
46714645
case TargetOpcode::G_VECREDUCE_FADD:
@@ -4708,10 +4682,81 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
47084682
ScalarOpc = TargetOpcode::G_UMIN;
47094683
break;
47104684
default:
4711-
LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n");
4685+
llvm_unreachable("Unhandled reduction");
4686+
}
4687+
return ScalarOpc;
4688+
}
4689+
4690+
LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
4691+
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4692+
unsigned Opc = MI.getOpcode();
4693+
assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
4694+
Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
4695+
"Sequential reductions not expected");
4696+
4697+
if (TypeIdx != 1)
47124698
return UnableToLegalize;
4699+
4700+
// The semantics of the normal non-sequential reductions allow us to freely
4701+
// re-associate the operation.
4702+
Register SrcReg = MI.getOperand(1).getReg();
4703+
LLT SrcTy = MRI.getType(SrcReg);
4704+
Register DstReg = MI.getOperand(0).getReg();
4705+
LLT DstTy = MRI.getType(DstReg);
4706+
4707+
if (NarrowTy.isVector() &&
4708+
(SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
4709+
return UnableToLegalize;
4710+
4711+
unsigned ScalarOpc = getScalarOpcForReduction(Opc);
4712+
SmallVector<Register> SplitSrcs;
4713+
// If NarrowTy is a scalar then we're being asked to scalarize.
4714+
const unsigned NumParts =
4715+
NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
4716+
: SrcTy.getNumElements();
4717+
4718+
extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
4719+
if (NarrowTy.isScalar()) {
4720+
if (DstTy != NarrowTy)
4721+
return UnableToLegalize; // FIXME: handle implicit extensions.
4722+
4723+
if (isPowerOf2_32(NumParts)) {
4724+
// Generate a tree of scalar operations to reduce the critical path.
4725+
SmallVector<Register> PartialResults;
4726+
unsigned NumPartsLeft = NumParts;
4727+
while (NumPartsLeft > 1) {
4728+
for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
4729+
PartialResults.emplace_back(
4730+
MIRBuilder
4731+
.buildInstr(ScalarOpc, {NarrowTy},
4732+
{SplitSrcs[Idx], SplitSrcs[Idx + 1]})
4733+
.getReg(0));
4734+
}
4735+
SplitSrcs = PartialResults;
4736+
PartialResults.clear();
4737+
NumPartsLeft = SplitSrcs.size();
4738+
}
4739+
assert(SplitSrcs.size() == 1);
4740+
MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
4741+
MI.eraseFromParent();
4742+
return Legalized;
4743+
}
4744+
// If we can't generate a tree, then just do sequential operations.
4745+
Register Acc = SplitSrcs[0];
4746+
for (unsigned Idx = 1; Idx < NumParts; ++Idx)
4747+
Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
4748+
.getReg(0);
4749+
MIRBuilder.buildCopy(DstReg, Acc);
4750+
MI.eraseFromParent();
4751+
return Legalized;
4752+
}
4753+
SmallVector<Register> PartialReductions;
4754+
for (unsigned Part = 0; Part < NumParts; ++Part) {
4755+
PartialReductions.push_back(
4756+
MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
47134757
}
47144758

4759+
47154760
// If the types involved are powers of 2, we can generate intermediate vector
47164761
// ops, before generating a final reduction operation.
47174762
if (isPowerOf2_32(SrcTy.getNumElements()) &&
@@ -7389,3 +7434,22 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerIsNaN(MachineInstr &MI) {
73897434
MI.eraseFromParent();
73907435
return Legalized;
73917436
}
7437+
7438+
LegalizerHelper::LegalizeResult
7439+
LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
7440+
Register SrcReg = MI.getOperand(1).getReg();
7441+
LLT SrcTy = MRI.getType(SrcReg);
7442+
LLT DstTy = MRI.getType(SrcReg);
7443+
7444+
// The source could be a scalar if the IR type was <1 x sN>.
7445+
if (SrcTy.isScalar()) {
7446+
if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
7447+
return UnableToLegalize; // FIXME: handle extension.
7448+
// This can be just a plain copy.
7449+
Observer.changingInstr(MI);
7450+
MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
7451+
Observer.changedInstr(MI);
7452+
return Legalized;
7453+
}
7454+
return UnableToLegalize;;
7455+
}

llvm/lib/CodeGen/MachineVerifier.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -1589,11 +1589,8 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
15891589
case TargetOpcode::G_VECREDUCE_UMAX:
15901590
case TargetOpcode::G_VECREDUCE_UMIN: {
15911591
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
1592-
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
15931592
if (!DstTy.isScalar())
15941593
report("Vector reduction requires a scalar destination type", MI);
1595-
if (!SrcTy.isVector())
1596-
report("Vector reduction requires vector source=", MI);
15971594
break;
15981595
}
15991596

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,27 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
691691
.clampMaxNumElements(1, s32, 4)
692692
.lower();
693693

694+
getActionDefinitionsBuilder(G_VECREDUCE_OR)
695+
// Try to break down into smaller vectors as long as they're at least 64
696+
// bits. This lets us use vector operations for some parts of the
697+
// reduction.
698+
.fewerElementsIf(
699+
[=](const LegalityQuery &Q) {
700+
LLT SrcTy = Q.Types[1];
701+
if (SrcTy.isScalar())
702+
return false;
703+
if (!isPowerOf2_32(SrcTy.getNumElements()))
704+
return false;
705+
// We can usually perform 64b vector operations.
706+
return SrcTy.getSizeInBits() > 64;
707+
},
708+
[=](const LegalityQuery &Q) {
709+
LLT SrcTy = Q.Types[1];
710+
return std::make_pair(1, SrcTy.divide(2));
711+
})
712+
.scalarize(1)
713+
.lower();
714+
694715
getActionDefinitionsBuilder({G_UADDSAT, G_USUBSAT})
695716
.lowerIf([=](const LegalityQuery &Q) { return Q.Types[0].isScalar(); });
696717

0 commit comments

Comments
 (0)