@@ -671,13 +671,9 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
671
671
return true ;
672
672
}
673
673
674
- // / Adds a 'fast' flag to floating point operations.
675
- static Value *addFastMathFlag (Value *V) {
676
- if (isa<FPMathOperator>(V)) {
677
- FastMathFlags Flags;
678
- Flags.setFast ();
679
- cast<Instruction>(V)->setFastMathFlags (Flags);
680
- }
674
+ static Value *addFastMathFlag (Value *V, FastMathFlags FMF) {
675
+ if (isa<FPMathOperator>(V))
676
+ cast<Instruction>(V)->setFastMathFlags (FMF);
681
677
return V;
682
678
}
683
679
@@ -761,7 +757,7 @@ llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src,
761
757
Value *
762
758
llvm::getShuffleReduction (IRBuilder<> &Builder, Value *Src, unsigned Op,
763
759
RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
764
- ArrayRef<Value *> RedOps) {
760
+ FastMathFlags FMF, ArrayRef<Value *> RedOps) {
765
761
unsigned VF = Src->getType ()->getVectorNumElements ();
766
762
// VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
767
763
// and vector ops, reducing the set of values being computed by half each
@@ -786,7 +782,8 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
786
782
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
787
783
// Floating point operations had to be 'fast' to enable the reduction.
788
784
TmpVec = addFastMathFlag (Builder.CreateBinOp ((Instruction::BinaryOps)Op,
789
- TmpVec, Shuf, " bin.rdx" ));
785
+ TmpVec, Shuf, " bin.rdx" ),
786
+ FMF);
790
787
} else {
791
788
assert (MinMaxKind != RecurrenceDescriptor::MRK_Invalid &&
792
789
" Invalid min/max" );
@@ -803,7 +800,7 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
803
800
// / flags (if generating min/max reductions).
804
801
Value *llvm::createSimpleTargetReduction (
805
802
IRBuilder<> &Builder, const TargetTransformInfo *TTI, unsigned Opcode,
806
- Value *Src, TargetTransformInfo::ReductionFlags Flags,
803
+ Value *Src, TargetTransformInfo::ReductionFlags Flags, FastMathFlags FMF,
807
804
ArrayRef<Value *> RedOps) {
808
805
assert (isa<VectorType>(Src->getType ()) && " Type must be a vector" );
809
806
@@ -873,7 +870,7 @@ Value *llvm::createSimpleTargetReduction(
873
870
}
874
871
if (TTI->useReductionIntrinsic (Opcode, Src->getType (), Flags))
875
872
return BuildFunc ();
876
- return getShuffleReduction (Builder, Src, Opcode, MinMaxKind, RedOps);
873
+ return getShuffleReduction (Builder, Src, Opcode, MinMaxKind, FMF, RedOps);
877
874
}
878
875
879
876
// / Create a vector reduction using a given recurrence descriptor.
@@ -888,28 +885,37 @@ Value *llvm::createTargetReduction(IRBuilder<> &B,
888
885
Flags.NoNaN = NoNaN;
889
886
switch (RecKind) {
890
887
case RD::RK_FloatAdd:
891
- return createSimpleTargetReduction (B, TTI, Instruction::FAdd, Src, Flags);
888
+ return createSimpleTargetReduction (B, TTI, Instruction::FAdd, Src, Flags,
889
+ Desc.getFastMathFlags ());
892
890
case RD::RK_FloatMult:
893
- return createSimpleTargetReduction (B, TTI, Instruction::FMul, Src, Flags);
891
+ return createSimpleTargetReduction (B, TTI, Instruction::FMul, Src, Flags,
892
+ Desc.getFastMathFlags ());
894
893
case RD::RK_IntegerAdd:
895
- return createSimpleTargetReduction (B, TTI, Instruction::Add, Src, Flags);
894
+ return createSimpleTargetReduction (B, TTI, Instruction::Add, Src, Flags,
895
+ Desc.getFastMathFlags ());
896
896
case RD::RK_IntegerMult:
897
- return createSimpleTargetReduction (B, TTI, Instruction::Mul, Src, Flags);
897
+ return createSimpleTargetReduction (B, TTI, Instruction::Mul, Src, Flags,
898
+ Desc.getFastMathFlags ());
898
899
case RD::RK_IntegerAnd:
899
- return createSimpleTargetReduction (B, TTI, Instruction::And, Src, Flags);
900
+ return createSimpleTargetReduction (B, TTI, Instruction::And, Src, Flags,
901
+ Desc.getFastMathFlags ());
900
902
case RD::RK_IntegerOr:
901
- return createSimpleTargetReduction (B, TTI, Instruction::Or, Src, Flags);
903
+ return createSimpleTargetReduction (B, TTI, Instruction::Or, Src, Flags,
904
+ Desc.getFastMathFlags ());
902
905
case RD::RK_IntegerXor:
903
- return createSimpleTargetReduction (B, TTI, Instruction::Xor, Src, Flags);
906
+ return createSimpleTargetReduction (B, TTI, Instruction::Xor, Src, Flags,
907
+ Desc.getFastMathFlags ());
904
908
case RD::RK_IntegerMinMax: {
905
909
RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind ();
906
910
Flags.IsMaxOp = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax);
907
911
Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin);
908
- return createSimpleTargetReduction (B, TTI, Instruction::ICmp, Src, Flags);
912
+ return createSimpleTargetReduction (B, TTI, Instruction::ICmp, Src, Flags,
913
+ Desc.getFastMathFlags ());
909
914
}
910
915
case RD::RK_FloatMinMax: {
911
916
Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind () == RD::MRK_FloatMax;
912
- return createSimpleTargetReduction (B, TTI, Instruction::FCmp, Src, Flags);
917
+ return createSimpleTargetReduction (B, TTI, Instruction::FCmp, Src, Flags,
918
+ Desc.getFastMathFlags ());
913
919
}
914
920
default :
915
921
llvm_unreachable (" Unhandled RecKind" );
0 commit comments