Skip to content

Commit c7fae59

Browse files
authored
[mlir][vector] Move extract_strided_slice canonicalization to folding (#135676)
Folders are preferred: https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations Included here : some missing `-----` between lit test file with mlir-opt with `-split-input-file` flag
1 parent 7fd0c8a commit c7fae59

File tree

2 files changed

+69
-103
lines changed

2 files changed

+69
-103
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+62-103
Original file line numberDiff line numberDiff line change
@@ -3714,12 +3714,67 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
37143714
return failure();
37153715
}
37163716

3717+
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3718+
static OpFoldResult
3719+
foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
3720+
Attribute foldInput) {
3721+
3722+
auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3723+
if (!dense)
3724+
return {};
3725+
3726+
// TODO: Handle non-unit strides when they become available.
3727+
if (op.hasNonUnitStrides())
3728+
return {};
3729+
3730+
VectorType sourceVecTy = op.getSourceVectorType();
3731+
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3732+
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3733+
3734+
VectorType sliceVecTy = op.getType();
3735+
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3736+
int64_t rank = sliceVecTy.getRank();
3737+
3738+
// Expand offsets and sizes to match the vector rank.
3739+
SmallVector<int64_t, 4> offsets(rank, 0);
3740+
copy(getI64SubArray(op.getOffsets()), offsets.begin());
3741+
3742+
SmallVector<int64_t, 4> sizes(sourceShape);
3743+
copy(getI64SubArray(op.getSizes()), sizes.begin());
3744+
3745+
// Calculate the slice elements by enumerating all slice positions and
3746+
// linearizing them. The enumeration order is lexicographic which yields a
3747+
// sequence of monotonically increasing linearized position indices.
3748+
const auto denseValuesBegin = dense.value_begin<Attribute>();
3749+
SmallVector<Attribute> sliceValues;
3750+
sliceValues.reserve(sliceVecTy.getNumElements());
3751+
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3752+
do {
3753+
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3754+
assert(linearizedPosition < sourceVecTy.getNumElements() &&
3755+
"Invalid index");
3756+
sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3757+
} while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3758+
3759+
assert(static_cast<int64_t>(sliceValues.size()) ==
3760+
sliceVecTy.getNumElements() &&
3761+
"Invalid number of slice elements");
3762+
return DenseElementsAttr::get(sliceVecTy, sliceValues);
3763+
}
3764+
37173765
OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
37183766
if (getSourceVectorType() == getResult().getType())
37193767
return getVector();
37203768
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
37213769
return getResult();
3722-
return {};
3770+
3771+
// ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3772+
if (auto splat =
3773+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
3774+
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
3775+
3776+
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3777+
return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
37233778
}
37243779

37253780
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
@@ -3783,98 +3838,6 @@ class StridedSliceConstantMaskFolder final
37833838
}
37843839
};
37853840

