@@ -3714,12 +3714,67 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3714
3714
return failure ();
3715
3715
}
3716
3716
3717
+ // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3718
+ static OpFoldResult
3719
+ foldExtractStridedSliceNonSplatConstant (ExtractStridedSliceOp op,
3720
+ Attribute foldInput) {
3721
+
3722
+ auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3723
+ if (!dense)
3724
+ return {};
3725
+
3726
+ // TODO: Handle non-unit strides when they become available.
3727
+ if (op.hasNonUnitStrides ())
3728
+ return {};
3729
+
3730
+ VectorType sourceVecTy = op.getSourceVectorType ();
3731
+ ArrayRef<int64_t > sourceShape = sourceVecTy.getShape ();
3732
+ SmallVector<int64_t , 4 > sourceStrides = computeStrides (sourceShape);
3733
+
3734
+ VectorType sliceVecTy = op.getType ();
3735
+ ArrayRef<int64_t > sliceShape = sliceVecTy.getShape ();
3736
+ int64_t rank = sliceVecTy.getRank ();
3737
+
3738
+ // Expand offsets and sizes to match the vector rank.
3739
+ SmallVector<int64_t , 4 > offsets (rank, 0 );
3740
+ copy (getI64SubArray (op.getOffsets ()), offsets.begin ());
3741
+
3742
+ SmallVector<int64_t , 4 > sizes (sourceShape);
3743
+ copy (getI64SubArray (op.getSizes ()), sizes.begin ());
3744
+
3745
+ // Calculate the slice elements by enumerating all slice positions and
3746
+ // linearizing them. The enumeration order is lexicographic which yields a
3747
+ // sequence of monotonically increasing linearized position indices.
3748
+ const auto denseValuesBegin = dense.value_begin <Attribute>();
3749
+ SmallVector<Attribute> sliceValues;
3750
+ sliceValues.reserve (sliceVecTy.getNumElements ());
3751
+ SmallVector<int64_t > currSlicePosition (offsets.begin (), offsets.end ());
3752
+ do {
3753
+ int64_t linearizedPosition = linearize (currSlicePosition, sourceStrides);
3754
+ assert (linearizedPosition < sourceVecTy.getNumElements () &&
3755
+ " Invalid index" );
3756
+ sliceValues.push_back (*(denseValuesBegin + linearizedPosition));
3757
+ } while (succeeded (incSlicePosition (currSlicePosition, sliceShape, offsets)));
3758
+
3759
+ assert (static_cast <int64_t >(sliceValues.size ()) ==
3760
+ sliceVecTy.getNumElements () &&
3761
+ " Invalid number of slice elements" );
3762
+ return DenseElementsAttr::get (sliceVecTy, sliceValues);
3763
+ }
3764
+
3717
3765
OpFoldResult ExtractStridedSliceOp::fold (FoldAdaptor adaptor) {
3718
3766
if (getSourceVectorType () == getResult ().getType ())
3719
3767
return getVector ();
3720
3768
if (succeeded (foldExtractStridedOpFromInsertChain (*this )))
3721
3769
return getResult ();
3722
- return {};
3770
+
3771
+ // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3772
+ if (auto splat =
3773
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector ()))
3774
+ DenseElementsAttr::get (getType (), splat.getSplatValue <Attribute>());
3775
+
3776
+ // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
3777
+ return foldExtractStridedSliceNonSplatConstant (*this , adaptor.getVector ());
3723
3778
}
3724
3779
3725
3780
void ExtractStridedSliceOp::getOffsets (SmallVectorImpl<int64_t > &results) {
@@ -3783,98 +3838,6 @@ class StridedSliceConstantMaskFolder final
3783
3838
}
3784
3839
};
3785
3840
3786
- // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3787
- class StridedSliceSplatConstantFolder final
3788
- : public OpRewritePattern<ExtractStridedSliceOp> {
3789
- public:
3790
- using OpRewritePattern::OpRewritePattern;
3791
-
3792
- LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
3793
- PatternRewriter &rewriter) const override {
3794
- // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3795
- // ConstantOp.
3796
- Value sourceVector = extractStridedSliceOp.getVector ();
3797
- Attribute vectorCst;
3798
- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
3799
- return failure ();
3800
-
3801
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3802
- if (!splat)
3803
- return failure ();
3804
-
3805
- auto newAttr = SplatElementsAttr::get (extractStridedSliceOp.getType (),
3806
- splat.getSplatValue <Attribute>());
3807
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractStridedSliceOp,
3808
- newAttr);
3809
- return success ();
3810
- }
3811
- };
3812
-
3813
- // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3814
- // ConstantOp.
3815
- class StridedSliceNonSplatConstantFolder final
3816
- : public OpRewritePattern<ExtractStridedSliceOp> {
3817
- public:
3818
- using OpRewritePattern::OpRewritePattern;
3819
-
3820
- LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
3821
- PatternRewriter &rewriter) const override {
3822
- // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3823
- // ConstantOp.
3824
- Value sourceVector = extractStridedSliceOp.getVector ();
3825
- Attribute vectorCst;
3826
- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
3827
- return failure ();
3828
-
3829
- // The splat case is handled by `StridedSliceSplatConstantFolder`.
3830
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3831
- if (!dense || dense.isSplat ())
3832
- return failure ();
3833
-
3834
- // TODO: Handle non-unit strides when they become available.
3835
- if (extractStridedSliceOp.hasNonUnitStrides ())
3836
- return failure ();
3837
-
3838
- auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType ());
3839
- ArrayRef<int64_t > sourceShape = sourceVecTy.getShape ();
3840
- SmallVector<int64_t , 4 > sourceStrides = computeStrides (sourceShape);
3841
-
3842
- VectorType sliceVecTy = extractStridedSliceOp.getType ();
3843
- ArrayRef<int64_t > sliceShape = sliceVecTy.getShape ();
3844
- int64_t sliceRank = sliceVecTy.getRank ();
3845
-
3846
- // Expand offsets and sizes to match the vector rank.
3847
- SmallVector<int64_t , 4 > offsets (sliceRank, 0 );
3848
- copy (getI64SubArray (extractStridedSliceOp.getOffsets ()), offsets.begin ());
3849
-
3850
- SmallVector<int64_t , 4 > sizes (sourceShape);
3851
- copy (getI64SubArray (extractStridedSliceOp.getSizes ()), sizes.begin ());
3852
-
3853
- // Calculate the slice elements by enumerating all slice positions and
3854
- // linearizing them. The enumeration order is lexicographic which yields a
3855
- // sequence of monotonically increasing linearized position indices.
3856
- auto denseValuesBegin = dense.value_begin <Attribute>();
3857
- SmallVector<Attribute> sliceValues;
3858
- sliceValues.reserve (sliceVecTy.getNumElements ());
3859
- SmallVector<int64_t > currSlicePosition (offsets.begin (), offsets.end ());
3860
- do {
3861
- int64_t linearizedPosition = linearize (currSlicePosition, sourceStrides);
3862
- assert (linearizedPosition < sourceVecTy.getNumElements () &&
3863
- " Invalid index" );
3864
- sliceValues.push_back (*(denseValuesBegin + linearizedPosition));
3865
- } while (
3866
- succeeded (incSlicePosition (currSlicePosition, sliceShape, offsets)));
3867
-
3868
- assert (static_cast <int64_t >(sliceValues.size ()) ==
3869
- sliceVecTy.getNumElements () &&
3870
- " Invalid number of slice elements" );
3871
- auto newAttr = DenseElementsAttr::get (sliceVecTy, sliceValues);
3872
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractStridedSliceOp,
3873
- newAttr);
3874
- return success ();
3875
- }
3876
- };
3877
-
3878
3841
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3879
3842
// BroadcastOp(ExtractStrideSliceOp).
3880
3843
class StridedSliceBroadcast final
@@ -4018,8 +3981,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
4018
3981
RewritePatternSet &results, MLIRContext *context) {
4019
3982
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4020
3983
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4021
- results.add <StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4022
- StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3984
+ results.add <StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4023
3985
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4024
3986
context);
4025
3987
}
@@ -5659,10 +5621,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5659
5621
5660
5622
// shape_cast(constant) -> constant
5661
5623
if (auto splatAttr =
5662
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5663
- return DenseElementsAttr::get (resultType,
5664
- splatAttr.getSplatValue <Attribute>());
5665
- }
5624
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ()))
5625
+ return splatAttr.reshape (getType ());
5666
5626
5667
5627
// shape_cast(poison) -> poison
5668
5628
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource ())) {
@@ -6006,10 +5966,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6006
5966
6007
5967
OpFoldResult vector::TransposeOp::fold (FoldAdaptor adaptor) {
6008
5968
// Eliminate splat constant transpose ops.
6009
- if (auto attr =
6010
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector ()))
6011
- if (attr.isSplat ())
6012
- return attr.reshape (getResultVectorType ());
5969
+ if (auto splat =
5970
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector ()))
5971
+ return splat.reshape (getResultVectorType ());
6013
5972
6014
5973
// Eliminate poison transpose ops.
6015
5974
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector ()))
0 commit comments