Skip to content

Commit bd7f7e2

Browse files
committedJun 22, 2021
[GlobalISel] Add scalable property to LLT types.
This patch aims to add the scalable property to LLT. The rest of the patch-series changes the interfaces to take/return ElementCount and TypeSize, which both have the ability to represent the scalable property. The changes are mostly mechanical and aim to be non-functional changes for fixed-width vectors. For scalable vectors some unit tests have been added, but no effort has been put into making any of the GlobalISel algorithms work with scalable vectors yet. That will be left as future work. The work is split into a series of 5 patches to make reviews easier. Reviewed By: arsenm Differential Revision: https://reviews.llvm.org/D104450
1 parent d919b73 commit bd7f7e2

File tree

5 files changed

+143
-56
lines changed

5 files changed

+143
-56
lines changed
 

‎llvm/include/llvm/Support/LowLevelTypeImpl.h

+77-27
Original file line numberDiff line numberDiff line change
@@ -42,31 +42,37 @@ class LLT {
4242
/// Get a low-level scalar or aggregate "bag of bits".
4343
static LLT scalar(unsigned SizeInBits) {
4444
assert(SizeInBits > 0 && "invalid scalar size");
45-
return LLT{/*isPointer=*/false, /*isVector=*/false, /*NumElements=*/0,
46-
SizeInBits, /*AddressSpace=*/0};
45+
return LLT{/*isPointer=*/false, /*isVector=*/false,
46+
ElementCount::getFixed(0), SizeInBits,
47+
/*AddressSpace=*/0};
4748
}
4849

4950
/// Get a low-level pointer in the given address space.
5051
static LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
5152
assert(SizeInBits > 0 && "invalid pointer size");
52-
return LLT{/*isPointer=*/true, /*isVector=*/false, /*NumElements=*/0,
53-
SizeInBits, AddressSpace};
53+
return LLT{/*isPointer=*/true, /*isVector=*/false,
54+
ElementCount::getFixed(0), SizeInBits, AddressSpace};
5455
}
5556

5657
/// Get a low-level vector of some number of elements and element width.
5758
/// \p NumElements must be at least 2.
58-
static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits) {
59-
assert(NumElements > 1 && "invalid number of vector elements");
59+
static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits,
60+
bool Scalable = false) {
61+
assert(((!Scalable && NumElements > 1) || NumElements > 0) &&
62+
"invalid number of vector elements");
6063
assert(ScalarSizeInBits > 0 && "invalid vector element size");
61-
return LLT{/*isPointer=*/false, /*isVector=*/true, NumElements,
62-
ScalarSizeInBits, /*AddressSpace=*/0};
64+
return LLT{/*isPointer=*/false, /*isVector=*/true,
65+
ElementCount::get(NumElements, Scalable), ScalarSizeInBits,
66+
/*AddressSpace=*/0};
6367
}
6468

