Skip to content

Commit 93f8cc1

Browse files
committed
Relax constraints for reduction vectorization
Summary: Gating vectorizing reductions on *all* fastmath flags seems unnecessary; `reassoc` should be sufficient. Reviewers: tvvikram, mkuper, kristof.beyls, sdesmalen, Ayal Reviewed By: sdesmalen Subscribers: dcaballe, huntergr, jmolloy, mcrosier, jlebar, bixia, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D57728 llvm-svn: 355868
1 parent b6d322b commit 93f8cc1

File tree

9 files changed

+180
-34
lines changed

9 files changed

+180
-34
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

+11-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ class RecurrenceDescriptor {
8989
RecurrenceDescriptor() = default;
9090

9191
RecurrenceDescriptor(Value *Start, Instruction *Exit, RecurrenceKind K,
92-
MinMaxRecurrenceKind MK, Instruction *UAI, Type *RT,
93-
bool Signed, SmallPtrSetImpl<Instruction *> &CI)
94-
: StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxKind(MK),
95-
UnsafeAlgebraInst(UAI), RecurrenceType(RT), IsSigned(Signed) {
92+
FastMathFlags FMF, MinMaxRecurrenceKind MK,
93+
Instruction *UAI, Type *RT, bool Signed,
94+
SmallPtrSetImpl<Instruction *> &CI)
95+
: StartValue(Start), LoopExitInstr(Exit), Kind(K), FMF(FMF),
96+
MinMaxKind(MK), UnsafeAlgebraInst(UAI), RecurrenceType(RT),
97+
IsSigned(Signed) {
9698
CastInsts.insert(CI.begin(), CI.end());
9799
}
98100

@@ -198,6 +200,8 @@ class RecurrenceDescriptor {
198200

199201
MinMaxRecurrenceKind getMinMaxRecurrenceKind() { return MinMaxKind; }
200202

203+
FastMathFlags getFastMathFlags() { return FMF; }
204+
201205
TrackingVH<Value> getRecurrenceStartValue() { return StartValue; }
202206

203207
Instruction *getLoopExitInstr() { return LoopExitInstr; }
@@ -237,6 +241,9 @@ class RecurrenceDescriptor {
237241
Instruction *LoopExitInstr = nullptr;
238242
// The kind of the recurrence.
239243
RecurrenceKind Kind = RK_NoRecurrence;
244+
// The fast-math flags on the recurrent instructions. We propagate these
245+
// fast-math flags into the vectorized FP instructions we generate.
246+
FastMathFlags FMF;
240247
// If this a min/max recurrence the kind of recurrence.
241248
MinMaxRecurrenceKind MinMaxKind = MRK_Invalid;
242249
// First occurrence of unasfe algebra in the PHI's use-chain.

llvm/include/llvm/IR/Operator.h

+6
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,12 @@ class FastMathFlags {
187187

188188
FastMathFlags() = default;
189189

190+
static FastMathFlags getFast() {
191+
FastMathFlags FMF;
192+
FMF.setFast();
193+
return FMF;
194+
}
195+
190196
bool any() const { return Flags != 0; }
191197
bool none() const { return Flags == 0; }
192198
bool all() const { return Flags == ~0U; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

+2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, unsigned Op,
296296
Value *getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
297297
RecurrenceDescriptor::MinMaxRecurrenceKind
298298
MinMaxKind = RecurrenceDescriptor::MRK_Invalid,
299+
FastMathFlags FMF = FastMathFlags(),
299300
ArrayRef<Value *> RedOps = None);
300301

301302
/// Create a target reduction of the given vector. The reduction operation
@@ -308,6 +309,7 @@ Value *createSimpleTargetReduction(IRBuilder<> &B,
308309
unsigned Opcode, Value *Src,
309310
TargetTransformInfo::ReductionFlags Flags =
310311
TargetTransformInfo::ReductionFlags(),
312+
FastMathFlags FMF = FastMathFlags(),
311313
ArrayRef<Value *> RedOps = None);
312314

313315
/// Create a generic target reduction using a recurrence descriptor \p Desc

llvm/lib/Analysis/IVDescriptors.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
251251
Worklist.push_back(Start);
252252
VisitedInsts.insert(Start);
253253

254+
// Start with all flags set because we will intersect this with the reduction
255+
// flags from all the reduction operations.
256+
FastMathFlags FMF = FastMathFlags::getFast();
257+
254258
// A value in the reduction can be used:
255259
// - By the reduction:
256260
// - Reduction operation:
@@ -296,6 +300,8 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
296300
ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr);
297301
if (!ReduxDesc.isRecurrence())
298302
return false;
303+
if (isa<FPMathOperator>(ReduxDesc.getPatternInst()))
304+
FMF &= ReduxDesc.getPatternInst()->getFastMathFlags();
299305
}
300306

301307
bool IsASelect = isa<SelectInst>(Cur);
@@ -441,7 +447,7 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
441447

442448
// Save the description of this reduction variable.
443449
RecurrenceDescriptor RD(
444-
RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(),
450+
RdxStart, ExitInstruction, Kind, FMF, ReduxDesc.getMinMaxKind(),
445451
ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts);
446452
RedDes = RD;
447453

@@ -550,7 +556,7 @@ RecurrenceDescriptor::InstDesc
550556
RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind,
551557
InstDesc &Prev, bool HasFunNoNaNAttr) {
552558
Instruction *UAI = Prev.getUnsafeAlgebraInst();
553-
if (!UAI && isa<FPMathOperator>(I) && !I->isFast())
559+
if (!UAI && isa<FPMathOperator>(I) && !I->hasAllowReassoc())
554560
UAI = I; // Found an unsafe (unvectorizable) algebra instruction.
555561

556562
switch (I->getOpcode()) {

llvm/lib/CodeGen/ExpandReductions.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,11 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
118118
}
119119
if (!TTI->shouldExpandReduction(II))
120120
continue;
121+
FastMathFlags FMF =
122+
isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
121123
Value *Rdx =
122124
IsOrdered ? getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), MRK)
123-
: getShuffleReduction(Builder, Vec, getOpcode(ID), MRK);
125+
: getShuffleReduction(Builder, Vec, getOpcode(ID), MRK, FMF);
124126
II->replaceAllUsesWith(Rdx);
125127
II->eraseFromParent();
126128
Changed = true;

