Skip to content

Commit 23c2182

Browse files
committed
Support generic expansion of ordered vector reduction (PR36732)
Without the fast math flags, the llvm.experimental.vector.reduce.fadd/fmul intrinsic expansions must be expanded in order. This patch scalarizes the reduction, applying the accumulator at the start of the sequence: ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[NumElts-1] Differential Revision: https://reviews.llvm.org/D45366 llvm-svn: 329585
1 parent bec8a66 commit 23c2182

File tree

4 files changed

+84
-14
lines changed

4 files changed

+84
-14
lines changed

Diff for: llvm/include/llvm/Transforms/Utils/LoopUtils.h

+7
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,13 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
509509
LoopSafetyInfo *SafetyInfo,
510510
OptimizationRemarkEmitter *ORE = nullptr);
511511

512+
/// Generates an ordered vector reduction using extracts to reduce the value.
513+
Value *
514+
getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, unsigned Op,
515+
RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind =
516+
RecurrenceDescriptor::MRK_Invalid,
517+
ArrayRef<Value *> RedOps = None);
518+
512519
/// Generates a vector reduction using shufflevectors to reduce the value.
513520
Value *getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
514521
RecurrenceDescriptor::MinMaxRecurrenceKind

Diff for: llvm/lib/CodeGen/ExpandReductions.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,26 @@ RecurrenceDescriptor::MinMaxRecurrenceKind getMRK(Intrinsic::ID ID) {
7878

7979
bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
8080
bool Changed = false;
81-
SmallVector<IntrinsicInst*, 4> Worklist;
81+
SmallVector<IntrinsicInst *, 4> Worklist;
8282
for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
8383
if (auto II = dyn_cast<IntrinsicInst>(&*I))
8484
Worklist.push_back(II);
8585

8686
for (auto *II : Worklist) {
8787
IRBuilder<> Builder(II);
88+
bool IsOrdered = false;
89+
Value *Acc = nullptr;
8890
Value *Vec = nullptr;
8991
auto ID = II->getIntrinsicID();
9092
auto MRK = RecurrenceDescriptor::MRK_Invalid;
9193
switch (ID) {
9294
case Intrinsic::experimental_vector_reduce_fadd:
9395
case Intrinsic::experimental_vector_reduce_fmul:
9496
// FMFs must be attached to the call, otherwise it's an ordered reduction
95-
// and it can't be handled by generating this shuffle sequence.
96-
// TODO: Implement scalarization of ordered reductions here for targets
97-
// without native support.
97+
// and it can't be handled by generating a shuffle sequence.
9898
if (!II->getFastMathFlags().isFast())
99-
continue;
99+
IsOrdered = true;
100+
Acc = II->getArgOperand(0);
100101
Vec = II->getArgOperand(1);
101102
break;
102103
case Intrinsic::experimental_vector_reduce_add:
@@ -118,7 +119,9 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
118119
}
119120
if (!TTI->shouldExpandReduction(II))
120121
continue;
121-
auto Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
122+
Value *Rdx =
123+
IsOrdered ? getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), MRK)
124+
: getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
122125
II->replaceAllUsesWith(Rdx);
123126
II->eraseFromParent();
124127
Changed = true;

Diff for: llvm/lib/Transforms/Utils/LoopUtils.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,38 @@ static Value *addFastMathFlag(Value *V) {
15261526
return V;
15271527
}
15281528

1529+
// Helper to generate an ordered reduction.
1530+
Value *
1531+
llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src,
1532+
unsigned Op,
1533+
RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
1534+
ArrayRef<Value *> RedOps) {
1535+
unsigned VF = Src->getType()->getVectorNumElements();
1536+
1537+
// Extract and apply reduction ops in ascending order:
1538+
// e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
1539+
Value *Result = Acc;
1540+
for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) {
1541+
Value *Ext =
1542+
Builder.CreateExtractElement(Src, Builder.getInt32(ExtractIdx));
1543+
1544+
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1545+
Result = Builder.CreateBinOp((Instruction::BinaryOps)Op, Result, Ext,
1546+
"bin.rdx");
1547+
} else {
1548+
assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid &&
1549+
"Invalid min/max");
1550+
Result = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, Result,
1551+
Ext);
1552+
}
1553+
1554+
if (!RedOps.empty())
1555+
propagateIRFlags(Result, RedOps);
1556+
}
1557+
1558+
return Result;
1559+
}
1560+
15291561
// Helper to generate a log2 shuffle reduction.
15301562
Value *
15311563
llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,

Diff for: llvm/test/CodeGen/Generic/expand-experimental-reductions.ll

