Skip to content

Commit a3d0dce

Browse files
committed
[ARM][TTI] Prevents constants in a min(max) or max(min) pattern from being hoisted when in a loop
Changes TTI function getIntImmCostInst to take an additional Instruction parameter, which enables us to be able to check it is part of a min(max())/max(min()) pattern that will match SSAT. We can then mark the constant used as free to prevent it being hoisted so SSAT can still be generated. Required minor changes in some non-ARM backends to allow for the optional parameter to be included. Differential Revision: https://reviews.llvm.org/D87457
1 parent b5e49e9 commit a3d0dce

18 files changed

+135
-116
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+9-6
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,9 @@ class TargetTransformInfo {
810810
/// Return the expected cost of materialization for the given integer
811811
/// immediate of the specified type for a given instruction. The cost can be
812812
/// zero if the immediate can be folded into the specified instruction.
813-
int getIntImmCostInst(unsigned Opc, unsigned Idx, const APInt &Imm,
814-
Type *Ty, TargetCostKind CostKind) const;
813+
int getIntImmCostInst(unsigned Opc, unsigned Idx, const APInt &Imm, Type *Ty,
814+
TargetCostKind CostKind,
815+
Instruction *Inst = nullptr) const;
815816
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
816817
Type *Ty, TargetCostKind CostKind) const;
817818

@@ -1461,7 +1462,8 @@ class TargetTransformInfo::Concept {
14611462
virtual int getIntImmCost(const APInt &Imm, Type *Ty,
14621463
TargetCostKind CostKind) = 0;
14631464
virtual int getIntImmCostInst(unsigned Opc, unsigned Idx, const APInt &Imm,
1464-
Type *Ty, TargetCostKind CostKind) = 0;
1465+
Type *Ty, TargetCostKind CostKind,
1466+
Instruction *Inst = nullptr) = 0;
14651467
virtual int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
14661468
const APInt &Imm, Type *Ty,
14671469
TargetCostKind CostKind) = 0;
@@ -1850,9 +1852,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
18501852
TargetCostKind CostKind) override {
18511853
return Impl.getIntImmCost(Imm, Ty, CostKind);
18521854
}
1853-
int getIntImmCostInst(unsigned Opc, unsigned Idx, const APInt &Imm,
1854-
Type *Ty, TargetCostKind CostKind) override {
1855-
return Impl.getIntImmCostInst(Opc, Idx, Imm, Ty, CostKind);
1855+
int getIntImmCostInst(unsigned Opc, unsigned Idx, const APInt &Imm, Type *Ty,
1856+
TargetCostKind CostKind,
1857+
Instruction *Inst = nullptr) override {
1858+
return Impl.getIntImmCostInst(Opc, Idx, Imm, Ty, CostKind, Inst);
18561859
}
18571860
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
18581861
Type *Ty, TargetCostKind CostKind) override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ class TargetTransformInfoImplBase {
314314
}
315315

316316
unsigned getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
317-
Type *Ty, TTI::TargetCostKind CostKind) {
317+
Type *Ty, TTI::TargetCostKind CostKind,
318+
Instruction *Inst = nullptr) {
318319
return TTI::TCC_Free;
319320
}
320321

llvm/lib/Analysis/TargetTransformInfo.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,11 @@ int TargetTransformInfo::getIntImmCost(const APInt &Imm, Type *Ty,
570570
return Cost;
571571
}
572572

573-
int
574-
TargetTransformInfo::getIntImmCostInst(unsigned Opcode, unsigned Idx,
575-
const APInt &Imm, Type *Ty,
576-
TTI::TargetCostKind CostKind) const {
577-
int Cost = TTIImpl->getIntImmCostInst(Opcode, Idx, Imm, Ty, CostKind);
573+
int TargetTransformInfo::getIntImmCostInst(unsigned Opcode, unsigned Idx,
574+
const APInt &Imm, Type *Ty,
575+
TTI::TargetCostKind CostKind,
576+
Instruction *Inst) const {
577+
int Cost = TTIImpl->getIntImmCostInst(Opcode, Idx, Imm, Ty, CostKind, Inst);
578578
assert(Cost >= 0 && "TTI should not produce negative costs!");
579579
return Cost;
580580
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ int AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
8484

8585
int AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
8686
const APInt &Imm, Type *Ty,
87-
TTI::TargetCostKind CostKind) {
87+
TTI::TargetCostKind CostKind,
88+
Instruction *Inst) {
8889
assert(Ty->isIntegerTy());
8990

9091
unsigned BitSize = Ty->getPrimitiveSizeInBits();

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
7474
int getIntImmCost(int64_t Val);
7575
int getIntImmCost(const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind);
7676
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
77-
Type *Ty, TTI::TargetCostKind CostKind);
77+
Type *Ty, TTI::TargetCostKind CostKind,
78+
Instruction *Inst = nullptr);
7879
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
7980
Type *Ty, TTI::TargetCostKind CostKind);
8081
TTI::PopcntSupportKind getPopcntSupport(unsigned TyWidth);

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