llvm/lib/Transforms/Utils/LoopUtils.cpp

+26-20
Original file line numberDiff line numberDiff line change
@@ -671,13 +671,9 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
671671
return true;
672672
}
673673

674-
/// Adds a 'fast' flag to floating point operations.
675-
static Value *addFastMathFlag(Value *V) {
676-
if (isa<FPMathOperator>(V)) {
677-
FastMathFlags Flags;
678-
Flags.setFast();
679-
cast<Instruction>(V)->setFastMathFlags(Flags);
680-
}
674+
static Value *addFastMathFlag(Value *V, FastMathFlags FMF) {
675+
if (isa<FPMathOperator>(V))
676+
cast<Instruction>(V)->setFastMathFlags(FMF);
681677
return V;
682678
}
683679

@@ -761,7 +757,7 @@ llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src,
761757
Value *
762758
llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
763759
RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
764-
ArrayRef<Value *> RedOps) {
760+
FastMathFlags FMF, ArrayRef<Value *> RedOps) {
765761
unsigned VF = Src->getType()->getVectorNumElements();
766762
// VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
767763
// and vector ops, reducing the set of values being computed by half each
@@ -786,7 +782,8 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
786782
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
787783
// Floating point operations had to be 'fast' to enable the reduction.
788784
TmpVec = addFastMathFlag(Builder.CreateBinOp((Instruction::BinaryOps)Op,
789-
TmpVec, Shuf, "bin.rdx"));
785+
TmpVec, Shuf, "bin.rdx"),
786+
FMF);
790787
} else {
791788
assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid &&
792789
"Invalid min/max");
@@ -803,7 +800,7 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
803800
/// flags (if generating min/max reductions).
804801
Value *llvm::createSimpleTargetReduction(
805802
IRBuilder<> &Builder, const TargetTransformInfo *TTI, unsigned Opcode,
806-
Value *Src, TargetTransformInfo::ReductionFlags Flags,
803+
Value *Src, TargetTransformInfo::ReductionFlags Flags, FastMathFlags FMF,
807804
ArrayRef<Value *> RedOps) {
808805
assert(isa<VectorType>(Src->getType()) && "Type must be a vector");
809806

@@ -873,7 +870,7 @@ Value *llvm::createSimpleTargetReduction(
873870
}
874871
if (TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags))
875872
return BuildFunc();
876-
return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps);
873+
return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, FMF, RedOps);
877874
}
878875

