Skip to content

Commit 16b75cd

Browse files
[mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions
`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated. Differential Revision: https://reviews.llvm.org/D156684
1 parent aba0ef7 commit 16b75cd

File tree

14 files changed

+100
-163
lines changed

14 files changed

+100
-163
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def Vector_ExtractOp :
573573
PredOpTrait<"operand and result have same element type",
574574
TCresVTEtIsSameAsOpBase<0, 0>>,
575575
InferTypeOpAdaptorWithIsCompatible]>,
576-
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
576+
Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
577577
Results<(outs AnyType)> {
578578
let summary = "extract operation";
579579
let description = [{
@@ -589,7 +589,6 @@ def Vector_ExtractOp :
589589
```
590590
}];
591591
let builders = [
592-
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
593592
// Convenience builder which assumes the values in `position` are defined by
594593
// ConstantIndexOp.
595594
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
@@ -689,7 +688,7 @@ def Vector_InsertOp :
689688
PredOpTrait<"source operand and result have same element type",
690689
TCresVTEtIsSameAsOpBase<0, 0>>,
691690
AllTypesMatch<["dest", "res"]>]>,
692-
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
691+
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
693692
Results<(outs AnyVectorOfAnyRank:$res)> {
694693
let summary = "insert operation";
695694
let description = [{
@@ -711,8 +710,6 @@ def Vector_InsertOp :
711710
}];
712711

713712
let builders = [
714-
OpBuilder<(ins "Value":$source, "Value":$dest,
715-
"ArrayRef<int64_t>":$position)>,
716713
// Convenience builder which assumes all values are constant indices.
717714
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
718715
];

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -807,8 +807,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
807807

808808
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
809809
op.getSource(), newIndices);
810-
result = rewriter.create<vector::InsertOp>(loc, el, result,
811-
rewriter.getI64ArrayAttr(i));
810+
result = rewriter.create<vector::InsertOp>(loc, el, result, i);
812811
}
813812
} else {
814813
if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
@@ -832,7 +831,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
832831
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
833832
op.getSource(), newIndices);
834833
result = rewriter.create<vector::InsertOp>(
835-
op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
834+
op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
836835
}
837836
}
838837
}

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,44 +1025,37 @@ class VectorExtractOpConversion
10251025
auto loc = extractOp->getLoc();
10261026
auto resultType = extractOp.getResult().getType();
10271027
auto llvmResultType = typeConverter->convertType(resultType);
1028-
auto positionArrayAttr = extractOp.getPosition();
1028+
ArrayRef<int64_t> positionArray = extractOp.getPosition();
10291029

10301030
// Bail if result type cannot be lowered.
10311031
if (!llvmResultType)
10321032
return failure();
10331033

10341034
// Extract entire vector. Should be handled by folder, but just to be safe.
1035-
if (positionArrayAttr.empty()) {
1035+
if (positionArray.empty()) {
10361036
rewriter.replaceOp(extractOp, adaptor.getVector());
10371037
return success();
10381038
}
10391039

10401040
// One-shot extraction of vector from array (only requires extractvalue).
10411041
if (isa<VectorType>(resultType)) {
1042-
SmallVector<int64_t> indices;
1043-
for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
1044-
indices.push_back(idx.getInt());
10451042
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1046-
loc, adaptor.getVector(), indices);
1043+
loc, adaptor.getVector(), positionArray);
10471044
rewriter.replaceOp(extractOp, extracted);
10481045
return success();
10491046
}
10501047

10511048
// Potential extraction of 1-D vector from array.
10521049
Value extracted = adaptor.getVector();
1053-
auto positionAttrs = positionArrayAttr.getValue();
1054-
if (positionAttrs.size() > 1) {
1055-
SmallVector<int64_t> nMinusOnePosition;
1056-
for (auto idx : positionAttrs.drop_back())
1057-
nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
1058-
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1059-
nMinusOnePosition);
1050+
if (positionArray.size() > 1) {
1051+
extracted = rewriter.create<LLVM::ExtractValueOp>(
1052+
loc, extracted, positionArray.drop_back());
10601053
}
10611054

10621055
// Remaining extraction of element from 1-D LLVM vector
1063-
auto position = cast<IntegerAttr>(positionAttrs.back());
10641056
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1065-
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
1057+
auto constant =
1058+
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
10661059
extracted =
10671060
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
10681061
rewriter.replaceOp(extractOp, extracted);
@@ -1147,52 +1140,48 @@ class VectorInsertOpConversion
11471140
auto sourceType = insertOp.getSourceType();
11481141
auto destVectorType = insertOp.getDestVectorType();
11491142
auto llvmResultType = typeConverter->convertType(destVectorType);
1150-
auto positionArrayAttr = insertOp.getPosition();
1143+
ArrayRef<int64_t> positionArray = insertOp.getPosition();
11511144