+43-2
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,43 @@ int ARMTTIImpl::getIntImmCodeSizeCost(unsigned Opcode, unsigned Idx,
284284
return 1;
285285
}
286286

287-
int ARMTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
288-
Type *Ty, TTI::TargetCostKind CostKind) {
287+
// Checks whether Inst is part of a min(max()) or max(min()) pattern
288+
// that will match to an SSAT instruction
289+
static bool isSSATMinMaxPattern(Instruction *Inst, const APInt &Imm) {
290+
Value *LHS, *RHS;
291+
ConstantInt *C;
292+
SelectPatternFlavor InstSPF = matchSelectPattern(Inst, LHS, RHS).Flavor;
293+
294+
if (InstSPF == SPF_SMAX &&
295+
PatternMatch::match(RHS, PatternMatch::m_ConstantInt(C)) &&
296+
C->getValue() == Imm && Imm.isNegative() && (-Imm).isPowerOf2()) {
297+
298+
auto isSSatMin = [&](Value *MinInst) {
299+
if (isa<SelectInst>(MinInst)) {
300+
Value *MinLHS, *MinRHS;
301+
ConstantInt *MinC;
302+
SelectPatternFlavor MinSPF =
303+
matchSelectPattern(MinInst, MinLHS, MinRHS).Flavor;
304+
if (MinSPF == SPF_SMIN &&
305+
PatternMatch::match(MinRHS, PatternMatch::m_ConstantInt(MinC)) &&
306+
MinC->getValue() == ((-Imm) - 1))
307+
return true;
308+
}
309+
return false;
310+
};
311+
312+
if (isSSatMin(Inst->getOperand(1)) ||
313+
(Inst->hasNUses(2) && (isSSatMin(*Inst->user_begin()) ||
314+
isSSatMin(*(++Inst->user_begin())))))
315+
return true;
316+
}
317+
return false;
318+
}
319+
320+
int ARMTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
321+
const APInt &Imm, Type *Ty,
322+
TTI::TargetCostKind CostKind,
323+
Instruction *Inst) {
289324
// Division by a constant can be turned into multiplication, but only if we
290325
// know it's constant. So it's not so much that the immediate is cheap (it's
291326
// not), but that the alternative is worse.
@@ -324,6 +359,12 @@ int ARMTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Im
324359
if (Opcode == Instruction::Xor && Imm.isAllOnesValue())
325360
return 0;
326361

362+
// Ensures negative constant of min(max()) or max(min()) patterns that
363+
// match to SSAT instructions don't get hoisted
364+
if (Inst && ((ST->hasV6Ops() && !ST->isThumb()) || ST->isThumb2()) &&
365+
Ty->getIntegerBitWidth() <= 32 && isSSATMinMaxPattern(Inst, Imm))
366+
return 0;
367+
327368
return getIntImmCost(Imm, Ty, CostKind);
328369
}
329370

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
126126
int getIntImmCost(const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind);
127127

128128
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
129-
Type *Ty, TTI::TargetCostKind CostKind);
129+
Type *Ty, TTI::TargetCostKind CostKind,
130+
Instruction *Inst = nullptr);
130131

131132
/// @}
132133

llvm/lib/Target/Lanai/LanaiTargetTransformInfo.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ class LanaiTTIImpl : public BasicTTIImplBase<LanaiTTIImpl> {
6767
}
6868

6969
int getIntImmCostInst(unsigned Opc, unsigned Idx, const APInt &Imm, Type *Ty,
70-
TTI::TargetCostKind CostKind) {
70+
TTI::TargetCostKind CostKind,
71+
Instruction *Inst = nullptr) {
7172
return getIntImmCost(Imm, Ty, CostKind);
7273
}
7374

llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,10 @@ int PPCTTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
234234

235235
int PPCTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
236236
const APInt &Imm, Type *Ty,
237-
TTI::TargetCostKind CostKind) {
237+
TTI::TargetCostKind CostKind,
238+
Instruction *Inst) {
238239
if (DisablePPCConstHoist)
239-
return BaseT::getIntImmCostInst(Opcode, Idx, Imm, Ty, CostKind);
240+
return BaseT::getIntImmCostInst(Opcode, Idx, Imm, Ty, CostKind, Inst);
240241

241242
assert(Ty->isIntegerTy());
242243

llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class PPCTTIImpl : public BasicTTIImplBase<PPCTTIImpl> {
5252
TTI::TargetCostKind CostKind);
5353

5454
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
55-
Type *Ty, TTI::TargetCostKind CostKind);
55+
Type *Ty, TTI::TargetCostKind CostKind,
56+
Instruction *Inst = nullptr);
5657
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
5758
Type *Ty, TTI::TargetCostKind CostKind);
5859

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ int RISCVTTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
3030
getST()->is64Bit());
3131
}
3232

