| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| struct MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> { |
| using OpRewritePattern<memref::CopyOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(memref::CopyOp copyOp, |
| PatternRewriter &rewriter) const override { |
| Operation *linalgCopy = |
| createLinalgCopyOp(rewriter, copyOp.getLoc(), copyOp.source(), |
| copyOp.target(), copyOp->getAttrs()); |
| rewriter.replaceOp(copyOp, linalgCopy->getResults()); |
| return success(); |
| } |
| }; |
| |
| struct MemrefCopyToLinalgPass |
| : public MemrefCopyToLinalgPassBase<MemrefCopyToLinalgPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<linalg::LinalgDialect>(); |
| } |
| |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(&getContext()); |
| patterns.insert<MemrefCopyOpToLinalg>(context); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> createMemrefCopyToLinalgPass() { |
| return std::make_unique<MemrefCopyToLinalgPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |