blob: 9bef01a1c96ced25a8ac56ee8ff2685d03902a1a [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- BufferizeCopyOnlyDispatchesPassPass.cpp ----------------------------===//
//
// This pass converts dispatches that are copy only into a form where backends
// can tile and distribute them appropriately.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Pass to bufferize early copy-only dispatches. This allows backends
/// to use the `linalg.generic` operation generated for lowering the dispatch.
struct BufferizeCopyOnlyDispatchesPass
: public BufferizeCopyOnlyDispatchesBase<BufferizeCopyOnlyDispatchesPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, bufferization::BufferizationDialect,
IREE::Flow::FlowDialect, linalg::LinalgDialect,
memref::MemRefDialect, tensor::TensorDialect>();
}
void runOnOperation() override;
};
} // namespace
void BufferizeCopyOnlyDispatchesPass::runOnOperation() {
ModuleOp module = getOperation();
SmallVector<Operation *> copyOnlyFunctions;
auto funcOps = module.getOps<func::FuncOp>();
for (auto funcOp : funcOps) {
/// Check if the dispatch has all sources for `flow.dispatch.tensor.store`
/// operations coming from `flow.dispatch.tensor.load` operations. If so,
/// this dispatch is just a copy dispatch.
auto walkResult = funcOp.walk(
[&](IREE::Flow::DispatchTensorStoreOp storeOp) -> WalkResult {
return success(isReadOnly(storeOp.value()));
});
if (walkResult.wasInterrupted()) continue;
// The function is just a copy.
copyOnlyFunctions.push_back(funcOp);
}
// There are no copy-only functions. So nothing to do.
if (copyOnlyFunctions.empty()) return;
// Bufferize the dispatch to create a `linalg.generic` as a copy operation.
// This can then be used by the backends to tile and distribute.
// Currently bufferization does not handle single function bufferization. So
// check that all functions are copy only and can be bufferized.
if (copyOnlyFunctions.size() !=
std::distance(funcOps.begin(), funcOps.end())) {
module.emitOpError(
"module contains functions that are both copy only and not copy only. "
"This is currently unhandled.");
return signalPassFailure();
}
// Apply the bufferization passes.
OpPassManager bufferizationPipeline(module.getOperationName());
addLinalgBufferizePasses(bufferizationPipeline);
if (failed(runPipeline(bufferizationPipeline, module))) {
return signalPassFailure();
}
// Check that there are no allocs created.
auto hasAlloc = module.walk(
[&](memref::AllocOp /*op*/) -> WalkResult { return failure(); });
if (hasAlloc.wasInterrupted()) {
module.emitOpError(
"unexpected allocations while bufferizing copy dispatch");
return signalPassFailure();
}
}
std::unique_ptr<OperationPass<ModuleOp>>
createBufferizeCopyOnlyDispatchesPass() {
return std::make_unique<BufferizeCopyOnlyDispatchesPass>();
}
} // namespace iree_compiler
} // namespace mlir