6569
/// Get a low-level vector of some number of elements and element type.
66-
static LLT vector(uint16_t NumElements, LLT ScalarTy) {
67-
assert(NumElements > 1 && "invalid number of vector elements");
70+
static LLT vector(uint16_t NumElements, LLT ScalarTy, bool Scalable = false) {
71+
assert(((!Scalable && NumElements > 1) || NumElements > 0) &&
72+
"invalid number of vector elements");
6873
assert(!ScalarTy.isVector() && "invalid vector element type");
69-
return LLT{ScalarTy.isPointer(), /*isVector=*/true, NumElements,
74+
return LLT{ScalarTy.isPointer(), /*isVector=*/true,
75+
ElementCount::get(NumElements, Scalable),
7076
ScalarTy.getSizeInBits(),
7177
ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0};
7278
}
@@ -79,9 +85,9 @@ class LLT {
7985
return scalarOrVector(NumElements, LLT::scalar(ScalarSize));
8086
}
8187

82-
explicit LLT(bool isPointer, bool isVector, uint16_t NumElements,
88+
explicit LLT(bool isPointer, bool isVector, ElementCount EC,
8389
unsigned SizeInBits, unsigned AddressSpace) {
84-
init(isPointer, isVector, NumElements, SizeInBits, AddressSpace);
90+
init(isPointer, isVector, EC, SizeInBits, AddressSpace);
8591
}
8692
explicit LLT() : IsPointer(false), IsVector(false), RawData(0) {}
8793

@@ -98,18 +104,37 @@ class LLT {
98104
/// Returns the number of elements in a vector LLT. Must only be called on
99105
/// vector types.
100106
uint16_t getNumElements() const {
107+
if (isScalable())
108+
llvm::reportInvalidSizeRequest(
109+
"Possible incorrect use of LLT::getNumElements() for "
110+
"scalable vector. Scalable flag may be dropped, use "
111+
"LLT::getElementCount() instead");
112+
return getElementCount().getKnownMinValue();
113+
}
114+
115+
/// Returns true if the LLT is a scalable vector. Must only be called on
116+
/// vector types.
117+
bool isScalable() const {
118+
assert(isVector() && "Expected a vector type");
119+
return IsPointer ? getFieldValue(PointerVectorScalableFieldInfo)
120+
: getFieldValue(VectorScalableFieldInfo);
121+
}
122+
123+
ElementCount getElementCount() const {
101124
assert(IsVector && "cannot get number of elements on scalar/aggregate");
102-
if (!IsPointer)
103-
return getFieldValue(VectorElementsFieldInfo);
104-
else
105-
return getFieldValue(PointerVectorElementsFieldInfo);
125+
return ElementCount::get(IsPointer
126+
? getFieldValue(PointerVectorElementsFieldInfo)
127+
: getFieldValue(VectorElementsFieldInfo),
128+
isScalable());
106129
}
107130

108131
/// Returns the total size of the type. Must only be called on sized types.
109132
unsigned getSizeInBits() const {
110133
if (isPointer() || isScalar())
111134
return getScalarSizeInBits();
112-
return getScalarSizeInBits() * getNumElements();
135+
// FIXME: This should return a TypeSize in order to work for scalable
136+
// vectors.
137+
return getScalarSizeInBits() * getElementCount().getKnownMinValue();
113138
}
114139

115140
/// Returns the total size of the type in bytes, i.e. number of whole bytes
@@ -125,7 +150,9 @@ class LLT {
125150
/// If this type is a vector, return a vector with the same number of elements
126151
/// but the new element type. Otherwise, return the new element type.
127152
LLT changeElementType(LLT NewEltTy) const {
128-
return isVector() ? LLT::vector(getNumElements(), NewEltTy) : NewEltTy;
153+
return isVector() ? LLT::vector(getElementCount().getKnownMinValue(),
154+
NewEltTy, isScalable())
155+
: NewEltTy;
129156
}
130157

131158
/// If this type is a vector, return a vector with the same number of elements
@@ -134,13 +161,16 @@ class LLT {
134161
LLT changeElementSize(unsigned NewEltSize) const {
135162
assert(!getScalarType().isPointer() &&
136163
"invalid to directly change element size for pointers");
137-
return isVector() ? LLT::vector(getNumElements(), NewEltSize)
164+
return isVector() ? LLT::vector(getElementCount().getKnownMinValue(),
165+
NewEltSize, isScalable())
138166
: LLT::scalar(NewEltSize);
139167
}
140168

141169
/// Return a vector or scalar with the same element type and the new number of
142170
/// elements.
143171
LLT changeNumElements(unsigned NewNumElts) const {
172+
assert((!isVector() || !isScalable()) &&
173+
"Cannot use changeNumElements on a scalable vector");
144174
return LLT::scalarOrVector(NewNumElts, getScalarType());
145175
}
146176

@@ -237,22 +267,37 @@ class LLT {
237267
static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 0};
238268
static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{
239269
24, PointerSizeFieldInfo[0] + PointerSizeFieldInfo[1]};
270+
static_assert((PointerAddressSpaceFieldInfo[0] +
271+
PointerAddressSpaceFieldInfo[1]) <= 62,
272+
"Insufficient bits to encode all data");
240273
/// * Vector-of-non-pointer (isPointer == 0 && isVector == 1):
241274
/// NumElements: 16;
242275
/// SizeOfElement: 32;
276+
/// Scalable: 1;
243277
static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 0};
244278
static const constexpr BitFieldInfo VectorSizeFieldInfo{
245279
32, VectorElementsFieldInfo[0] + VectorElementsFieldInfo[1]};
280+
static const constexpr BitFieldInfo VectorScalableFieldInfo{
281+
1, VectorSizeFieldInfo[0] + VectorSizeFieldInfo[1]};
282+
static_assert((VectorSizeFieldInfo[0] + VectorSizeFieldInfo[1]) <= 62,
283+
"Insufficient bits to encode all data");
246284
/// * Vector-of-pointer (isPointer == 1 && isVector == 1):
247285
/// NumElements: 16;
248286
/// SizeOfElement: 16;
249287
/// AddressSpace: 24;
288+
/// Scalable: 1;
250289
static const constexpr BitFieldInfo PointerVectorElementsFieldInfo{16, 0};
251290
static const constexpr BitFieldInfo PointerVectorSizeFieldInfo{
252291
16,
253292
PointerVectorElementsFieldInfo[1] + PointerVectorElementsFieldInfo[0]};
254293
static const constexpr BitFieldInfo PointerVectorAddressSpaceFieldInfo{
255294
24, PointerVectorSizeFieldInfo[1] + PointerVectorSizeFieldInfo[0]};
295+
static const constexpr BitFieldInfo PointerVectorScalableFieldInfo{
296+
1, PointerVectorAddressSpaceFieldInfo[0] +
297+
PointerVectorAddressSpaceFieldInfo[1]};
298+
static_assert((PointerVectorAddressSpaceFieldInfo[0] +
299+
PointerVectorAddressSpaceFieldInfo[1]) <= 62,
300+
"Insufficient bits to encode all data");
256301

257302
uint64_t IsPointer : 1;
258303
uint64_t IsVector : 1;
@@ -273,8 +318,8 @@ class LLT {
273318
return getMask(FieldInfo) & (RawData >> FieldInfo[1]);
274319
}
275320

276-
void init(bool IsPointer, bool IsVector, uint16_t NumElements,
277-
unsigned SizeInBits, unsigned AddressSpace) {
321+
void init(bool IsPointer, bool IsVector, ElementCount EC, unsigned SizeInBits,
322+
unsigned AddressSpace) {
278323
this->IsPointer = IsPointer;
279324
this->IsVector = IsVector;
280325
if (!IsVector) {
@@ -284,15 +329,20 @@ class LLT {
284329
RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) |
285330
maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo);
286331
} else {
287-
assert(NumElements > 1 && "invalid number of vector elements");
332+
assert(EC.isVector() && "invalid number of vector elements");
288333
if (!IsPointer)
289-
RawData = maskAndShift(NumElements, VectorElementsFieldInfo) |
290-
maskAndShift(SizeInBits, VectorSizeFieldInfo);
334+
RawData =
335+
maskAndShift(EC.getKnownMinValue(), VectorElementsFieldInfo) |
336+
maskAndShift(SizeInBits, VectorSizeFieldInfo) |
337+
maskAndShift(EC.isScalable() ? 1 : 0, VectorScalableFieldInfo);
291338
else
292339
RawData =
293-
maskAndShift(NumElements, PointerVectorElementsFieldInfo) |
340+
maskAndShift(EC.getKnownMinValue(),
341+
PointerVectorElementsFieldInfo) |
294342
maskAndShift(SizeInBits, PointerVectorSizeFieldInfo) |
295-
maskAndShift(AddressSpace, PointerVectorAddressSpaceFieldInfo);
343+
maskAndShift(AddressSpace, PointerVectorAddressSpaceFieldInfo) |
344+
maskAndShift(EC.isScalable() ? 1 : 0,
345+
PointerVectorScalableFieldInfo);
296346
}
297347
}
298348

‎llvm/lib/CodeGen/LowLevelType.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ using namespace llvm;
2020

2121
LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
2222
if (auto VTy = dyn_cast<VectorType>(&Ty)) {
23-
auto NumElements = cast<FixedVectorType>(VTy)->getNumElements();
23+
auto EC = VTy->getElementCount();
2424
LLT ScalarTy = getLLTForType(*VTy->getElementType(), DL);
25-
if (NumElements == 1)
25+
if (EC.isScalar())
2626
return ScalarTy;
27-
return LLT::vector(NumElements, ScalarTy);
27+
return LLT::vector(EC.getKnownMinValue(), ScalarTy, EC.isScalable());
2828
}
2929

3030
if (auto PTy = dyn_cast<PointerType>(&Ty)) {

‎llvm/lib/Support/LowLevelType.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ using namespace llvm;
1818
LLT::LLT(MVT VT) {
1919
if (VT.isVector()) {
2020
init(/*IsPointer=*/false, VT.getVectorNumElements() > 1,
21-
VT.getVectorNumElements(), VT.getVectorElementType().getSizeInBits(),
21+
VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
2222
/*AddressSpace=*/0);
2323
} else if (VT.isValid()) {
2424
// Aggregates are no different from real scalars as far as GlobalISel is
2525
// concerned.
2626
assert(VT.getSizeInBits().isNonZero() && "invalid zero-sized type");
27-
init(/*IsPointer=*/false, /*IsVector=*/false, /*NumElements=*/0,
27+
init(/*IsPointer=*/false, /*IsVector=*/false, ElementCount::getFixed(0),
2828
VT.getSizeInBits(), /*AddressSpace=*/0);
2929
} else {
3030
IsPointer = false;
@@ -34,9 +34,10 @@ LLT::LLT(MVT VT) {
3434
}
3535

3636
void LLT::print(raw_ostream &OS) const {
37-
if (isVector())
38-
OS << "<" << getNumElements() << " x " << getElementType() << ">";
39-
else if (isPointer())
37+
if (isVector()) {
38+
OS << "<";
39+
OS << getElementCount() << " x " << getElementType() << ">";
40+
} else if (isPointer())
4041
OS << "p" << getAddressSpace();
4142
else if (isValid()) {
4243
assert(isScalar() && "unexpected type");
@@ -49,7 +50,9 @@ const constexpr LLT::BitFieldInfo LLT::ScalarSizeFieldInfo;
4950
const constexpr LLT::BitFieldInfo LLT::PointerSizeFieldInfo;
5051
const constexpr LLT::BitFieldInfo LLT::PointerAddressSpaceFieldInfo;
5152
const constexpr LLT::BitFieldInfo LLT::VectorElementsFieldInfo;
53+
const constexpr LLT::BitFieldInfo LLT::VectorScalableFieldInfo;
5254
const constexpr LLT::BitFieldInfo LLT::VectorSizeFieldInfo;
5355
const constexpr LLT::BitFieldInfo LLT::PointerVectorElementsFieldInfo;
56+
const constexpr LLT::BitFieldInfo LLT::PointerVectorScalableFieldInfo;
5457
const constexpr LLT::BitFieldInfo LLT::PointerVectorSizeFieldInfo;
5558
const constexpr LLT::BitFieldInfo LLT::PointerVectorAddressSpaceFieldInfo;

‎llvm/unittests/CodeGen/LowLevelTypeTest.cpp

+39-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/IR/DerivedTypes.h"
1212
#include "llvm/IR/LLVMContext.h"
1313
#include "llvm/IR/Type.h"
14+
#include "llvm/Support/TypeSize.h"
1415
#include "gtest/gtest.h"
1516

1617
using namespace llvm;
@@ -50,13 +51,19 @@ TEST(LowLevelTypeTest, Vector) {
5051
DataLayout DL("");
5152

5253
for (unsigned S : {1U, 17U, 32U, 64U, 0xfffU}) {
53-
for (uint16_t Elts : {2U, 3U, 4U, 32U, 0xffU}) {
54+
for (auto EC :
55+
{ElementCount::getFixed(2), ElementCount::getFixed(3),
56+
ElementCount::getFixed(4), ElementCount::getFixed(32),
57+
ElementCount::getFixed(0xff), ElementCount::getScalable(2),
58+
ElementCount::getScalable(3), ElementCount::getScalable(4),
59+
ElementCount::getScalable(32), ElementCount::getScalable(0xff)}) {
5460
const LLT STy = LLT::scalar(S);
55-
const LLT VTy = LLT::vector(Elts, S);
61+
const LLT VTy = LLT::vector(EC.getKnownMinValue(), S, EC.isScalable());
5662

5763
// Test the alternative vector().
5864
{
59-
const LLT VSTy = LLT::vector(Elts, STy);
65+
const LLT VSTy =
66+
LLT::vector(EC.getKnownMinValue(), STy, EC.isScalable());
6067
EXPECT_EQ(VTy, VSTy);
6168
}
6269

@@ -71,9 +78,10 @@ TEST(LowLevelTypeTest, Vector) {
7178
ASSERT_FALSE(VTy.isPointer());
7279

7380
// Test sizes.
74-
EXPECT_EQ(S * Elts, VTy.getSizeInBits());
7581
EXPECT_EQ(S, VTy.getScalarSizeInBits());
76-
EXPECT_EQ(Elts, VTy.getNumElements());
82+
EXPECT_EQ(EC, VTy.getElementCount());
83+
if (!EC.isScalable())
84+
EXPECT_EQ(S * EC.getFixedValue(), VTy.getSizeInBits());
7785

7886
// Test equality operators.
7987
EXPECT_TRUE(VTy == VTy);
@@ -85,7 +93,7 @@ TEST(LowLevelTypeTest, Vector) {
8593

8694
// Test Type->LLT conversion.
8795
Type *IRSTy = IntegerType::get(C, S);
88-
Type *IRTy = FixedVectorType::get(IRSTy, Elts);
96+
Type *IRTy = VectorType::get(IRSTy, EC);
8997
EXPECT_EQ(VTy, getLLTForType(*IRTy, DL));
9098
}
9199
}
@@ -136,6 +144,22 @@ TEST(LowLevelTypeTest, ChangeElementType) {
136144

137145
EXPECT_EQ(V2P1, V2P0.changeElementType(P1));
138146
EXPECT_EQ(V2S32, V2P0.changeElementType(S32));
147+
148+
// Similar tests for for scalable vectors.
149+
const LLT NXV2S32 = LLT::vector(2, 32, true);
150+
const LLT NXV2S64 = LLT::vector(2, 64, true);
151+
152+
const LLT NXV2P0 = LLT::vector(2, P0, true);
153+
const LLT NXV2P1 = LLT::vector(2, P1, true);
154+
155+
EXPECT_EQ(NXV2S64, NXV2S32.changeElementType(S64));
156+
EXPECT_EQ(NXV2S32, NXV2S64.changeElementType(S32));
157+
158+
EXPECT_EQ(NXV2S64, NXV2S32.changeElementSize(64));
159+
EXPECT_EQ(NXV2S32, NXV2S64.changeElementSize(32));
160+
161+
EXPECT_EQ(NXV2P1, NXV2P0.changeElementType(P1));
162+
EXPECT_EQ(NXV2S32, NXV2P0.changeElementType(S32));
139163
}
140164

141165
TEST(LowLevelTypeTest, ChangeNumElements) {
@@ -191,9 +215,14 @@ TEST(LowLevelTypeTest, Pointer) {
191215
for (unsigned AS : {0U, 1U, 127U, 0xffffU,
192216
static_cast<unsigned>(maxUIntN(23)),
193217
static_cast<unsigned>(maxUIntN(24))}) {
194-
for (unsigned NumElts : {2, 3, 4, 256, 65535}) {
218+
for (ElementCount EC :
219+
{ElementCount::getFixed(2), ElementCount::getFixed(3),
220+
ElementCount::getFixed(4), ElementCount::getFixed(256),
221+
ElementCount::getFixed(65535), ElementCount::getScalable(2),
222+
ElementCount::getScalable(3), ElementCount::getScalable(4),
223+
ElementCount::getScalable(256), ElementCount::getScalable(65535)}) {
195224
const LLT Ty = LLT::pointer(AS, DL.getPointerSizeInBits(AS));
196-
const LLT VTy = LLT::vector(NumElts, Ty);
225+
const LLT VTy = LLT::vector(EC.getKnownMinValue(), Ty, EC.isScalable());
197226

198227
// Test kind.
199228
ASSERT_TRUE(Ty.isValid());
@@ -222,8 +251,8 @@ TEST(LowLevelTypeTest, Pointer) {
222251
// Test Type->LLT conversion.
223252
Type *IRTy = PointerType::get(IntegerType::get(C, 8), AS);
224253
EXPECT_EQ(Ty, getLLTForType(*IRTy, DL));
225-
Type *IRVTy = FixedVectorType::get(
226-
PointerType::get(IntegerType::get(C, 8), AS), NumElts);
254+
Type *IRVTy =
255+
VectorType::get(PointerType::get(IntegerType::get(C, 8), AS), EC);
227256
EXPECT_EQ(VTy, getLLTForType(*IRVTy, DL));
228257
}
229258
}

0 commit comments

Comments
 (0)