879876
/// Create a vector reduction using a given recurrence descriptor.
@@ -888,28 +885,37 @@ Value *llvm::createTargetReduction(IRBuilder<> &B,
888885
Flags.NoNaN = NoNaN;
889886
switch (RecKind) {
890887
case RD::RK_FloatAdd:
891-
return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags);
888+
return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags,
889+
Desc.getFastMathFlags());
892890
case RD::RK_FloatMult:
893-
return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags);
891+
return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags,
892+
Desc.getFastMathFlags());
894893
case RD::RK_IntegerAdd:
895-
return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags);
894+
return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags,
895+
Desc.getFastMathFlags());
896896
case RD::RK_IntegerMult:
897-
return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags);
897+
return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags,
898+
Desc.getFastMathFlags());
898899
case RD::RK_IntegerAnd:
899-
return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags);
900+
return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags,
901+
Desc.getFastMathFlags());
900902
case RD::RK_IntegerOr:
901-
return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags);
903+
return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags,
904+
Desc.getFastMathFlags());
902905
case RD::RK_IntegerXor:
903-
return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags);
906+
return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags,
907+
Desc.getFastMathFlags());
904908
case RD::RK_IntegerMinMax: {
905909
RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind();
906910
Flags.IsMaxOp = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax);
907911
Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin);
908-
return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags);
912+
return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags,
913+
Desc.getFastMathFlags());
909914
}
910915
case RD::RK_FloatMinMax: {
911916
Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind() == RD::MRK_FloatMax;
912-
return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags);
917+
return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags,
918+
Desc.getFastMathFlags());
913919
}
914920
default:
915921
llvm_unreachable("Unhandled RecKind");

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,14 @@ static unsigned getReciprocalPredBlockProb() { return 2; }
319319

320320
/// A helper function that adds a 'fast' flag to floating-point operations.
321321
static Value *addFastMathFlag(Value *V) {
322-
if (isa<FPMathOperator>(V)) {
323-
FastMathFlags Flags;
324-
Flags.setFast();
325-
cast<Instruction>(V)->setFastMathFlags(Flags);
326-
}
322+
if (isa<FPMathOperator>(V))
323+
cast<Instruction>(V)->setFastMathFlags(FastMathFlags::getFast());
324+
return V;
325+
}
326+
327+
static Value *addFastMathFlag(Value *V, FastMathFlags FMF) {
328+
if (isa<FPMathOperator>(V))
329+
cast<Instruction>(V)->setFastMathFlags(FMF);
327330
return V;
328331
}
329332

@@ -3612,7 +3615,8 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
36123615
// Floating point operations had to be 'fast' to enable the reduction.
36133616
ReducedPartRdx = addFastMathFlag(
36143617
Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart,
3615-
ReducedPartRdx, "bin.rdx"));
3618+
ReducedPartRdx, "bin.rdx"),
3619+
RdxDesc.getFastMathFlags());
36163620
else
36173621
ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx,
36183622
RdxPart);

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -5929,7 +5929,8 @@ class HorizontalReduction {
59295929
if (!IsPairwiseReduction)
59305930
return createSimpleTargetReduction(
59315931
Builder, TTI, ReductionData.getOpcode(), VectorizedValue,
5932-
ReductionData.getFlags(), ReductionOps.back());
5932+
ReductionData.getFlags(), FastMathFlags::getFast(),
5933+
ReductionOps.back());
59335934

