Skip to content

Commit ad9b5a4

Browse files
nirvedhmeshramThomasRaoux
authored andcommitted
[mlir][vector] Add pattern to drop lead unit dim for Contraction Op
If the result operand has a unit leading dim it is removed from all operands. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D119206
1 parent 5565b38 commit ad9b5a4

File tree

5 files changed

+409
-108
lines changed

5 files changed

+409
-108
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,10 @@ def Vector_ContractionOp :
200200
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>,
201201
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
202202
"ArrayRef<ArrayRef<AffineExpr>>":$indexingExprs,
203-
"ArrayRef<StringRef>":$iteratorTypes)>
203+
"ArrayRef<StringRef>":$iteratorTypes)>,
204+
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
205+
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
206+
"CombiningKind":$kind)>
204207
];
205208
let extraClassDeclaration = [{
206209
VectorType getLhsType() {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,13 +502,20 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
502502
Value lhs, Value rhs, Value acc,
503503
ArrayAttr indexingMaps,
504504
ArrayAttr iteratorTypes) {
505+
build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
506+
ContractionOp::getDefaultKind());
507+
}
508+
509+
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
510+
Value lhs, Value rhs, Value acc,
511+
ArrayAttr indexingMaps,
512+
ArrayAttr iteratorTypes, CombiningKind kind) {
505513
result.addOperands({lhs, rhs, acc});
506514
result.addTypes(acc.getType());
507515
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
508516
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
509517
result.addAttribute(ContractionOp::getKindAttrName(),
510-
CombiningKindAttr::get(ContractionOp::getDefaultKind(),
511-
builder.getContext()));
518+
CombiningKindAttr::get(kind, builder.getContext()));
512519
}
513520

514521
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
910
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1011
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1112
#include "mlir/IR/Builders.h"
@@ -220,6 +221,128 @@ struct CastAwayTransferWriteLeadingOneDim
220221
}
221222
};
222223

224+
/// Turns vector.contract on vector with leading 1 dimensions into
225+
/// vector.extract followed by vector.contract on vector without leading
226+
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
227+
/// prior to extract.
228+
struct CastAwayContractionLeadingOneDim
229+
: public OpRewritePattern<vector::ContractionOp> {
230+
using OpRewritePattern::OpRewritePattern;
231+
232+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
233+
PatternRewriter &rewriter) const override {
234+
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
235+
if (oldAccType == nullptr)
236+
return failure();
237+
if (oldAccType.getRank() < 2)
238+
return failure();
239+
// TODO: implement masks.
240+
if (llvm::size(contractOp.masks()) != 0)
241+
return failure();
242+
if (oldAccType.getShape()[0] != 1)
243+
return failure();
244+
// currently we support only dropping one dim but the pattern can be applied
245+
// greedily to drop more.
246+
int64_t dropDim = 1;
247+
248+
auto oldIndexingMaps = contractOp.getIndexingMaps();
249+
SmallVector<AffineMap> newIndexingMaps;
250+
251+
auto oldIteratorTypes = contractOp.iterator_types();
252+
SmallVector<Attribute> newIteratorTypes;
253+
254+
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
255+
256+
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
257+
// only parallel type iterators can be dropped.
258+
return failure();
259+
260+
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
261+
int64_t currDim = it.index();
262+
if (currDim == dimToDrop)
263+
continue;
264+
newIteratorTypes.push_back(it.value());
265+
}
266+
267+
SmallVector<Value> operands = {contractOp.lhs(), contractOp.rhs(),
268+
contractOp.acc()};
269+
SmallVector<Value> newOperands;
270+
271+
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
272+
// Check if the dim to be dropped exists as a leading dim in the operand
273+
// if it does then we use vector.extract to drop it.
274+
bool validExtract = false;
275+
SmallVector<AffineExpr> results;
276+
auto map = it.value();
277+
int64_t orginalZeroDim = it.value().getDimPosition(0);
278+
if (orginalZeroDim != dimToDrop) {
279+
// There are two reasons to be in this path, 1. We need to
280+
// tranpose the operand to make the dim to be dropped
281+
// leading. 2. The dim to be dropped does not exist and in
282+
// that case we dont want to add a unit tranpose but we must
283+
// check all the indices to make sure this is the case.
284+
bool tranposeNeeded = false;
285+
SmallVector<int64_t> perm;
286+
SmallVector<AffineExpr> transposeResults;
287+
288+
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
289+
int64_t currDim = map.getDimPosition(i);
290+
if (currDim == dimToDrop) {
291+
tranposeNeeded = true;
292+
perm.insert(perm.begin(), i);
293+
auto targetExpr = rewriter.getAffineDimExpr(currDim);
294+
transposeResults.insert(transposeResults.begin(), targetExpr);
295+
} else {
296+
perm.push_back(i);
297+
auto targetExpr = rewriter.getAffineDimExpr(currDim);
298+
transposeResults.push_back(targetExpr);
299+
}
300+
}
301+
// Do the tranpose now if needed so that we can drop the
302+
// correct dim using extract later.
303+
if (tranposeNeeded) {
304+
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
305+
contractOp.getContext());
306+
operands[it.index()] = rewriter.create<vector::TransposeOp>(
307+
contractOp.getLoc(), operands[it.index()], perm);
308+
}
309+
}
310+
// We have taken care to have the dim to be dropped be
311+
// the leading dim. If its still not leading that means it
312+
// does not exist in this operand and hence we do not need
313+
// an extract.
314+
if (map.getDimPosition(0) == dimToDrop)
315+
validExtract = true;
316+
317+
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
318+
int64_t currDim = map.getDimPosition(i);
319+
if (currDim == dimToDrop)
320+
// This is the dim we are dropping.
321+
continue;
322+
auto targetExpr = rewriter.getAffineDimExpr(
323+
currDim < dimToDrop ? currDim : currDim - 1);
324+
results.push_back(targetExpr);
325+
}
326+
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
327+
contractOp.getContext()));
328+
// Extract if its a valid extraction, otherwise use the operand
329+
// without extraction.
330+
newOperands.push_back(validExtract
331+
? rewriter.create<vector::ExtractOp>(
332+
contractOp.getLoc(), operands[it.index()],
333+
splatZero(dropDim))
334+
: operands[it.index()]);
335+
}
336+
auto newContractOp = rewriter.create<vector::ContractionOp>(
337+
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
338+
rewriter.getAffineMapArrayAttr(newIndexingMaps),
339+
rewriter.getArrayAttr(newIteratorTypes), contractOp.kind());
340+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
341+
contractOp, contractOp->getResultTypes()[0], newContractOp);
342+
return success();
343+
}
344+
};
345+
223346
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
224347
public:
225348
CastAwayElementwiseLeadingOneDim(MLIRContext *context)
@@ -260,10 +383,11 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
260383

261384
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
262385
RewritePatternSet &patterns) {
263-
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
264-
CastAwayInsertStridedSliceLeadingOneDim,
265-
CastAwayTransferReadLeadingOneDim,
266-
CastAwayTransferWriteLeadingOneDim,
267-
CastAwayElementwiseLeadingOneDim>(patterns.getContext());
386+
patterns
387+
.add<CastAwayExtractStridedSliceLeadingOneDim,
388+
CastAwayInsertStridedSliceLeadingOneDim,
389+
CastAwayTransferReadLeadingOneDim,
390+
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
391+
CastAwayContractionLeadingOneDim>(patterns.getContext());
268392
populateShapeCastFoldingPatterns(patterns);
269393
}

0 commit comments

Comments
 (0)