|  | // Copyright 2022 The IREE Authors | 
|  | // | 
|  | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" | 
|  | #include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" | 
|  | #include "mlir/Dialect/Arith/Utils/Utils.h" | 
|  | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" | 
|  | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | 
|  | #include "mlir/Dialect/Linalg/Utils/Utils.h" | 
|  | #include "mlir/Dialect/Tensor/IR/Tensor.h" | 
|  | #include "mlir/IR/Operation.h" | 
|  | #include "mlir/IR/OperationSupport.h" | 
|  | #include "mlir/IR/PatternMatch.h" | 
|  | #include "mlir/Interfaces/TilingInterface.h" | 
|  |  | 
|  | using namespace mlir; | 
|  | using namespace mlir::iree_compiler::IREE::LinalgExt; | 
|  |  | 
|  | FailureOr<FusionResult> LinalgExtFusionPattern::returningMatchAndRewrite( | 
|  | TilingInterface consumerOp, PatternRewriter &rewriter) const { | 
|  | // Try to fuse the producers of all operands to fuse. | 
|  | SmallVector<TilingInterface> fusedOps; | 
|  | for (int64_t operandToFuse : operandsToFuse) { | 
|  | // Check the operand exists. | 
|  | if (operandToFuse >= consumerOp->getNumOperands()) | 
|  | return failure(); | 
|  |  | 
|  | // Check the operand is a slice of a producer result. | 
|  | auto sliceOp = consumerOp->getOperand(operandToFuse) | 
|  | .getDefiningOp<tensor::ExtractSliceOp>(); | 
|  | if (!sliceOp) | 
|  | return failure(); | 
|  | auto producerOp = sliceOp.getSource().getDefiningOp<TilingInterface>(); | 
|  | if (!producerOp || producerOp->getNumResults() != 1) | 
|  | return failure(); | 
|  |  | 
|  | // Tile the producer. | 
|  | FailureOr<Value> tiledProducer = producerOp.generateResultTileValue( | 
|  | rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(), | 
|  | sliceOp.getMixedSizes()); | 
|  | if (failed(tiledProducer)) | 
|  | return failure(); | 
|  | fusedOps.push_back(cast<TilingInterface>(tiledProducer->getDefiningOp())); | 
|  | } | 
|  |  | 
|  | // Update the consumer in-place using the tiled producer results. | 
|  | SmallVector<Value> newOperands = consumerOp->getOperands(); | 
|  | for (auto it : llvm::zip(operandsToFuse, fusedOps)) { | 
|  | int64_t operandToFuse = std::get<0>(it); | 
|  | TilingInterface fusedOp = std::get<1>(it); | 
|  | newOperands[operandToFuse] = fusedOp->getResult(0); | 
|  | } | 
|  | rewriter.updateRootInPlace(consumerOp, | 
|  | [&]() { consumerOp->setOperands(newOperands); }); | 
|  |  | 
|  | return FusionResult{consumerOp, fusedOps}; | 
|  | } |