59345935
Value *TmpVec = VectorizedValue;
59355936
for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
; RUN: opt -S -loop-vectorize < %s | FileCheck %s
2+
3+
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-unknown-linux-gnu"
5+
6+
define float @reduction_sum_float_ieee(i32 %n, float* %array) {
7+
; CHECK-LABEL: define float @reduction_sum_float_ieee(
8+
entry:
9+
%entry.cond = icmp ne i32 0, 4096
10+
br i1 %entry.cond, label %loop, label %loop.exit
11+
12+
loop:
13+
%idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
14+
%sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
15+
%address = getelementptr float, float* %array, i32 %idx
16+
%value = load float, float* %address
17+
%sum.inc = fadd float %sum, %value
18+
%idx.inc = add i32 %idx, 1
19+
%be.cond = icmp ne i32 %idx.inc, 4096
20+
br i1 %be.cond, label %loop, label %loop.exit
21+
22+
loop.exit:
23+
%sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
24+
; CHECK-NOT: %wide.load = load <4 x float>, <4 x float>*
25+
; CHECK: ret float %sum.lcssa
26+
ret float %sum.lcssa
27+
}
28+
29+
define float @reduction_sum_float_fastmath(i32 %n, float* %array) {
30+
; CHECK-LABEL: define float @reduction_sum_float_fastmath(
31+
; CHECK: fadd fast <4 x float>
32+
; CHECK: fadd fast <4 x float>
33+
; CHECK: fadd fast <4 x float>
34+
; CHECK: fadd fast <4 x float>
35+
; CHECK: fadd fast <4 x float>
36+
entry:
37+
%entry.cond = icmp ne i32 0, 4096
38+
br i1 %entry.cond, label %loop, label %loop.exit
39+
40+
loop:
41+
%idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
42+
%sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
43+
%address = getelementptr float, float* %array, i32 %idx
44+
%value = load float, float* %address
45+
%sum.inc = fadd fast float %sum, %value
46+
%idx.inc = add i32 %idx, 1
47+
%be.cond = icmp ne i32 %idx.inc, 4096
48+
br i1 %be.cond, label %loop, label %loop.exit
49+
50+
loop.exit:
51+
%sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
52+
; CHECK: ret float %sum.lcssa
53+
ret float %sum.lcssa
54+
}
55+
56+
define float @reduction_sum_float_only_reassoc(i32 %n, float* %array) {
57+
; CHECK-LABEL: define float @reduction_sum_float_only_reassoc(
58+
; CHECK-NOT: fadd fast
59+
; CHECK: fadd reassoc <4 x float>
60+
; CHECK: fadd reassoc <4 x float>
61+
; CHECK: fadd reassoc <4 x float>
62+
; CHECK: fadd reassoc <4 x float>
63+
; CHECK: fadd reassoc <4 x float>
64+
65+
entry:
66+
%entry.cond = icmp ne i32 0, 4096
67+
br i1 %entry.cond, label %loop, label %loop.exit
68+
69+
loop:
70+
%idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
71+
%sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
72+
%address = getelementptr float, float* %array, i32 %idx
73+
%value = load float, float* %address
74+
%sum.inc = fadd reassoc float %sum, %value
75+
%idx.inc = add i32 %idx, 1
76+
%be.cond = icmp ne i32 %idx.inc, 4096
77+
br i1 %be.cond, label %loop, label %loop.exit
78+
79+
loop.exit:
80+
%sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
81+
; CHECK: ret float %sum.lcssa
82+
ret float %sum.lcssa
83+
}
84+
85+
define float @reduction_sum_float_only_reassoc_and_contract(i32 %n, float* %array) {
86+
; CHECK-LABEL: define float @reduction_sum_float_only_reassoc_and_contract(
87+
; CHECK-NOT: fadd fast
88+
; CHECK: fadd reassoc contract <4 x float>
89+
; CHECK: fadd reassoc contract <4 x float>
90+
; CHECK: fadd reassoc contract <4 x float>
91+
; CHECK: fadd reassoc contract <4 x float>
92+
; CHECK: fadd reassoc contract <4 x float>
93+
94+
entry:
95+
%entry.cond = icmp ne i32 0, 4096
96+
br i1 %entry.cond, label %loop, label %loop.exit
97+
98+
loop:
99+
%idx = phi i32 [ 0, %entry ], [ %idx.inc, %loop ]
100+
%sum = phi float [ 0.000000e+00, %entry ], [ %sum.inc, %loop ]
101+
%address = getelementptr float, float* %array, i32 %idx
102+
%value = load float, float* %address
103+
%sum.inc = fadd reassoc contract float %sum, %value
104+
%idx.inc = add i32 %idx, 1
105+
%be.cond = icmp ne i32 %idx.inc, 4096
106+
br i1 %be.cond, label %loop, label %loop.exit
107+
108+
loop.exit:
109+
%sum.lcssa = phi float [ %sum.inc, %loop ], [ 0.000000e+00, %entry ]
110+
; CHECK: ret float %sum.lcssa
111+
ret float %sum.lcssa
112+
}

0 commit comments

Comments
 (0)