|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
| 9 | +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
9 | 10 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
10 | 11 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
11 | 12 | #include "mlir/IR/Builders.h"
|
@@ -220,6 +221,128 @@ struct CastAwayTransferWriteLeadingOneDim
|
220 | 221 | }
|
221 | 222 | };
|
222 | 223 |
|
| 224 | +/// Turns vector.contract on vector with leading 1 dimensions into |
| 225 | +/// vector.extract followed by vector.contract on vector without leading |
| 226 | +/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required |
| 227 | +/// prior to extract. |
| 228 | +struct CastAwayContractionLeadingOneDim |
| 229 | + : public OpRewritePattern<vector::ContractionOp> { |
| 230 | + using OpRewritePattern::OpRewritePattern; |
| 231 | + |
| 232 | + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
| 233 | + PatternRewriter &rewriter) const override { |
| 234 | + VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>(); |
| 235 | + if (oldAccType == nullptr) |
| 236 | + return failure(); |
| 237 | + if (oldAccType.getRank() < 2) |
| 238 | + return failure(); |
| 239 | + // TODO: implement masks. |
| 240 | + if (llvm::size(contractOp.masks()) != 0) |
| 241 | + return failure(); |
| 242 | + if (oldAccType.getShape()[0] != 1) |
| 243 | + return failure(); |
| 244 | + // currently we support only dropping one dim but the pattern can be applied |
| 245 | + // greedily to drop more. |
| 246 | + int64_t dropDim = 1; |
| 247 | + |
| 248 | + auto oldIndexingMaps = contractOp.getIndexingMaps(); |
| 249 | + SmallVector<AffineMap> newIndexingMaps; |
| 250 | + |
| 251 | + auto oldIteratorTypes = contractOp.iterator_types(); |
| 252 | + SmallVector<Attribute> newIteratorTypes; |
| 253 | + |
| 254 | + int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); |
| 255 | + |
| 256 | + if (!isParallelIterator(oldIteratorTypes[dimToDrop])) |
| 257 | + // only parallel type iterators can be dropped. |
| 258 | + return failure(); |
| 259 | + |
| 260 | + for (const auto &it : llvm::enumerate(oldIteratorTypes)) { |
| 261 | + int64_t currDim = it.index(); |
| 262 | + if (currDim == dimToDrop) |
| 263 | + continue; |
| 264 | + newIteratorTypes.push_back(it.value()); |
| 265 | + } |
| 266 | + |
| 267 | + SmallVector<Value> operands = {contractOp.lhs(), contractOp.rhs(), |
| 268 | + contractOp.acc()}; |
| 269 | + SmallVector<Value> newOperands; |
| 270 | + |
| 271 | + for (const auto &it : llvm::enumerate(oldIndexingMaps)) { |
| 272 | + // Check if the dim to be dropped exists as a leading dim in the operand |
| 273 | + // if it does then we use vector.extract to drop it. |
| 274 | + bool validExtract = false; |
| 275 | + SmallVector<AffineExpr> results; |
| 276 | + auto map = it.value(); |
| 277 | + int64_t orginalZeroDim = it.value().getDimPosition(0); |
| 278 | + if (orginalZeroDim != dimToDrop) { |
| 279 | + // There are two reasons to be in this path, 1. We need to |
| 280 | + // tranpose the operand to make the dim to be dropped |
| 281 | + // leading. 2. The dim to be dropped does not exist and in |
| 282 | + // that case we dont want to add a unit tranpose but we must |
| 283 | + // check all the indices to make sure this is the case. |
| 284 | + bool tranposeNeeded = false; |
| 285 | + SmallVector<int64_t> perm; |
| 286 | + SmallVector<AffineExpr> transposeResults; |
| 287 | + |
| 288 | + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
| 289 | + int64_t currDim = map.getDimPosition(i); |
| 290 | + if (currDim == dimToDrop) { |
| 291 | + tranposeNeeded = true; |
| 292 | + perm.insert(perm.begin(), i); |
| 293 | + auto targetExpr = rewriter.getAffineDimExpr(currDim); |
| 294 | + transposeResults.insert(transposeResults.begin(), targetExpr); |
| 295 | + } else { |
| 296 | + perm.push_back(i); |
| 297 | + auto targetExpr = rewriter.getAffineDimExpr(currDim); |
| 298 | + transposeResults.push_back(targetExpr); |
| 299 | + } |
| 300 | + } |
| 301 | + // Do the tranpose now if needed so that we can drop the |
| 302 | + // correct dim using extract later. |
| 303 | + if (tranposeNeeded) { |
| 304 | + map = AffineMap::get(map.getNumDims(), 0, transposeResults, |
| 305 | + contractOp.getContext()); |
| 306 | + operands[it.index()] = rewriter.create<vector::TransposeOp>( |
| 307 | + contractOp.getLoc(), operands[it.index()], perm); |
| 308 | + } |
| 309 | + } |
| 310 | + // We have taken care to have the dim to be dropped be |
| 311 | + // the leading dim. If its still not leading that means it |
| 312 | + // does not exist in this operand and hence we do not need |
| 313 | + // an extract. |
| 314 | + if (map.getDimPosition(0) == dimToDrop) |
| 315 | + validExtract = true; |
| 316 | + |
| 317 | + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
| 318 | + int64_t currDim = map.getDimPosition(i); |
| 319 | + if (currDim == dimToDrop) |
| 320 | + // This is the dim we are dropping. |
| 321 | + continue; |
| 322 | + auto targetExpr = rewriter.getAffineDimExpr( |
| 323 | + currDim < dimToDrop ? currDim : currDim - 1); |
| 324 | + results.push_back(targetExpr); |
| 325 | + } |
| 326 | + newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, |
| 327 | + contractOp.getContext())); |
| 328 | + // Extract if its a valid extraction, otherwise use the operand |
| 329 | + // without extraction. |
| 330 | + newOperands.push_back(validExtract |
| 331 | + ? rewriter.create<vector::ExtractOp>( |
| 332 | + contractOp.getLoc(), operands[it.index()], |
| 333 | + splatZero(dropDim)) |
| 334 | + : operands[it.index()]); |
| 335 | + } |
| 336 | + auto newContractOp = rewriter.create<vector::ContractionOp>( |
| 337 | + contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], |
| 338 | + rewriter.getAffineMapArrayAttr(newIndexingMaps), |
| 339 | + rewriter.getArrayAttr(newIteratorTypes), contractOp.kind()); |
| 340 | + rewriter.replaceOpWithNewOp<vector::BroadcastOp>( |
| 341 | + contractOp, contractOp->getResultTypes()[0], newContractOp); |
| 342 | + return success(); |
| 343 | + } |
| 344 | +}; |
| 345 | + |
223 | 346 | class CastAwayElementwiseLeadingOneDim : public RewritePattern {
|
224 | 347 | public:
|
225 | 348 | CastAwayElementwiseLeadingOneDim(MLIRContext *context)
|
@@ -260,10 +383,11 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
|
260 | 383 |
|
261 | 384 | void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
262 | 385 | RewritePatternSet &patterns) {
|
263 |
| - patterns.add<CastAwayExtractStridedSliceLeadingOneDim, |
264 |
| - CastAwayInsertStridedSliceLeadingOneDim, |
265 |
| - CastAwayTransferReadLeadingOneDim, |
266 |
| - CastAwayTransferWriteLeadingOneDim, |
267 |
| - CastAwayElementwiseLeadingOneDim>(patterns.getContext()); |
| 386 | + patterns |
| 387 | + .add<CastAwayExtractStridedSliceLeadingOneDim, |
| 388 | + CastAwayInsertStridedSliceLeadingOneDim, |
| 389 | + CastAwayTransferReadLeadingOneDim, |
| 390 | + CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, |
| 391 | + CastAwayContractionLeadingOneDim>(patterns.getContext()); |
268 | 392 | populateShapeCastFoldingPatterns(patterns);
|
269 | 393 | }
|
0 commit comments