Skip to content

Commit 01d8759

Browse files
committed
[IR][ShuffleVector] Introduce isReplicationMask() matcher
Avid readers of this saga may recall from previous installments, that replication mask replicates (lol) each of the `VF` elements in a vector `ReplicationFactor` times. For example, the mask for `ReplicationFactor=3` and `VF=4` is: `<0,0,0,1,1,1,2,2,2,3,3,3>`. More importantly, replication mask is used by LoopVectorizer when using masked interleaved memory operations. As discussed in previous installments, while it is used by LV, and we **seem** to support masked interleaved memory operations on X86, it's support in cost model leaves a lot to be desired: until basically yesterday even for AVX512 we had no cost model for it. As it has been witnessed in the recent AVX2 `X86TTIImpl::getInterleavedMemoryOpCost()` costmodel patches, while it is hard-enough to query the cost of a particular assembly sequence [from llvm-mca], afterwards the check lines LV costmodel tests must be updated manually. This is, at the very least, boring. Okay, now we have decent costmodel coverage for interleaving shuffles, but now basically the same mind-killing sequence has to be performed for replication mask. I think we can improve at least the second half of the problem, by teaching the `TargetTransformInfoImplCRTPBase::getUserCost()` to recognize `Instruction::ShuffleVector` that are repetition masks, adding exhaustive test coverage using `-cost-model -analyze` + `utils/update_analyze_test_checks.py` This way we can have good exhaustive coverage for cost model, and only basic coverage for the LV costmodel. This patch adds precise undef-aware `isReplicationMask()`, with exhaustive test coverage. * `InstructionsTest.ShuffleMaskIsReplicationMask` shows that it correctly detects all the known masks. * `InstructionsTest.ShuffleMaskIsReplicationMask_undef` shows that replacing some mask elements in a known replication mask still allows us to recognize it as a replication mask. Note, with enough undef elts, we may detect a different tuple. * `InstructionsTest.ShuffleMaskIsReplicationMask_Exhaustive_Correctness` shows that if we detected the replication mask with given params, then if we actually generate a true replication mask with said params, it matches element-wise ignoring undef mask elements. Reviewed By: spatel Differential Revision: https://reviews.llvm.org/D113214
1 parent 7a98761 commit 01d8759

File tree

3 files changed

+183
-1
lines changed

3 files changed

+183
-1
lines changed

llvm/include/llvm/IR/Instructions.h

+28
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,34 @@ class ShuffleVectorInst : public Instruction {
23542354
return isInsertSubvectorMask(ShuffleMask, NumSrcElts, NumSubElts, Index);
23552355
}
23562356