33-
int RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
34-
Type *Ty, TTI::TargetCostKind CostKind) {
33+
int RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
34+
const APInt &Imm, Type *Ty,
35+
TTI::TargetCostKind CostKind,
36+
Instruction *Inst) {
3537
assert(Ty->isIntegerTy() &&
3638
"getIntImmCost can only estimate cost of materialising integers");
3739

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
4242
TLI(ST->getTargetLowering()) {}
4343

4444
int getIntImmCost(const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind);
45-
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm, Type *Ty,
46-
TTI::TargetCostKind CostKind);
45+
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
46+
Type *Ty, TTI::TargetCostKind CostKind,
47+
Instruction *Inst = nullptr);
4748
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
4849
Type *Ty, TTI::TargetCostKind CostKind);
4950
};

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ int SystemZTTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
6464
}
6565

6666
int SystemZTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
67-
const APInt &Imm, Type *Ty,
68-
TTI::TargetCostKind CostKind) {
67+
const APInt &Imm, Type *Ty,
68+
TTI::TargetCostKind CostKind,
69+
Instruction *Inst) {
6970
assert(Ty->isIntegerTy());
7071

7172
unsigned BitSize = Ty->getPrimitiveSizeInBits();

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
4141
int getIntImmCost(const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind);
4242

4343
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
44-
Type *Ty, TTI::TargetCostKind CostKind);
44+
Type *Ty, TTI::TargetCostKind CostKind,
45+
Instruction *Inst = nullptr);
4546
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
4647
Type *Ty, TTI::TargetCostKind CostKind);
4748

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -3841,8 +3841,10 @@ int X86TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
38413841
return std::max(1, Cost);
38423842
}
38433843

3844-
int X86TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
3845-
Type *Ty, TTI::TargetCostKind CostKind) {
3844+
int X86TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
3845+
const APInt &Imm, Type *Ty,
3846+
TTI::TargetCostKind CostKind,
3847+
Instruction *Inst) {
38463848
assert(Ty->isIntegerTy());
38473849

38483850
unsigned BitSize = Ty->getPrimitiveSizeInBits();

llvm/lib/Target/X86/X86TargetTransformInfo.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,9 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
203203

204204
unsigned getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind);
205205

206-
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm, Type *Ty,
207-
TTI::TargetCostKind CostKind);
206+
int getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm,
207+
Type *Ty, TTI::TargetCostKind CostKind,
208+
Instruction *Inst = nullptr);
208209
int getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
209210
Type *Ty, TTI::TargetCostKind CostKind);
210211
bool isLSRCostLess(TargetTransformInfo::LSRCost &C1,

llvm/lib/Transforms/Scalar/ConstantHoisting.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,9 @@ void ConstantHoistingPass::collectConstantCandidates(
366366
ConstInt->getValue(), ConstInt->getType(),
367367
TargetTransformInfo::TCK_SizeAndLatency);
368368
else
369-
Cost = TTI->getIntImmCostInst(Inst->getOpcode(), Idx, ConstInt->getValue(),
370-
ConstInt->getType(),
371-
TargetTransformInfo::TCK_SizeAndLatency);
369+
Cost = TTI->getIntImmCostInst(
370+
Inst->getOpcode(), Idx, ConstInt->getValue(), ConstInt->getType(),
371+
TargetTransformInfo::TCK_SizeAndLatency, Inst);
372372

373373
// Ignore cheap integer constants.
374374
if (Cost > TargetTransformInfo::TCC_Basic) {
@@ -418,8 +418,9 @@ void ConstantHoistingPass::collectConstantCandidates(
418418
// usually lowered to a load from constant pool. Such operation is unlikely
419419
// to be cheaper than compute it by <Base + Offset>, which can be lowered to
420420
// an ADD instruction or folded into Load/Store instruction.
421-
int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy,
422-
TargetTransformInfo::TCK_SizeAndLatency);
421+
int Cost =
422+
TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy,
423+
TargetTransformInfo::TCK_SizeAndLatency, Inst);
423424
ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV];
424425
ConstCandMapType::iterator Itr;
425426
bool Inserted;

0 commit comments

Comments
 (0)