11521145
// Bail if result type cannot be lowered.
11531146
if (!llvmResultType)
11541147
return failure();
11551148

11561149
// Overwrite entire vector with value. Should be handled by folder, but
11571150
// just to be safe.
1158-
if (positionArrayAttr.empty()) {
1151+
if (positionArray.empty()) {
11591152
rewriter.replaceOp(insertOp, adaptor.getSource());
11601153
return success();
11611154
}
11621155

11631156
// One-shot insertion of a vector into an array (only requires insertvalue).
11641157
if (isa<VectorType>(sourceType)) {
11651158
Value inserted = rewriter.create<LLVM::InsertValueOp>(
1166-
loc, adaptor.getDest(), adaptor.getSource(),
1167-
LLVM::convertArrayToIndices(positionArrayAttr));
1159+
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
11681160
rewriter.replaceOp(insertOp, inserted);
11691161
return success();
11701162
}
11711163

11721164
// Potential extraction of 1-D vector from array.
11731165
Value extracted = adaptor.getDest();
1174-
auto positionAttrs = positionArrayAttr.getValue();
1175-
auto position = cast<IntegerAttr>(positionAttrs.back());
11761166
auto oneDVectorType = destVectorType;
1177-
if (positionAttrs.size() > 1) {
1167+
if (positionArray.size() > 1) {
11781168
oneDVectorType = reducedVectorTypeBack(destVectorType);
11791169
extracted = rewriter.create<LLVM::ExtractValueOp>(
1180-
loc, extracted,
1181-
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
1170+
loc, extracted, positionArray.drop_back());
11821171
}
11831172

11841173
// Insertion of an element into a 1-D LLVM vector.
11851174
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1186-
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
1175+
auto constant =
1176+
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
11871177
Value inserted = rewriter.create<LLVM::InsertElementOp>(
11881178
loc, typeConverter->convertType(oneDVectorType), extracted,
11891179
adaptor.getSource(), constant);
11901180

11911181
// Potential insertion of resulting 1-D vector into array.
1192-
if (positionAttrs.size() > 1) {
1182+
if (positionArray.size() > 1) {
11931183
inserted = rewriter.create<LLVM::InsertValueOp>(
1194-
loc, adaptor.getDest(), inserted,
1195-
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
1184+
loc, adaptor.getDest(), inserted, positionArray.drop_back());
11961185
}
11971186

11981187
rewriter.replaceOp(insertOp, inserted);

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -886,10 +886,9 @@ struct UnrollTransferReadConversion
886886
/// vector::InsertOp, return that operation's indices.
887887
void getInsertionIndices(TransferReadOp xferOp,
888888
SmallVector<int64_t, 8> &indices) const {
889-
if (auto insertOp = getInsertOp(xferOp)) {
890-
for (Attribute attr : insertOp.getPosition())
891-
indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
892-
}
889+
if (auto insertOp = getInsertOp(xferOp))
890+
indices.assign(insertOp.getPosition().begin(),
891+
insertOp.getPosition().end());
893892
}
894893

895894
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1013,10 +1012,9 @@ struct UnrollTransferWriteConversion
10131012
/// indices.
10141013
void getExtractionIndices(TransferWriteOp xferOp,
10151014
SmallVector<int64_t, 8> &indices) const {
1016-
if (auto extractOp = getExtractOp(xferOp)) {
1017-
for (Attribute attr : extractOp.getPosition())
1018-
indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
1019-
}
1015+
if (auto extractOp = getExtractOp(xferOp))
1016+
indices.assign(extractOp.getPosition().begin(),
1017+
extractOp.getPosition().end());
10201018
}
10211019

10221020
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ struct VectorExtractOpConvert final
152152
return success();
153153
}
154154

155-
int32_t id = getFirstIntValue(extractOp.getPosition());
155+
int32_t id = extractOp.getPosition()[0];
156156
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
157157
extractOp, adaptor.getVector(), id);
158158
return success();
@@ -232,7 +232,7 @@ struct VectorInsertOpConvert final
232232
return success();
233233
}
234234

235-
int32_t id = getFirstIntValue(insertOp.getPosition());
235+
int32_t id = insertOp.getPosition()[0];
236236
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
237237
insertOp, adaptor.getSource(), adaptor.getDest(), id);
238238
return success();

0 commit comments

Comments
 (0)