Skip to content

Commit 04570e9

Browse files
committed
[RISCV] Group the legal vector types into lists we can iterator over in the RISCVISelLowering constructor
Remove the RISCVVMVTs namespace because I don't think it provides a lot of value. If we change the mappings we'd likely have to add or remove things from the list anyway. Add a wrapper around addRegisterClass that can determine the register class from the fixed size of the type. Reviewed By: frasercrmck, rogfer01 Differential Revision: https://reviews.llvm.org/D95491
1 parent f30c523 commit 04570e9

File tree

2 files changed

+78
-150
lines changed

2 files changed

+78
-150
lines changed

llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h

-57
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "llvm/ADT/StringSwitch.h"
1919
#include "llvm/MC/MCInstrDesc.h"
2020
#include "llvm/MC/SubtargetFeature.h"
21-
#include "llvm/Support/MachineValueType.h"
2221

2322
namespace llvm {
2423

@@ -257,62 +256,6 @@ void validate(const Triple &TT, const FeatureBitset &FeatureBits);
257256

258257
} // namespace RISCVFeatures
259258

260-
namespace RISCVVMVTs {
261-
262-
constexpr MVT vint8mf8_t = MVT::nxv1i8;
263-
constexpr MVT vint8mf4_t = MVT::nxv2i8;
264-
constexpr MVT vint8mf2_t = MVT::nxv4i8;
265-
constexpr MVT vint8m1_t = MVT::nxv8i8;
266-
constexpr MVT vint8m2_t = MVT::nxv16i8;
267-
constexpr MVT vint8m4_t = MVT::nxv32i8;
268-
constexpr MVT vint8m8_t = MVT::nxv64i8;
269-
270-
constexpr MVT vint16mf4_t = MVT::nxv1i16;
271-
constexpr MVT vint16mf2_t = MVT::nxv2i16;
272-
constexpr MVT vint16m1_t = MVT::nxv4i16;
273-
constexpr MVT vint16m2_t = MVT::nxv8i16;
274-
constexpr MVT vint16m4_t = MVT::nxv16i16;
275-
constexpr MVT vint16m8_t = MVT::nxv32i16;
276-
277-
constexpr MVT vint32mf2_t = MVT::nxv1i32;
278-
constexpr MVT vint32m1_t = MVT::nxv2i32;
279-
constexpr MVT vint32m2_t = MVT::nxv4i32;
280-
constexpr MVT vint32m4_t = MVT::nxv8i32;
281-
constexpr MVT vint32m8_t = MVT::nxv16i32;
282-
283-
constexpr MVT vint64m1_t = MVT::nxv1i64;
284-
constexpr MVT vint64m2_t = MVT::nxv2i64;
285-
constexpr MVT vint64m4_t = MVT::nxv4i64;
286-
constexpr MVT vint64m8_t = MVT::nxv8i64;
287-
288-
constexpr MVT vfloat16mf4_t = MVT::nxv1f16;
289-
constexpr MVT vfloat16mf2_t = MVT::nxv2f16;
290-
constexpr MVT vfloat16m1_t = MVT::nxv4f16;
291-
constexpr MVT vfloat16m2_t = MVT::nxv8f16;
292-
constexpr MVT vfloat16m4_t = MVT::nxv16f16;
293-
constexpr MVT vfloat16m8_t = MVT::nxv32f16;
294-
295-
constexpr MVT vfloat32mf2_t = MVT::nxv1f32;
296-
constexpr MVT vfloat32m1_t = MVT::nxv2f32;
297-
constexpr MVT vfloat32m2_t = MVT::nxv4f32;
298-
constexpr MVT vfloat32m4_t = MVT::nxv8f32;
299-
constexpr MVT vfloat32m8_t = MVT::nxv16f32;
300-
301-
constexpr MVT vfloat64m1_t = MVT::nxv1f64;
302-
constexpr MVT vfloat64m2_t = MVT::nxv2f64;
303-
constexpr MVT vfloat64m4_t = MVT::nxv4f64;
304-
constexpr MVT vfloat64m8_t = MVT::nxv8f64;
305-
306-
constexpr MVT vbool1_t = MVT::nxv64i1;
307-
constexpr MVT vbool2_t = MVT::nxv32i1;
308-
constexpr MVT vbool4_t = MVT::nxv16i1;
309-
constexpr MVT vbool8_t = MVT::nxv8i1;
310-
constexpr MVT vbool16_t = MVT::nxv4i1;
311-
constexpr MVT vbool32_t = MVT::nxv2i1;
312-
constexpr MVT vbool64_t = MVT::nxv1i1;
313-
314-
} // namespace RISCVVMVTs
315-
316259
enum class RISCVVSEW {
317260
SEW_8 = 0,
318261
SEW_16,

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+78-93
Original file line numberDiff line numberDiff line change
@@ -90,64 +90,56 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
9090
if (Subtarget.hasStdExtD())
9191
addRegisterClass(MVT::f64, &RISCV::FPR64RegClass);
9292

93+
static const MVT::SimpleValueType BoolVecVTs[] = {
94+
MVT::nxv1i1, MVT::nxv2i1, MVT::nxv4i1, MVT::nxv8i1,
95+
MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1};
96+
static const MVT::SimpleValueType IntVecVTs[] = {
97+
MVT::nxv1i8, MVT::nxv2i8, MVT::nxv4i8, MVT::nxv8i8, MVT::nxv16i8,
98+
MVT::nxv32i8, MVT::nxv64i8, MVT::nxv1i16, MVT::nxv2i16, MVT::nxv4i16,
99+
MVT::nxv8i16, MVT::nxv16i16, MVT::nxv32i16, MVT::nxv1i32, MVT::nxv2i32,
100+
MVT::nxv4i32, MVT::nxv8i32, MVT::nxv16i32, MVT::nxv1i64, MVT::nxv2i64,
101+
MVT::nxv4i64, MVT::nxv8i64};
102+
static const MVT::SimpleValueType F16VecVTs[] = {
103+
MVT::nxv1f16, MVT::nxv2f16, MVT::nxv4f16,
104+
MVT::nxv8f16, MVT::nxv16f16, MVT::nxv32f16};
105+
static const MVT::SimpleValueType F32VecVTs[] = {
106+
MVT::nxv1f32, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv8f32, MVT::nxv16f32};
107+
static const MVT::SimpleValueType F64VecVTs[] = {
108+
MVT::nxv1f64, MVT::nxv2f64, MVT::nxv4f64, MVT::nxv8f64};
109+
93110
if (Subtarget.hasStdExtV()) {
94-
addRegisterClass(RISCVVMVTs::vbool64_t, &RISCV::VRRegClass);
95-
addRegisterClass(RISCVVMVTs::vbool32_t, &RISCV::VRRegClass);
96-
addRegisterClass(RISCVVMVTs::vbool16_t, &RISCV::VRRegClass);
97-
addRegisterClass(RISCVVMVTs::vbool8_t, &RISCV::VRRegClass);
98-
addRegisterClass(RISCVVMVTs::vbool4_t, &RISCV::VRRegClass);
99-
addRegisterClass(RISCVVMVTs::vbool2_t, &RISCV::VRRegClass);
100-
addRegisterClass(RISCVVMVTs::vbool1_t, &RISCV::VRRegClass);
101-
102-
addRegisterClass(RISCVVMVTs::vint8mf8_t, &RISCV::VRRegClass);
103-
addRegisterClass(RISCVVMVTs::vint8mf4_t, &RISCV::VRRegClass);
104-
addRegisterClass(RISCVVMVTs::vint8mf2_t, &RISCV::VRRegClass);
105-
addRegisterClass(RISCVVMVTs::vint8m1_t, &RISCV::VRRegClass);
106-
addRegisterClass(RISCVVMVTs::vint8m2_t, &RISCV::VRM2RegClass);
107-
addRegisterClass(RISCVVMVTs::vint8m4_t, &RISCV::VRM4RegClass);
108-
addRegisterClass(RISCVVMVTs::vint8m8_t, &RISCV::VRM8RegClass);
109-
110-
addRegisterClass(RISCVVMVTs::vint16mf4_t, &RISCV::VRRegClass);
111-
addRegisterClass(RISCVVMVTs::vint16mf2_t, &RISCV::VRRegClass);
112-
addRegisterClass(RISCVVMVTs::vint16m1_t, &RISCV::VRRegClass);
113-
addRegisterClass(RISCVVMVTs::vint16m2_t, &RISCV::VRM2RegClass);
114-
addRegisterClass(RISCVVMVTs::vint16m4_t, &RISCV::VRM4RegClass);
115-
addRegisterClass(RISCVVMVTs::vint16m8_t, &RISCV::VRM8RegClass);
116-
117-
addRegisterClass(RISCVVMVTs::vint32mf2_t, &RISCV::VRRegClass);
118-
addRegisterClass(RISCVVMVTs::vint32m1_t, &RISCV::VRRegClass);
119-
addRegisterClass(RISCVVMVTs::vint32m2_t, &RISCV::VRM2RegClass);
120-
addRegisterClass(RISCVVMVTs::vint32m4_t, &RISCV::VRM4RegClass);
121-
addRegisterClass(RISCVVMVTs::vint32m8_t, &RISCV::VRM8RegClass);
122-
123-
addRegisterClass(RISCVVMVTs::vint64m1_t, &RISCV::VRRegClass);
124-
addRegisterClass(RISCVVMVTs::vint64m2_t, &RISCV::VRM2RegClass);
125-
addRegisterClass(RISCVVMVTs::vint64m4_t, &RISCV::VRM4RegClass);
126-
addRegisterClass(RISCVVMVTs::vint64m8_t, &RISCV::VRM8RegClass);
127-
128-
if (Subtarget.hasStdExtZfh()) {
129-
addRegisterClass(RISCVVMVTs::vfloat16mf4_t, &RISCV::VRRegClass);
130-
addRegisterClass(RISCVVMVTs::vfloat16mf2_t, &RISCV::VRRegClass);
131-
addRegisterClass(RISCVVMVTs::vfloat16m1_t, &RISCV::VRRegClass);
132-
addRegisterClass(RISCVVMVTs::vfloat16m2_t, &RISCV::VRM2RegClass);
133-
addRegisterClass(RISCVVMVTs::vfloat16m4_t, &RISCV::VRM4RegClass);
134-
addRegisterClass(RISCVVMVTs::vfloat16m8_t, &RISCV::VRM8RegClass);
135-
}
111+
auto addRegClassForRVV = [this](MVT VT) {
112+
unsigned Size = VT.getSizeInBits().getKnownMinValue();
113+
assert(Size <= 512 && isPowerOf2_32(Size));
114+
const TargetRegisterClass *RC;
115+
if (Size <= 64)
116+
RC = &RISCV::VRRegClass;
117+
else if (Size == 128)
118+
RC = &RISCV::VRM2RegClass;
119+
else if (Size == 256)
120+
RC = &RISCV::VRM4RegClass;
121+
else
122+
RC = &RISCV::VRM8RegClass;
136123

137-
if (Subtarget.hasStdExtF()) {
138-
addRegisterClass(RISCVVMVTs::vfloat32mf2_t, &RISCV::VRRegClass);
139-
addRegisterClass(RISCVVMVTs::vfloat32m1_t, &RISCV::VRRegClass);
140-
addRegisterClass(RISCVVMVTs::vfloat32m2_t, &RISCV::VRM2RegClass);
141-
addRegisterClass(RISCVVMVTs::vfloat32m4_t, &RISCV::VRM4RegClass);
142-
addRegisterClass(RISCVVMVTs::vfloat32m8_t, &RISCV::VRM8RegClass);
143-
}
124+
addRegisterClass(VT, RC);
125+
};
144126

145-
if (Subtarget.hasStdExtD()) {
146-
addRegisterClass(RISCVVMVTs::vfloat64m1_t, &RISCV::VRRegClass);
147-
addRegisterClass(RISCVVMVTs::vfloat64m2_t, &RISCV::VRM2RegClass);
148-
addRegisterClass(RISCVVMVTs::vfloat64m4_t, &RISCV::VRM4RegClass);
149-
addRegisterClass(RISCVVMVTs::vfloat64m8_t, &RISCV::VRM8RegClass);
150-
}
127+
for (MVT VT : BoolVecVTs)
128+
addRegClassForRVV(VT);
129+
for (MVT VT : IntVecVTs)
130+
addRegClassForRVV(VT);
131+
132+
if (Subtarget.hasStdExtZfh())
133+
for (MVT VT : F16VecVTs)
134+
addRegClassForRVV(VT);
135+
136+
if (Subtarget.hasStdExtF())
137+
for (MVT VT : F32VecVTs)
138+
addRegClassForRVV(VT);
139+
140+
if (Subtarget.hasStdExtD())
141+
for (MVT VT : F64VecVTs)
142+
addRegClassForRVV(VT);
151143
}
152144

153145
// Compute derived properties from the register classes.
@@ -379,9 +371,22 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
379371
if (Subtarget.is64Bit()) {
380372
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i64, Custom);
381373
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i64, Custom);
374+
} else {
375+
// We must custom-lower certain vXi64 operations on RV32 due to the vector
376+
// element type being illegal.
377+
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
378+
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
379+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
382380
}
383381

