Add transform dialect op to allow hoisting bounded allocs in a function (#12284)
In the process, this refactors and extends the transformation to perform
hoisting of bounded memref::AllocaOp to also work with memref::AllocOp.
This makes it possible to reuse the utility with GPUs.
diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp
index 4c56f99..d25831b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp
@@ -4,16 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#define DEBUG_TYPE "iree-codegen-hoist-statically-bound-allocations"
-
namespace mlir {
namespace iree_compiler {
@@ -26,54 +26,10 @@
} // namespace
-/// Some uses of a `memref.alloca` can be replaced with a `memref.subview`
-/// easily. Other uses (like a use in a `scf.yield` or `func.return`) are
-/// non-trivial because of compatibility between types of different SSA values.
-static bool isUseReplacableWithSubview(OpOperand &use) {
- Operation *user = use.getOwner();
- return isa<linalg::LinalgOp, memref::StoreOp, memref::SubViewOp>(user);
-}
-
void HoistStaticallyBoundAllocationsPass::runOnOperation() {
func::FuncOp funcOp = getOperation();
- SmallVector<memref::AllocaOp> allocaOps;
-
- // Collect all allocas that are hoistable.
- funcOp.walk([&](memref::AllocaOp allocaOp) {
- if (allocaOp->getBlock() == &funcOp.getBody().front()) return;
- if (allocaOp.getDynamicSizes().empty()) {
- allocaOps.push_back(allocaOp);
- return;
- }
- if (llvm::all_of(allocaOp->getUses(), [](OpOperand &use) {
- return isUseReplacableWithSubview(use);
- })) {
- allocaOps.push_back(allocaOp);
- return;
- }
- });
-
- // Hoist the allocas and replace all uses.
- OpBuilder builder(&getContext());
- for (auto allocaOp : allocaOps) {
- LLVM_DEBUG({
- llvm::dbgs() << "Alloca Op : ";
- allocaOp->dump();
- int numUses = std::distance(allocaOp.getResult().use_begin(),
- allocaOp.getResult().use_end());
- llvm::dbgs() << " num Uses : " << numUses;
- });
- std::optional<Value> replacement =
- hoistStaticallyBoundAllocations(funcOp, builder, allocaOp);
- if (!replacement) continue;
- LLVM_DEBUG({
- llvm::dbgs() << "Replacement : ";
- replacement->dump();
- });
- Value replacementVal = replacement.value();
- allocaOp.getResult().replaceAllUsesWith(replacementVal);
- allocaOp->erase();
- }
+ IRRewriter rewriter(funcOp->getContext());
+ hoistStaticallyBoundAllocationsInFunc<memref::AllocaOp>(rewriter, funcOp);
}
std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
index e72cd79..fa01046 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
@@ -67,6 +67,7 @@
"//compiler/src/iree/compiler/Codegen:PassHeaders",
"//compiler/src/iree/compiler/Codegen/Common:CommonPasses",
"//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
+ "//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index 5bd3d9f..6d6d6de 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -58,6 +58,7 @@
iree::compiler::Codegen::Common::CommonPasses
iree::compiler::Codegen::Interfaces::BufferizationInterfaces
iree::compiler::Codegen::PassHeaders
+ iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 753da93..9983bd5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -54,6 +55,68 @@
mlir::iree_compiler::IREE::transform_dialect::CommonExtensions>();
}
+// Return true if all the uses of op are either Store/transfer_write.
+// There can be SubviewOp users as long as all its users are also
+// StoreOp/transfer_write. If return true it also fills out the uses, if it
+// returns false uses is unchanged.
+static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
+ std::vector<Operation *> opUses;
+ for (OpOperand &use : op->getUses()) {
+ Operation *useOp = use.getOwner();
+ if (isa<memref::DeallocOp, vector::TransferWriteOp, memref::StoreOp>(
+ useOp) ||
+ (isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
+ opUses.push_back(useOp);
+ continue;
+ }
+ return false;
+ }
+ uses.insert(uses.end(), opUses.begin(), opUses.end());
+ return true;
+}
+
+// Track temporary allocations that are never read from. If this is the case
+// it means both the allocations and associated stores can be removed.
+static void eraseDeadAllocAndStores(Operation *parentOp) {
+ std::vector<Operation *> opToErase;
+ parentOp->walk([&](memref::AllocOp op) {
+ if (allUsesAreStores(op, opToErase)) {
+ opToErase.push_back(op.getOperation());
+ }
+ });
+ for (Operation *op : opToErase) {
+ op->erase();
+ }
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyBufferOptimizationsOp
+//===---------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
+ Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ // Apply store to load forwarding and dead store elimination.
+ vector::transferOpflowOpt(target);
+ eraseDeadAllocAndStores(target);
+
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::ApplyBufferOptimizationsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::producesHandle(getResult(), effects);
+ transform::modifiesPayload(effects);
+}
+
+void transform_dialect::ApplyBufferOptimizationsOp::build(
+ OpBuilder &builder, OperationState &result, Value target) {
+ result.addOperands(target);
+ result.addTypes({pdl::OperationType::get(target.getContext())});
+}
+
//===---------------------------------------------------------------------===//
// ApplyPatternsOp
//===---------------------------------------------------------------------===//
@@ -311,6 +374,20 @@
}
//===----------------------------------------------------------------------===//
+// HoistStaticAllocOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform_dialect::HoistStaticAllocOp::applyToOne(
+ func::FuncOp funcOp, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ IRRewriter rewriter(funcOp->getContext());
+ mlir::iree_compiler::hoistStaticallyBoundAllocationsInFunc<memref::AllocOp>(
+ rewriter, funcOp);
+ results.push_back(funcOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
// ShareForallOperandsOp
//===----------------------------------------------------------------------===//
@@ -482,14 +559,14 @@
// Step 3. Outline the compute workload region and set up the workload
// operands, if this has not been done already.
// Using `transform.iree.tile_to_forall_and_workgroup_count_region` is
- // the preferred way to set up tiling and workgroup_count region **at the same
- // time**.
+ // the preferred way to set up tiling and workgroup_count region **at the
+ // same time**.
//
- // The block of code below will be retired once there is enough confidence we
- // can do everything without it. This includes in particular providing custom
- // fusion heuristics at the flow level: at this time, the only way to fully
- // control fusion of more advanced cases is to use the transform dialect at
- // the flow level and explicitly match the ops we want to fuse.
+ // The block of code below will be retired once there is enough confidence
+ // we can do everything without it. This includes in particular providing
+ // custom fusion heuristics at the flow level: at this time, the only way to
+ // fully control fusion of more advanced cases is to use the transform
+ // dialect at the flow level and explicitly match the ops we want to fuse.
// Once fusion is customizable enough in perpetuity, we can retire this.
if (exportOp.getWorkgroupCount().empty()) {
if (llvm::any_of(forallOp.getUpperBound(rewriter), [](Value v) {
@@ -497,9 +574,11 @@
})) {
return forallOp->emitError(
"unsupported dynamic workgroup_count atm --- need to slice out "
- "workgroup_count computation into ExecutableExport::workgroup_count."
+ "workgroup_count computation into "
+ "ExecutableExport::workgroup_count."
"\nThis region may require arbitrary computations and cannot "
- "magically match what the `stream.cmd.dispatch` has already imposed "
+ "magically match what the `stream.cmd.dispatch` has already "
+ "imposed "
"on us at a distance."
"\nFor now we must specify the number of values properly when "
"applying the topLevel tile_to_forall_op");
@@ -628,8 +707,8 @@
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
// Call the default builder which sets up the proper operands segment sizes
- // attributes for multiple variadic operands. In the absence of this, horrible
- // bugs ensue.
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
@@ -665,8 +744,8 @@
dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
staticNumThreads);
// Call the default builder which sets up the proper operands segment sizes
- // attributes for multiple variadic operands. In the absence of this, horrible
- // bugs ensue.
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
build(builder, result,
@@ -1226,64 +1305,5 @@
return DiagnosedSilenceableFailure::success();
}
-// Return true if all the uses of op are either Store/transfer_write.
-// There can be SubviewOp users as long as all its users are also
-// StoreOp/transfer_write. If return true it also fills out the uses, if it
-// returns false uses is unchanged.
-static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
- std::vector<Operation *> opUses;
- for (OpOperand &use : op->getUses()) {
- Operation *useOp = use.getOwner();
- if (isa<memref::DeallocOp, vector::TransferWriteOp, memref::StoreOp>(
- useOp) ||
- (isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
- opUses.push_back(useOp);
- continue;
- }
- return false;
- }
- uses.insert(uses.end(), opUses.begin(), opUses.end());
- return true;
-}
-
-// Track temporary allocations that are never read from. If this is the case
-// it means both the allocations and associated stores can be removed.
-static void eraseDeadAllocAndStores(Operation *parentOp) {
- std::vector<Operation *> opToErase;
- parentOp->walk([&](memref::AllocOp op) {
- if (allUsesAreStores(op, opToErase)) {
- opToErase.push_back(op.getOperation());
- }
- });
- for (Operation *op : opToErase) {
- op->erase();
- }
-}
-
-DiagnosedSilenceableFailure
-transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- // Apply store to load forwarding and dead store elimination.
- vector::transferOpflowOpt(target);
- eraseDeadAllocAndStores(target);
-
- results.push_back(target);
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::ApplyBufferOptimizationsOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getTarget(), effects);
- transform::producesHandle(getResult(), effects);
- transform::modifiesPayload(effects);
-}
-
-void transform_dialect::ApplyBufferOptimizationsOp::build(
- OpBuilder &builder, OperationState &result, Value target) {
- result.addOperands(target);
- result.addTypes({pdl::OperationType::get(target.getContext())});
-}
-
#define GET_OP_CLASSES
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index ea4491e..2e9de43 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -17,6 +17,42 @@
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
+def ApplyBufferOptimizationsOp :
+ Op<Transform_Dialect, "iree.apply_buffer_optimizations",
+ [TransformEachOpTrait,
+ TransformOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ This applies memory optimization on memref. In particular it does store to
+ load forwarding, dead store elimination and dead alloc elimination.
+
+ #### Return modes
+
+ This operation applies a set of memory optimization on the whole region of
+ the operand.
+
+ If the transformation is successful it returns the handle to the
+ same payload as its operand to allow for simpler composition.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$result);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def ApplyPatternsOp : Op<Transform_Dialect, "iree.apply_patterns",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
@@ -125,6 +161,32 @@
}];
}
+def HoistStaticAllocOp : Op<Transform_Dialect, "iree.hoist_static_alloc",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Hoist static allocations";
+ let description = [{
+ Find static allocations and hoist them to the top level.
+
+ #### Return modes
+ This operation applies static alloc hoisting the whole region of the operand.
+ It always return success.
+ }];
+
+ let arguments = (ins Transform_ConcreteOpType<"func.func">:$target);
+ let results = (outs Transform_ConcreteOpType<"func.func">:$result);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::func::FuncOp funcOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def IREEBufferizeOp : Op<Transform_Dialect, "iree.bufferize",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
@@ -453,40 +515,4 @@
}];
}
-def ApplyBufferOptimizationsOp :
- Op<Transform_Dialect, "iree.apply_buffer_optimizations",
- [TransformEachOpTrait,
- TransformOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let description = [{
- This applies memory optimization on memref. In particular it does store to
- load forwarding, dead store elimination and dead alloc elimination.
-
- #### Return modes
-
- This operation applies a set of memory optimization on the whole region of
- the operand.
-
- If the transformation is successful it returns the handle to the
- same payload as its operand to allow for simpler composition.
- }];
-
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$result);
-
- let assemblyFormat = "$target attr-dict";
- let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<(ins "Value":$target)>
- ];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 901f93f..d2a44c8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/TargetSelect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index d96ee7b..91bb0d7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -98,8 +98,9 @@
unsigned alignment) {
auto funcOp = builder.getInsertionPoint()->getParentOfType<func::FuncOp>();
if (funcOp) {
- std::optional<Value> hoistedAllocation = hoistStaticallyBoundAllocations(
- funcOp, builder, loc, memRefType, dynamicSizes, alignment);
+ std::optional<Value> hoistedAllocation =
+ hoistOneStaticallyBoundAllocation<memref::AllocaOp>(
+ funcOp, builder, loc, memRefType, dynamicSizes, alignment);
if (hoistedAllocation) {
return hoistedAllocation.value();
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
index a65da9a..adbf5e7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
index 616ad87..a60227e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -39,6 +39,7 @@
"tensor_pad.mlir",
"tensorcore_vectorization.mlir",
"tile_on_tensor.mlir",
+ "transform_dialect_hoist_allocs.mlir",
"transform_dialect_vector_distribution.mlir",
"transform_dialect_bufferize.mlir",
"transform_dialect_promote_operands.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index b63550c..78b834c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -36,6 +36,7 @@
"tensorcore_vectorization.mlir"
"tile_on_tensor.mlir"
"transform_dialect_bufferize.mlir"
+ "transform_dialect_hoist_allocs.mlir"
"transform_dialect_promote_operands.mlir"
"transform_dialect_vector_distribution.mlir"
"transform_distribute_forall.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_hoist_allocs.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_hoist_allocs.mlir
new file mode 100644
index 0000000..cabc06b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_hoist_allocs.mlir
@@ -0,0 +1,88 @@
+// RUN: iree-opt --split-input-file -iree-transform-dialect-interpreter -transform-dialect-drop-schedule %s | FileCheck %s
+
+func.func @non_entry_bb_allocs() {
+ cf.br ^bb1
+ ^bb1() :
+ %0 = memref.alloc() : memref<16xi32>
+ memref.dealloc %0 : memref<16xi32>
+ return
+}
+// CHECK-LABEL: func @non_entry_bb_allocs()
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
+// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16xi32>
+// CHECK-NEXT: cf.br ^bb1
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: return
+
+transform.sequence failures(propagate) {
+^bb1(%module: !pdl.operation):
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!pdl.operation) -> !transform.op<"func.func">
+ %func_2 = transform.iree.hoist_static_alloc %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0, 16)>
+func.func @nested_op_alloc_subview_use_static(%arg0 : index, %o0 : index, %o1 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : i32
+ scf.for %iv = %c0 to %arg0 step %c1 {
+ %0 = affine.min #map(%iv)
+ %1 = memref.alloc() : memref<16x16xi32>
+ %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref<16x16xi32> to memref<?x?xi32, strided<[?, 1], offset: ?>>
+ memref.dealloc %1 : memref<16x16xi32>
+ scf.yield
+ }
+ return
+}
+// CHECK-LABEL: func @nested_op_alloc_subview_use_static(
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32>
+// CHECK: scf.for
+// CHECK: %[[SIZE:.+]] = affine.min
+// CHECK: memref.subview %[[ALLOC]]
+// CHECK-NEXT: }
+// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16x16xi32>
+
+transform.sequence failures(propagate) {
+^bb1(%module: !pdl.operation):
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!pdl.operation) -> !transform.op<"func.func">
+ %func_2 = transform.iree.hoist_static_alloc %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0, 16)>
+func.func @nested_op_alloc_subview_use_dynamic(%arg0 : index, %o0 : index, %o1 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : i32
+ scf.for %iv = %c0 to %arg0 step %c1 {
+ %0 = affine.min #map(%iv)
+ %1 = memref.alloc(%0, %0) : memref<?x?xi32>
+ %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref<?x?xi32> to memref<?x?xi32, strided<[?, 1], offset: ?>>
+ memref.dealloc %1 : memref<?x?xi32>
+ scf.yield
+ }
+ return
+}
+// CHECK-LABEL: func @nested_op_alloc_subview_use_dynamic(
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32>
+// CHECK: scf.for
+// CHECK: %[[SIZE:.+]] = affine.min
+// CHECK: %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC]][0, 0] [%[[SIZE]], %[[SIZE]]] [1, 1]
+// CHECK: memref.subview %[[SUBVIEW1]]
+// CHECK-NEXT: }
+// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16x16xi32>
+
+transform.sequence failures(propagate) {
+^bb1(%module: !pdl.operation):
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!pdl.operation) -> !transform.op<"func.func">
+ %func_2 = transform.iree.hoist_static_alloc %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+}
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index c0e0fce..32ad122 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -12,8 +12,12 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#define DEBUG_TYPE "iree-codegen-transforms"
namespace mlir {
namespace iree_compiler {
@@ -84,32 +88,36 @@
loadOp.getMixedSizes(), loadOp.getMixedStrides(), loadOp.getSourceDims());
}
-std::optional<Value> hoistStaticallyBoundAllocations(
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
func::FuncOp funcOp, OpBuilder &builder, Location loc,
- MemRefType allocaType, ValueRange dynamicSizes,
+ MemRefType allocLikeType, ValueRange dynamicSizes,
std::optional<uint64_t> alignment) {
IntegerAttr alignmentAttr =
alignment ? builder.getI64IntegerAttr(alignment.value()) : nullptr;
- // For static case just create a new allocation in the entry block of the
- // same size. No need to insert a subview.
+ // For static case just create a new allocation in the entry block of the same
+ // size. No need to insert a subview.
if (dynamicSizes.empty()) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToStart(&funcOp.getBody().front());
Value allocation =
- builder.create<memref::AllocaOp>(loc, allocaType, alignmentAttr);
+ builder.create<AllocLikeOpType>(loc, allocLikeType, alignmentAttr);
+ if (std::is_same<AllocLikeOpType, memref::AllocOp>::value) {
+ builder.setInsertionPoint(funcOp.getBody().front().getTerminator());
+ builder.create<memref::DeallocOp>(loc, allocation);
+ }
return allocation;
}
- /// For the dynamic but bounded case, insert an allocation
- /// of the shape of the bounds, and a subview of the
- /// required size to be used as a replacement.
+ /// For the dynamic but bounded case, insert an allocation of the shape of the
+ /// bounds, and a subview of the required size to be used as a replacement.
SmallVector<int64_t> staticShape;
SmallVector<OpFoldResult> subviewSizes;
- staticShape.reserve(allocaType.getRank());
- subviewSizes.reserve(allocaType.getRank());
+ staticShape.reserve(allocLikeType.getRank());
+ subviewSizes.reserve(allocLikeType.getRank());
int index = 0;
- for (auto dimSize : allocaType.getShape()) {
+ for (auto dimSize : allocLikeType.getShape()) {
if (!ShapedType::isDynamic(dimSize)) {
staticShape.push_back(dimSize);
subviewSizes.push_back(builder.getIndexAttr(dimSize));
@@ -123,9 +131,9 @@
staticShape.push_back(ub.value());
subviewSizes.push_back(dynamicSize);
}
- SmallVector<OpFoldResult> offsets(allocaType.getRank(),
+ SmallVector<OpFoldResult> offsets(allocLikeType.getRank(),
builder.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(allocaType.getRank(),
+ SmallVector<OpFoldResult> strides(allocLikeType.getRank(),
builder.getIndexAttr(1));
Value allocation;
@@ -133,22 +141,96 @@
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToStart(&funcOp.getBody().front());
auto allocationType =
- MemRefType::get(staticShape, allocaType.getElementType());
+ MemRefType::get(staticShape, allocLikeType.getElementType());
allocation =
- builder.create<memref::AllocaOp>(loc, allocationType, alignmentAttr);
+ builder.create<AllocLikeOpType>(loc, allocationType, alignmentAttr);
}
Value subviewOp = builder.create<memref::SubViewOp>(loc, allocation, offsets,
subviewSizes, strides);
+
+ if (std::is_same<AllocLikeOpType, memref::AllocOp>::value) {
+ builder.setInsertionPoint(funcOp.getBody().front().getTerminator());
+ builder.create<memref::DeallocOp>(loc, allocation);
+ }
return subviewOp;
}
-std::optional<Value> hoistStaticallyBoundAllocations(
- func::FuncOp funcOp, OpBuilder &builder, memref::AllocaOp allocaOp) {
+
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
+ func::FuncOp funcOp, OpBuilder &builder, AllocLikeOpType allocLikeOp) {
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(allocaOp);
- return hoistStaticallyBoundAllocations(
- funcOp, builder, allocaOp.getLoc(), allocaOp.getType(),
- allocaOp.getDynamicSizes(), allocaOp.getAlignment());
+ builder.setInsertionPoint(allocLikeOp);
+ return hoistOneStaticallyBoundAllocation<AllocLikeOpType>(
+ funcOp, builder, allocLikeOp.getLoc(), allocLikeOp.getType(),
+ allocLikeOp.getDynamicSizes(), allocLikeOp.getAlignment());
+}
+
+/// Some uses of a AllocLike can be replaced with a `memref.subview`
+/// easily. Other uses (like a use in a `scf.yield` or `func.return`) are
+/// non-trivial because of compatibility between types of different SSA values.
+static bool isUseReplaceableWithSubview(OpOperand &use) {
+ Operation *user = use.getOwner();
+ return isa<linalg::LinalgOp, memref::DeallocOp, memref::StoreOp,
+ memref::SubViewOp>(user);
+}
+
+/// Explicit instantiations for `hoistStaticallyBoundAllocationsInFunc`.
+/// Automatically trigger the explicit instantiations of the needed versions of
+/// `hoistOneStaticallyBoundAllocation`.
+template void hoistStaticallyBoundAllocationsInFunc<memref::AllocOp>(
+ RewriterBase &rewriter, func::FuncOp funcOp);
+template void hoistStaticallyBoundAllocationsInFunc<memref::AllocaOp>(
+ RewriterBase &rewriter, func::FuncOp funcOp);
+
+template <typename AllocLikeOpType>
+void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter,
+ func::FuncOp funcOp) {
+ SmallVector<AllocLikeOpType> allocLikeOps;
+
+ // Collect all allocLikes that are hoistable.
+ funcOp.walk([&](AllocLikeOpType allocLikeOp) {
+ if (allocLikeOp->getBlock() == &funcOp.getBody().front()) return;
+ if (allocLikeOp.getDynamicSizes().empty()) {
+ allocLikeOps.push_back(allocLikeOp);
+ return;
+ }
+ if (llvm::all_of(allocLikeOp->getUses(), [](OpOperand &use) {
+ return isUseReplaceableWithSubview(use);
+ })) {
+ allocLikeOps.push_back(allocLikeOp);
+ return;
+ }
+ });
+
+ // Hoist the allocLikes and replace all uses.
+ for (auto allocLikeOp : allocLikeOps) {
+ // Record potential memref::DeallocOps to clean up after hoisting occurs.
+ SmallVector<memref::DeallocOp> deallocOps;
+ for (Operation *user : allocLikeOp->getUsers()) {
+ auto dealloc = dyn_cast<memref::DeallocOp>(user);
+ if (dealloc) deallocOps.push_back(dealloc);
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Alloca Op : ";
+ allocLikeOp->dump();
+ int numUses = std::distance(allocLikeOp.getResult().use_begin(),
+ allocLikeOp.getResult().use_end());
+ llvm::dbgs() << " num Uses : " << numUses;
+ });
+ std::optional<Value> replacement =
+ hoistOneStaticallyBoundAllocation(funcOp, rewriter, allocLikeOp);
+ if (!replacement) continue;
+ LLVM_DEBUG({
+ llvm::dbgs() << "Replacement : ";
+ replacement->dump();
+ });
+ Value replacementVal = replacement.value();
+ rewriter.replaceOp(allocLikeOp, replacementVal);
+
+ for (memref::DeallocOp deallocOp : deallocOps) rewriter.eraseOp(deallocOp);
+ }
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
index ce6cbb8..e073ce0 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -45,7 +44,8 @@
/// allocations creates an allocation, and inserts a subview to match the
/// dynamic shape of the allocation. Returns std::nullopt if the method
/// couldnt creat an allocation in the entry block.
-std::optional<Value> hoistStaticallyBoundAllocations(
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
func::FuncOp funcOp, OpBuilder &builder, Location loc,
MemRefType allocaType, ValueRange dynamicSizes,
std::optional<uint64_t> alignment);
@@ -56,9 +56,15 @@
/// allocations creates an allocation, and inserts a subview to match the
/// dynamic shape of the allocation. The method returns a value, but
/// does not replace the uses of the `allocaOp`.
-std::optional<Value> hoistStaticallyBoundAllocations(func::FuncOp funcOp,
- OpBuilder &builder,
- memref::AllocaOp allocaOp);
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
+ func::FuncOp funcOp, OpBuilder &builder, AllocLikeOpType allocaOp);
+
+/// Traverse funcOp and try to hoist every AllocaOp to the entry block of the
+/// function if the size is statically bounded.
+template <typename AllocLikeOpType>
+void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter,
+ func::FuncOp funcOp);
/// Insert patterns to perform folding of AffineMinOp by matching the
/// pattern generated by tile and distribute. Try to fold a affine.min op by