+36-8
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,15 @@ entry:
117117
define float @fadd_f32_strict(<4 x float> %vec) {
118118
; CHECK-LABEL: @fadd_f32_strict(
119119
; CHECK-NEXT: entry:
120-
; CHECK-NEXT: [[R:%.*]] = call float @llvm.experimental.vector.reduce.fadd.f32.f32.v4f32(float undef, <4 x float> [[VEC:%.*]])
121-
; CHECK-NEXT: ret float [[R]]
120+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x float> [[VEC:%.*]], i32 0
121+
; CHECK-NEXT: [[BIN_RDX:%.*]] = fadd float undef, [[TMP0]]
122+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[VEC]], i32 1
123+
; CHECK-NEXT: [[BIN_RDX1:%.*]] = fadd float [[BIN_RDX]], [[TMP1]]
124+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[VEC]], i32 2
125+
; CHECK-NEXT: [[BIN_RDX2:%.*]] = fadd float [[BIN_RDX1]], [[TMP2]]
126+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[VEC]], i32 3
127+
; CHECK-NEXT: [[BIN_RDX3:%.*]] = fadd float [[BIN_RDX2]], [[TMP3]]
128+
; CHECK-NEXT: ret float [[BIN_RDX3]]
122129
;
123130
entry:
124131
%r = call float @llvm.experimental.vector.reduce.fadd.f32.v4f32(float undef, <4 x float> %vec)
@@ -128,8 +135,15 @@ entry:
128135
define float @fadd_f32_strict_accum(float %accum, <4 x float> %vec) {
129136
; CHECK-LABEL: @fadd_f32_strict_accum(
130137
; CHECK-NEXT: entry:
131-
; CHECK-NEXT: [[R:%.*]] = call float @llvm.experimental.vector.reduce.fadd.f32.f32.v4f32(float [[ACCUM:%.*]], <4 x float> [[VEC:%.*]])
132-
; CHECK-NEXT: ret float [[R]]
138+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x float> [[VEC:%.*]], i32 0
139+
; CHECK-NEXT: [[BIN_RDX:%.*]] = fadd float [[ACCUM:%.*]], [[TMP0]]
140+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[VEC]], i32 1
141+
; CHECK-NEXT: [[BIN_RDX1:%.*]] = fadd float [[BIN_RDX]], [[TMP1]]
142+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[VEC]], i32 2
143+
; CHECK-NEXT: [[BIN_RDX2:%.*]] = fadd float [[BIN_RDX1]], [[TMP2]]
144+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[VEC]], i32 3
145+
; CHECK-NEXT: [[BIN_RDX3:%.*]] = fadd float [[BIN_RDX2]], [[TMP3]]
146+
; CHECK-NEXT: ret float [[BIN_RDX3]]
133147
;
134148
entry:
135149
%r = call float @llvm.experimental.vector.reduce.fadd.f32.v4f32(float %accum, <4 x float> %vec)
@@ -169,8 +183,15 @@ entry:
169183
define float @fmul_f32_strict(<4 x float> %vec) {
170184
; CHECK-LABEL: @fmul_f32_strict(
171185
; CHECK-NEXT: entry:
172-
; CHECK-NEXT: [[R:%.*]] = call float @llvm.experimental.vector.reduce.fmul.f32.f32.v4f32(float undef, <4 x float> [[VEC:%.*]])
173-
; CHECK-NEXT: ret float [[R]]
186+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x float> [[VEC:%.*]], i32 0
187+
; CHECK-NEXT: [[BIN_RDX:%.*]] = fmul float undef, [[TMP0]]
188+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[VEC]], i32 1
189+
; CHECK-NEXT: [[BIN_RDX1:%.*]] = fmul float [[BIN_RDX]], [[TMP1]]
190+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[VEC]], i32 2
191+
; CHECK-NEXT: [[BIN_RDX2:%.*]] = fmul float [[BIN_RDX1]], [[TMP2]]
192+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[VEC]], i32 3
193+
; CHECK-NEXT: [[BIN_RDX3:%.*]] = fmul float [[BIN_RDX2]], [[TMP3]]
194+
; CHECK-NEXT: ret float [[BIN_RDX3]]
174195
;
175196
entry:
176197
%r = call float @llvm.experimental.vector.reduce.fmul.f32.v4f32(float undef, <4 x float> %vec)
@@ -180,8 +201,15 @@ entry:
180201
define float @fmul_f32_strict_accum(float %accum, <4 x float> %vec) {
181202
; CHECK-LABEL: @fmul_f32_strict_accum(
182203
; CHECK-NEXT: entry:
183-
; CHECK-NEXT: [[R:%.*]] = call float @llvm.experimental.vector.reduce.fmul.f32.f32.v4f32(float [[ACCUM:%.*]], <4 x float> [[VEC:%.*]])
184-
; CHECK-NEXT: ret float [[R]]
204+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x float> [[VEC:%.*]], i32 0
205+
; CHECK-NEXT: [[BIN_RDX:%.*]] = fmul float [[ACCUM:%.*]], [[TMP0]]
206+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[VEC]], i32 1
207+
; CHECK-NEXT: [[BIN_RDX1:%.*]] = fmul float [[BIN_RDX]], [[TMP1]]
208+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[VEC]], i32 2
209+
; CHECK-NEXT: [[BIN_RDX2:%.*]] = fmul float [[BIN_RDX1]], [[TMP2]]
210+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[VEC]], i32 3
211+
; CHECK-NEXT: [[BIN_RDX3:%.*]] = fmul float [[BIN_RDX2]], [[TMP3]]
212+
; CHECK-NEXT: ret float [[BIN_RDX3]]
185213
;
186214
entry:
187215
%r = call float @llvm.experimental.vector.reduce.fmul.f32.v4f32(float %accum, <4 x float> %vec)

0 commit comments

Comments
 (0)