2357+
/// Return true if this shuffle mask replicates each of the \p VF elements
2358+
/// in a vector \p ReplicationFactor times.
2359+
/// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
2360+
/// <0,0,0,1,1,1,2,2,2,3,3,3>
2361+
static bool isReplicationMask(ArrayRef<int> Mask, int &ReplicationFactor,
2362+
int &VF);
2363+
static bool isReplicationMask(const Constant *Mask, int &ReplicationFactor,
2364+
int &VF) {
2365+
assert(Mask->getType()->isVectorTy() && "Shuffle needs vector constant.");
2366+
// Not possible to express a shuffle mask for a scalable vector for this
2367+
// case.
2368+
if (isa<ScalableVectorType>(Mask->getType()))
2369+
return false;
2370+
SmallVector<int, 16> MaskAsInts;
2371+
getShuffleMask(Mask, MaskAsInts);
2372+
return isReplicationMask(MaskAsInts, ReplicationFactor, VF);
2373+
}
2374+
2375+
/// Return true if this shuffle mask is an replication mask.
2376+
bool isReplicationMask(int &ReplicationFactor, int &VF) const {
2377+
// Not possible to express a shuffle mask for a scalable vector for this
2378+
// case.
2379+
if (isa<ScalableVectorType>(getType()))
2380+
return false;
2381+
2382+
return isReplicationMask(ShuffleMask, ReplicationFactor, VF);
2383+
}
2384+
23572385
/// Change values in a shuffle permute mask assuming the two vector operands
23582386
/// of length InVecNumElts have swapped position.
23592387
static void commuteShuffleMask(MutableArrayRef<int> Mask,

llvm/lib/IR/Instructions.cpp

+66
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,72 @@ bool ShuffleVectorInst::isConcat() const {
24362436
return isIdentityMaskImpl(getShuffleMask(), NumMaskElts);
24372437
}
24382438

2439+
static bool isReplicationMaskWithParams(ArrayRef<int> Mask,
2440+
int ReplicationFactor, int VF) {
2441+
assert(Mask.size() == (unsigned)ReplicationFactor * VF &&
2442+
"Unexpected mask size.");
2443+
2444+
for (int CurrElt : seq(0, VF)) {
2445+
ArrayRef<int> CurrSubMask = Mask.take_front(ReplicationFactor);
2446+
assert(CurrSubMask.size() == (unsigned)ReplicationFactor &&
2447+
"Run out of mask?");
2448+
Mask = Mask.drop_front(ReplicationFactor);
2449+
if (!all_of(CurrSubMask, [CurrElt](int MaskElt) {
2450+
return MaskElt == UndefMaskElem || MaskElt == CurrElt;
2451+
}))
2452+
return false;
2453+
}
2454+
assert(Mask.empty() && "Did not consume the whole mask?");
2455+
2456+
return true;
2457+
}
2458+
2459+
bool ShuffleVectorInst::isReplicationMask(ArrayRef<int> Mask,
2460+
int &ReplicationFactor, int &VF) {
2461+
// undef-less case is trivial.
2462+
if (none_of(Mask, [](int MaskElt) { return MaskElt == UndefMaskElem; })) {
2463+
ReplicationFactor =
2464+
Mask.take_while([](int MaskElt) { return MaskElt == 0; }).size();
2465+
if (ReplicationFactor == 0 || Mask.size() % ReplicationFactor != 0)
2466+
return false;
2467+
VF = Mask.size() / ReplicationFactor;
2468+
return isReplicationMaskWithParams(Mask, ReplicationFactor, VF);
2469+
}
2470+
2471+
// However, if the mask contains undef's, we have to enumerate possible tuples
2472+
// and pick one. There are bounds on replication factor: [1, mask size]
2473+
// (where RF=1 is an identity shuffle, RF=mask size is a broadcast shuffle)
2474+
// Additionally, mask size is a replication factor multiplied by vector size,
2475+
// which further significantly reduces the search space.
2476+
2477+
// Before doing that, let's perform basic sanity check first.
2478+
int Largest = -1;
2479+
for (int MaskElt : Mask) {
2480+
if (MaskElt == UndefMaskElem)
2481+
continue;
2482+
// Elements must be in non-decreasing order.
2483+
if (MaskElt < Largest)
2484+
return false;
2485+
Largest = std::max(Largest, MaskElt);
2486+
}
2487+
2488+
// Prefer larger replication factor if all else equal.
2489+
for (int PossibleReplicationFactor :
2490+
reverse(seq_inclusive<unsigned>(1, Mask.size()))) {
2491+
if (Mask.size() % PossibleReplicationFactor != 0)
2492+
continue;
2493+
int PossibleVF = Mask.size() / PossibleReplicationFactor;
2494+
if (!isReplicationMaskWithParams(Mask, PossibleReplicationFactor,
2495+
PossibleVF))
2496+
continue;
2497+
ReplicationFactor = PossibleReplicationFactor;
2498+
VF = PossibleVF;
2499+
return true;
2500+
}
2501+
2502+
return false;
2503+
}
2504+
24392505
//===----------------------------------------------------------------------===//
24402506
// InsertValueInst Class
24412507
//===----------------------------------------------------------------------===//

llvm/unittests/IR/InstructionsTest.cpp

+89-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "llvm/AsmParser/Parser.h"
109
#include "llvm/IR/Instructions.h"
10+
#include "llvm/ADT/CombinationGenerator.h"
1111
#include "llvm/ADT/STLExtras.h"
1212
#include "llvm/Analysis/ValueTracking.h"
13+
#include "llvm/Analysis/VectorUtils.h"
14+
#include "llvm/AsmParser/Parser.h"
1315
#include "llvm/IR/BasicBlock.h"
1416
#include "llvm/IR/Constants.h"
1517
#include "llvm/IR/DataLayout.h"
@@ -1115,6 +1117,92 @@ TEST(InstructionsTest, ShuffleMaskQueries) {
11151117
delete Id15;
11161118
}
11171119

1120+
TEST(InstructionsTest, ShuffleMaskIsReplicationMask) {
1121+
for (int ReplicationFactor : seq_inclusive(1, 8)) {
1122+
for (int VF : seq_inclusive(1, 8)) {
1123+
const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF);
1124+
int GuessedReplicationFactor = -1, GuessedVF = -1;
1125+
EXPECT_TRUE(ShuffleVectorInst::isReplicationMask(
1126+
ReplicatedMask, GuessedReplicationFactor, GuessedVF));
1127+
EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor);
1128+
EXPECT_EQ(GuessedVF, VF);
1129+
}
1130+
}
1131+
}
1132+
1133+
TEST(InstructionsTest, ShuffleMaskIsReplicationMask_undef) {
1134+
for (int ReplicationFactor : seq_inclusive(1, 6)) {
1135+
for (int VF : seq_inclusive(1, 4)) {
1136+
const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF);
1137+
int GuessedReplicationFactor = -1, GuessedVF = -1;
1138+
1139+
// If we change some mask elements to undef, we should still match.
1140+
1141+
SmallVector<SmallVector<bool>> ElementChoices(ReplicatedMask.size(),
1142+
{false, true});
1143+
1144+
CombinationGenerator<bool, decltype(ElementChoices)::value_type,
1145+
/*variable_smallsize=*/4>
1146+
G(ElementChoices);
1147+
1148+
G.generate([&](ArrayRef<bool> UndefOverrides) -> bool {
1149+
SmallVector<int> AdjustedMask;
1150+
AdjustedMask.reserve(ReplicatedMask.size());
1151+
for (auto I : zip(ReplicatedMask, UndefOverrides))
1152+
AdjustedMask.emplace_back(std::get<1>(I) ? -1 : std::get<0>(I));
1153+
assert(AdjustedMask.size() == ReplicatedMask.size() &&
1154+
"Size misprediction");
1155+
1156+
EXPECT_TRUE(ShuffleVectorInst::isReplicationMask(
1157+
AdjustedMask, GuessedReplicationFactor, GuessedVF));
1158+
// Do not check GuessedReplicationFactor and GuessedVF,
1159+
// with enough undef's we may deduce a different tuple.
1160+
1161+
return /*Abort=*/false;
1162+
});
1163+
}
1164+
}
1165+
}
1166+
1167+
TEST(InstructionsTest, ShuffleMaskIsReplicationMask_Exhaustive_Correctness) {
1168+
for (int ShufMaskNumElts : seq_inclusive(1, 8)) {
1169+
SmallVector<int> PossibleShufMaskElts;
1170+
PossibleShufMaskElts.reserve(ShufMaskNumElts + 2);
1171+
for (int PossibleShufMaskElt : seq_inclusive(-1, ShufMaskNumElts))
1172+
PossibleShufMaskElts.emplace_back(PossibleShufMaskElt);
1173+
assert(PossibleShufMaskElts.size() == ShufMaskNumElts + 2U &&
1174+
"Size misprediction");
1175+
1176+
SmallVector<SmallVector<int>> ElementChoices(ShufMaskNumElts,
1177+
PossibleShufMaskElts);
1178+
1179+
CombinationGenerator<int, decltype(ElementChoices)::value_type,
1180+
/*variable_smallsize=*/4>
1181+
G(ElementChoices);
1182+
1183+
G.generate([&](ArrayRef<int> Mask) -> bool {
1184+
int GuessedReplicationFactor = -1, GuessedVF = -1;
1185+
bool Match = ShuffleVectorInst::isReplicationMask(
1186+
Mask, GuessedReplicationFactor, GuessedVF);
1187+
if (!Match)
1188+
return /*Abort=*/false;
1189+
1190+
const auto ActualMask =
1191+
createReplicatedMask(GuessedReplicationFactor, GuessedVF);
1192+
EXPECT_EQ(Mask.size(), ActualMask.size());
1193+
for (auto I : zip(Mask, ActualMask)) {
1194+
int Elt = std::get<0>(I);
1195+
int ActualElt = std::get<0>(I);
1196+
1197+
if (Elt != -1)
1198+
EXPECT_EQ(Elt, ActualElt);
1199+
}
1200+
1201+
return /*Abort=*/false;
1202+
});
1203+
}
1204+
}
1205+
11181206
TEST(InstructionsTest, GetSplat) {
11191207
// Create the elements for various constant vectors.
11201208
LLVMContext Ctx;

0 commit comments

Comments
 (0)