blob: 9704f5fb2529e228327a75f02cca28db2a5774c7 [file]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-dispatch-creation-producers-into-dispatch-regions"
namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FUSEENCODINGOPSINTODISPATCHREGIONSPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
// Return true if the op is fusable with a SetEncodingOp consumer.
// For now, just check if it is a LinalgOp.
static bool isFusableWithSetEncoding(Operation *op) {
return isa<linalg::LinalgOp>(op);
}
struct FuseEncodingOpsIntoDispatchRegionsPass
: public DispatchCreation::impl::FuseEncodingOpsIntoDispatchRegionsPassBase<
FuseEncodingOpsIntoDispatchRegionsPass> {
void runOnOperation() override {
mlir::FunctionOpInterface funcOp = getOperation();
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
SmallVector<IREE::Encoding::SetEncodingOp> encodingOps;
funcOp->walk([&](IREE::Encoding::SetEncodingOp encodingOp) {
encodingOps.push_back(encodingOp);
});
for (IREE::Encoding::SetEncodingOp encodingOp : encodingOps) {
OpOperand &operand = encodingOp.getSourceMutable();
auto producerDispatch =
operand.get().getDefiningOp<IREE::Flow::DispatchRegionOp>();
// Nothing to fuse with, so wrap the `encodingOp` in its own dispatch.
if (!producerDispatch) {
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
return signalPassFailure();
}
continue;
}
// Find producer operation inside of the dispatch region to determine if
// fusion is possible.
auto result = cast<OpResult>(operand.get());
auto dispatchReturnOp = cast<IREE::Flow::ReturnOp>(
producerDispatch.getBody().front().getTerminator());
auto producerInRegion = dyn_cast<OpResult>(
dispatchReturnOp->getOperand(result.getResultNumber()));
if (!producerInRegion) {
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
return signalPassFailure();
}
continue;
}
// Place the op in its own dispatch region if fusion is not possible.
if (!isFusableWithSetEncoding(producerInRegion.getOwner())) {
if (failed(IREE::Flow::wrapOpInDispatchRegion(rewriter, encodingOp))) {
return signalPassFailure();
}
continue;
}
// Fuse the `encodingOp` into the producer dispatch region.
if (failed(moveFollowingOpIntoDispatchRegion(rewriter, encodingOp,
producerDispatch))) {
return signalPassFailure();
}
}
// Dynamic dims may have dominance issues after pulling encoding ops into
// producer dispatch regions, so we need to resolve tensor.dim ops.
RewritePatternSet patterns(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::DispatchCreation