384-
for (auto VT : MVT::integer_scalable_vector_valuetypes()) {
382+
for (MVT VT : BoolVecVTs) {
383+
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
384+
385+
// Mask VTs are custom-expanded into a series of standard nodes
386+
setOperationAction(ISD::TRUNCATE, VT, Custom);
387+
}
388+
389+
for (MVT VT : IntVecVTs) {
385390
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
386391

387392
setOperationAction(ISD::SMIN, VT, Legal);
@@ -392,30 +397,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
392397
setOperationAction(ISD::ROTL, VT, Expand);
393398
setOperationAction(ISD::ROTR, VT, Expand);
394399

395-
if (isTypeLegal(VT)) {
396-
// Custom-lower extensions and truncations from/to mask types.
397-
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
398-
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
399-
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
400-
401-
// We custom-lower all legally-typed vector truncates:
402-
// 1. Mask VTs are custom-expanded into a series of standard nodes
403-
// 2. Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
404-
// nodes which truncate by one power of two at a time.
405-
setOperationAction(ISD::TRUNCATE, VT, Custom);
406-
407-
// Custom-lower insert/extract operations to simplify patterns.
408-
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
409-
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
410-
}
411-
}
400+
// Custom-lower extensions and truncations from/to mask types.
401+
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
402+
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
403+
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
412404

413-
// We must custom-lower certain vXi64 operations on RV32 due to the vector
414-
// element type being illegal.
415-
if (!Subtarget.is64Bit()) {
416-
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
417-
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
418-
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
405+
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
406+
// nodes which truncate by one power of two at a time.
407+
setOperationAction(ISD::TRUNCATE, VT, Custom);
408+
409+
// Custom-lower insert/extract operations to simplify patterns.
410+
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
411+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
419412
}
420413

421414
// Expand various CCs to best match the RVV ISA, which natively supports UNE
@@ -441,25 +434,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
441434
setCondCodeAction(CC, VT, Expand);
442435
};
443436

444-
if (Subtarget.hasStdExtZfh()) {
445-
for (auto VT : {RISCVVMVTs::vfloat16mf4_t, RISCVVMVTs::vfloat16mf2_t,
446-
RISCVVMVTs::vfloat16m1_t, RISCVVMVTs::vfloat16m2_t,
447-
RISCVVMVTs::vfloat16m4_t, RISCVVMVTs::vfloat16m8_t})
437+
if (Subtarget.hasStdExtZfh())
438+
for (MVT VT : F16VecVTs)
448439
SetCommonVFPActions(VT);
449-
}
450440

451-
if (Subtarget.hasStdExtF()) {
452-
for (auto VT : {RISCVVMVTs::vfloat32mf2_t, RISCVVMVTs::vfloat32m1_t,
453-
RISCVVMVTs::vfloat32m2_t, RISCVVMVTs::vfloat32m4_t,
454-
RISCVVMVTs::vfloat32m8_t})
441+
if (Subtarget.hasStdExtF())
442+
for (MVT VT : F32VecVTs)
455443
SetCommonVFPActions(VT);
456-
}
457444

458-
if (Subtarget.hasStdExtD()) {
459-
for (auto VT : {RISCVVMVTs::vfloat64m1_t, RISCVVMVTs::vfloat64m2_t,
460-
RISCVVMVTs::vfloat64m4_t, RISCVVMVTs::vfloat64m8_t})
445+
if (Subtarget.hasStdExtD())
446+
for (MVT VT : F64VecVTs)
461447
SetCommonVFPActions(VT);
462-
}
463448
}
464449

465450
// Function alignments.

0 commit comments

Comments
 (0)