3786-
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3787-
class StridedSliceSplatConstantFolder final
3788-
: public OpRewritePattern<ExtractStridedSliceOp> {
3789-
public:
3790-
using OpRewritePattern::OpRewritePattern;
3791-
3792-
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3793-
PatternRewriter &rewriter) const override {
3794-
// Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3795-
// ConstantOp.
3796-
Value sourceVector = extractStridedSliceOp.getVector();
3797-
Attribute vectorCst;
3798-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3799-
return failure();
3800-
3801-
auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3802-
if (!splat)
3803-
return failure();
3804-
3805-
auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3806-
splat.getSplatValue<Attribute>());
3807-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3808-
newAttr);
3809-
return success();
3810-
}
3811-
};
3812-
3813-
// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3814-
// ConstantOp.
3815-
class StridedSliceNonSplatConstantFolder final
3816-
: public OpRewritePattern<ExtractStridedSliceOp> {
3817-
public:
3818-
using OpRewritePattern::OpRewritePattern;
3819-
3820-
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3821-
PatternRewriter &rewriter) const override {
3822-
// Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3823-
// ConstantOp.
3824-
Value sourceVector = extractStridedSliceOp.getVector();
3825-
Attribute vectorCst;
3826-
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3827-
return failure();
3828-
3829-
// The splat case is handled by `StridedSliceSplatConstantFolder`.
3830-
auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3831-
if (!dense || dense.isSplat())
3832-
return failure();
3833-
3834-
// TODO: Handle non-unit strides when they become available.
3835-
if (extractStridedSliceOp.hasNonUnitStrides())
3836-
return failure();
3837-
3838-
auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3839-
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3840-
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3841-
3842-
VectorType sliceVecTy = extractStridedSliceOp.getType();
3843-
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3844-
int64_t sliceRank = sliceVecTy.getRank();
3845-
3846-
// Expand offsets and sizes to match the vector rank.
3847-
SmallVector<int64_t, 4> offsets(sliceRank, 0);
3848-
copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3849-
3850-
SmallVector<int64_t, 4> sizes(sourceShape);
3851-
copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3852-
3853-
// Calculate the slice elements by enumerating all slice positions and
3854-
// linearizing them. The enumeration order is lexicographic which yields a
3855-
// sequence of monotonically increasing linearized position indices.
3856-
auto denseValuesBegin = dense.value_begin<Attribute>();
3857-
SmallVector<Attribute> sliceValues;
3858-
sliceValues.reserve(sliceVecTy.getNumElements());
3859-
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3860-
do {
3861-
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3862-
assert(linearizedPosition < sourceVecTy.getNumElements() &&
3863-
"Invalid index");
3864-
sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3865-
} while (
3866-
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3867-
3868-
assert(static_cast<int64_t>(sliceValues.size()) ==
3869-
sliceVecTy.getNumElements() &&
3870-
"Invalid number of slice elements");
3871-
auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3872-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3873-
newAttr);
3874-
return success();
3875-
}
3876-
};
3877-
38783841
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
38793842
// BroadcastOp(ExtractStrideSliceOp).
38803843
class StridedSliceBroadcast final
@@ -4018,8 +3981,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
40183981
RewritePatternSet &results, MLIRContext *context) {
40193982
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
40203983
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4021-
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4022-
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3984+
results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
40233985
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
40243986
context);
40253987
}
@@ -5659,10 +5621,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56595621

56605622
// shape_cast(constant) -> constant
56615623
if (auto splatAttr =
5662-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
5663-
return DenseElementsAttr::get(resultType,
5664-
splatAttr.getSplatValue<Attribute>());
5665-
}
5624+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5625+
return splatAttr.reshape(getType());
56665626

56675627
// shape_cast(poison) -> poison
56685628
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6006,10 +5966,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
60065966

60075967
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
60085968
// Eliminate splat constant transpose ops.
6009-
if (auto attr =
6010-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
6011-
if (attr.isSplat())
6012-
return attr.reshape(getResultVectorType());
5969+
if (auto splat =
5970+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
5971+
return splat.reshape(getResultVectorType());
60135972

60145973
// Eliminate poison transpose ops.
60155974
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))

mlir/test/Dialect/Vector/canonicalize.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
11211121
return %0, %2 : vector<4x8xf32>, vector<2xi32>
11221122
}
11231123

1124+
// -----
1125+
11241126
// CHECK-LABEL: func @bitcast_f16_to_f32
11251127
// bit pattern: 0x40004000
11261128
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
11351137
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
11361138
}
11371139

1140+
// -----
1141+
11381142
// CHECK-LABEL: func @bitcast_i8_to_i32
11391143
// bit pattern: 0xA0A0A0A0
11401144
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1732,6 +1736,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
17321736
}
17331737

17341738
// -----
1739+
17351740
// CHECK-LABEL: func.func @vector_multi_reduction_scalable(
17361741
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
17371742
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2249,6 +2254,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
22492254
return %0 : vector<8x4xf32>
22502255
}
22512256

2257+
// -----
2258+
22522259
// CHECK-LABEL: func @transpose_splat2(
22532260
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
22542261
// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>

0 commit comments

Comments
 (0)