Add transform dialect op to allow hoisting bounded allocs in a function (#12284)

In the process, this refactors and extends the transformation to perform
hoisting of bounded memref::AllocaOp to also work with memref::AllocOp.

This makes it possible to reuse the utility with GPUs.
diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp
index 4c56f99..d25831b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp
@@ -4,16 +4,16 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 
-#define DEBUG_TYPE "iree-codegen-hoist-statically-bound-allocations"
-
 namespace mlir {
 namespace iree_compiler {
 
@@ -26,54 +26,10 @@
 
 }  // namespace
 
-/// Some uses of a `memref.alloca` can be replaced with a `memref.subview`
-/// easily. Other uses (like a use in a `scf.yield` or `func.return`) are
-/// non-trivial because of compatibility between types of different SSA values.
-static bool isUseReplacableWithSubview(OpOperand &use) {
-  Operation *user = use.getOwner();
-  return isa<linalg::LinalgOp, memref::StoreOp, memref::SubViewOp>(user);
-}
-
 void HoistStaticallyBoundAllocationsPass::runOnOperation() {
   func::FuncOp funcOp = getOperation();
-  SmallVector<memref::AllocaOp> allocaOps;
-
-  // Collect all allocas that are hoistable.
-  funcOp.walk([&](memref::AllocaOp allocaOp) {
-    if (allocaOp->getBlock() == &funcOp.getBody().front()) return;
-    if (allocaOp.getDynamicSizes().empty()) {
-      allocaOps.push_back(allocaOp);
-      return;
-    }
-    if (llvm::all_of(allocaOp->getUses(), [](OpOperand &use) {
-          return isUseReplacableWithSubview(use);
-        })) {
-      allocaOps.push_back(allocaOp);
-      return;
-    }
-  });
-
-  // Hoist the allocas and replace all uses.
-  OpBuilder builder(&getContext());
-  for (auto allocaOp : allocaOps) {
-    LLVM_DEBUG({
-      llvm::dbgs() << "Alloca Op : ";
-      allocaOp->dump();
-      int numUses = std::distance(allocaOp.getResult().use_begin(),
-                                  allocaOp.getResult().use_end());
-      llvm::dbgs() << " num Uses : " << numUses;
-    });
-    std::optional<Value> replacement =
-        hoistStaticallyBoundAllocations(funcOp, builder, allocaOp);
-    if (!replacement) continue;
-    LLVM_DEBUG({
-      llvm::dbgs() << "Replacement : ";
-      replacement->dump();
-    });
-    Value replacementVal = replacement.value();
-    allocaOp.getResult().replaceAllUsesWith(replacementVal);
-    allocaOp->erase();
-  }
+  IRRewriter rewriter(funcOp->getContext());
+  hoistStaticallyBoundAllocationsInFunc<memref::AllocaOp>(rewriter, funcOp);
 }
 
 std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
index e72cd79..fa01046 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
@@ -67,6 +67,7 @@
         "//compiler/src/iree/compiler/Codegen:PassHeaders",
         "//compiler/src/iree/compiler/Codegen/Common:CommonPasses",
         "//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
+        "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index 5bd3d9f..6d6d6de 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -58,6 +58,7 @@
     iree::compiler::Codegen::Common::CommonPasses
     iree::compiler::Codegen::Interfaces::BufferizationInterfaces
     iree::compiler::Codegen::PassHeaders
+    iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 753da93..9983bd5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -14,6 +14,7 @@
 #include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
 #include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -54,6 +55,68 @@
       mlir::iree_compiler::IREE::transform_dialect::CommonExtensions>();
 }
 
+// Return true if all the uses of op are either Store/transfer_write.
+// There can be SubviewOp users as long as all its users are also
+// StoreOp/transfer_write. If return true it also fills out the uses, if it
+// returns false uses is unchanged.
+static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
+  std::vector<Operation *> opUses;
+  for (OpOperand &use : op->getUses()) {
+    Operation *useOp = use.getOwner();
+    if (isa<memref::DeallocOp, vector::TransferWriteOp, memref::StoreOp>(
+            useOp) ||
+        (isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
+      opUses.push_back(useOp);
+      continue;
+    }
+    return false;
+  }
+  uses.insert(uses.end(), opUses.begin(), opUses.end());
+  return true;
+}
+
+// Track temporary allocations that are never read from. If this is the case
+// it means both the allocations and associated stores can be removed.
+static void eraseDeadAllocAndStores(Operation *parentOp) {
+  std::vector<Operation *> opToErase;
+  parentOp->walk([&](memref::AllocOp op) {
+    if (allUsesAreStores(op, opToErase)) {
+      opToErase.push_back(op.getOperation());
+    }
+  });
+  for (Operation *op : opToErase) {
+    op->erase();
+  }
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyBufferOptimizationsOp
+//===---------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
+    Operation *target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  // Apply store to load forwarding and dead store elimination.
+  vector::transferOpflowOpt(target);
+  eraseDeadAllocAndStores(target);
+
+  results.push_back(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::ApplyBufferOptimizationsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
+void transform_dialect::ApplyBufferOptimizationsOp::build(
+    OpBuilder &builder, OperationState &result, Value target) {
+  result.addOperands(target);
+  result.addTypes({pdl::OperationType::get(target.getContext())});
+}
+
 //===---------------------------------------------------------------------===//
 // ApplyPatternsOp
 //===---------------------------------------------------------------------===//
@@ -311,6 +374,20 @@
 }
 
 //===----------------------------------------------------------------------===//
+// HoistStaticAllocOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform_dialect::HoistStaticAllocOp::applyToOne(
+    func::FuncOp funcOp, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  IRRewriter rewriter(funcOp->getContext());
+  mlir::iree_compiler::hoistStaticallyBoundAllocationsInFunc<memref::AllocOp>(
+      rewriter, funcOp);
+  results.push_back(funcOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
 // ShareForallOperandsOp
 //===----------------------------------------------------------------------===//
 
@@ -482,14 +559,14 @@
   // Step 3. Outline the compute workload region and set up the workload
   // operands, if this has not been done already.
   // Using `transform.iree.tile_to_forall_and_workgroup_count_region` is
-  // the preferred way to set up tiling and workgroup_count region **at the same
-  // time**.
+  // the preferred way to set up tiling and workgroup_count region **at the
+  // same time**.
   //
-  // The block of code below will be retired once there is enough confidence we
-  // can do everything without it. This includes in particular providing custom
-  // fusion heuristics at the flow level: at this time, the only way to fully
-  // control fusion of more advanced cases is to use the transform dialect at
-  // the flow level and explicitly match the ops we want to fuse.
+  // The block of code below will be retired once there is enough confidence
+  // we can do everything without it. This includes in particular providing
+  // custom fusion heuristics at the flow level: at this time, the only way to
+  // fully control fusion of more advanced cases is to use the transform
+  // dialect at the flow level and explicitly match the ops we want to fuse.
   // Once fusion is customizable enough in perpetuity, we can retire this.
   if (exportOp.getWorkgroupCount().empty()) {
     if (llvm::any_of(forallOp.getUpperBound(rewriter), [](Value v) {
@@ -497,9 +574,11 @@
         })) {
       return forallOp->emitError(
           "unsupported dynamic workgroup_count atm --- need to slice out "
-          "workgroup_count computation into ExecutableExport::workgroup_count."
+          "workgroup_count computation into "
+          "ExecutableExport::workgroup_count."
           "\nThis region may require arbitrary computations and cannot "
-          "magically match what the `stream.cmd.dispatch` has already imposed "
+          "magically match what the `stream.cmd.dispatch` has already "
+          "imposed "
           "on us at a distance."
           "\nFor now we must specify the number of values properly when "
           "applying the topLevel tile_to_forall_op");
@@ -628,8 +707,8 @@
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
   // Call the default builder which sets up the proper operands segment sizes
-  // attributes for multiple variadic operands. In the absence of this, horrible
-  // bugs ensue.
+  // attributes for multiple variadic operands. In the absence of this,
+  // horrible bugs ensue.
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
 
@@ -665,8 +744,8 @@
   dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
                              staticNumThreads);
   // Call the default builder which sets up the proper operands segment sizes
-  // attributes for multiple variadic operands. In the absence of this, horrible
-  // bugs ensue.
+  // attributes for multiple variadic operands. In the absence of this,
+  // horrible bugs ensue.
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   build(builder, result,
@@ -1226,64 +1305,5 @@
   return DiagnosedSilenceableFailure::success();
 }
 
-// Return true if all the uses of op are either Store/transfer_write.
-// There can be SubviewOp users as long as all its users are also
-// StoreOp/transfer_write. If return true it also fills out the uses, if it
-// returns false uses is unchanged.
-static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
-  std::vector<Operation *> opUses;
-  for (OpOperand &use : op->getUses()) {
-    Operation *useOp = use.getOwner();
-    if (isa<memref::DeallocOp, vector::TransferWriteOp, memref::StoreOp>(
-            useOp) ||
-        (isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
-      opUses.push_back(useOp);
-      continue;
-    }
-    return false;
-  }
-  uses.insert(uses.end(), opUses.begin(), opUses.end());
-  return true;
-}
-
-// Track temporary allocations that are never read from. If this is the case
-// it means both the allocations and associated stores can be removed.
-static void eraseDeadAllocAndStores(Operation *parentOp) {
-  std::vector<Operation *> opToErase;
-  parentOp->walk([&](memref::AllocOp op) {
-    if (allUsesAreStores(op, opToErase)) {
-      opToErase.push_back(op.getOperation());
-    }
-  });
-  for (Operation *op : opToErase) {
-    op->erase();
-  }
-}
-
-DiagnosedSilenceableFailure
-transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
-    Operation *target, transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  // Apply store to load forwarding and dead store elimination.
-  vector::transferOpflowOpt(target);
-  eraseDeadAllocAndStores(target);
-
-  results.push_back(target);
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::ApplyBufferOptimizationsOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getTarget(), effects);
-  transform::producesHandle(getResult(), effects);
-  transform::modifiesPayload(effects);
-}
-
-void transform_dialect::ApplyBufferOptimizationsOp::build(
-    OpBuilder &builder, OperationState &result, Value target) {
-  result.addOperands(target);
-  result.addTypes({pdl::OperationType::get(target.getContext())});
-}
-
 #define GET_OP_CLASSES
 #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index ea4491e..2e9de43 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -17,6 +17,42 @@
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyBufferOptimizationsOp :
+  Op<Transform_Dialect, "iree.apply_buffer_optimizations",
+    [TransformEachOpTrait,
+     TransformOpInterface,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+    This applies memory optimization on memref. In particular it does store to
+    load forwarding, dead store elimination and dead alloc elimination.
+
+    #### Return modes
+
+    This operation applies a set of memory optimization on the whole region of
+    the operand.
+
+    If the transformation is successful it returns the handle to the
+    same payload as its operand to allow for simpler composition.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$result);
+
+  let assemblyFormat = "$target attr-dict";
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def ApplyPatternsOp : Op<Transform_Dialect, "iree.apply_patterns",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformEachOpTrait,
@@ -125,6 +161,32 @@
   }];
 }
 
+def HoistStaticAllocOp :  Op<Transform_Dialect, "iree.hoist_static_alloc",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Hoist static allocations";
+  let description = [{
+    Find static allocations and hoist them to the top level.
+
+    #### Return modes
+    This operation applies static alloc hoisting the whole region of the operand.
+    It always return success.
+  }];
+
+  let arguments = (ins Transform_ConcreteOpType<"func.func">:$target);
+  let results = (outs Transform_ConcreteOpType<"func.func">:$result);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::func::FuncOp funcOp,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def IREEBufferizeOp : Op<Transform_Dialect, "iree.bufferize",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
@@ -453,40 +515,4 @@
   }];
 }
 
-def ApplyBufferOptimizationsOp :
-  Op<Transform_Dialect, "iree.apply_buffer_optimizations",
-    [TransformEachOpTrait,
-     TransformOpInterface,
-     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let description = [{
-    This applies memory optimization on memref. In particular it does store to
-    load forwarding, dead store elimination and dead alloc elimination.
-
-    #### Return modes
-
-    This operation applies a set of memory optimization on the whole region of
-    the operand.
-
-    If the transformation is successful it returns the handle to the
-    same payload as its operand to allow for simpler composition.
-  }];
-
-  let arguments = (ins PDL_Operation:$target);
-  let results = (outs PDL_Operation:$result);
-
-  let assemblyFormat = "$target attr-dict";
-  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-
-  let skipDefaultBuilders = 1;
-  let builders = [
-    OpBuilder<(ins "Value":$target)>
-  ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
 #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 901f93f..d2a44c8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -19,6 +19,7 @@
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/TargetSelect.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index d96ee7b..91bb0d7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -98,8 +98,9 @@
                                         unsigned alignment) {
   auto funcOp = builder.getInsertionPoint()->getParentOfType<func::FuncOp>();
   if (funcOp) {
-    std::optional<Value> hoistedAllocation = hoistStaticallyBoundAllocations(
-        funcOp, builder, loc, memRefType, dynamicSizes, alignment);
+    std::optional<Value> hoistedAllocation =
+        hoistOneStaticallyBoundAllocation<memref::AllocaOp>(
+            funcOp, builder, loc, memRefType, dynamicSizes, alignment);
     if (hoistedAllocation) {
       return hoistedAllocation.value();
     }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
index a65da9a..adbf5e7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
@@ -14,6 +14,7 @@
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
index 616ad87..a60227e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -39,6 +39,7 @@
             "tensor_pad.mlir",
             "tensorcore_vectorization.mlir",
             "tile_on_tensor.mlir",
+            "transform_dialect_hoist_allocs.mlir",
             "transform_dialect_vector_distribution.mlir",
             "transform_dialect_bufferize.mlir",
             "transform_dialect_promote_operands.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index b63550c..78b834c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -36,6 +36,7 @@
     "tensorcore_vectorization.mlir"
     "tile_on_tensor.mlir"
     "transform_dialect_bufferize.mlir"
+    "transform_dialect_hoist_allocs.mlir"
     "transform_dialect_promote_operands.mlir"
     "transform_dialect_vector_distribution.mlir"
     "transform_distribute_forall.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_hoist_allocs.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_hoist_allocs.mlir
new file mode 100644
index 0000000..cabc06b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_hoist_allocs.mlir
@@ -0,0 +1,88 @@
+// RUN: iree-opt --split-input-file -iree-transform-dialect-interpreter -transform-dialect-drop-schedule %s | FileCheck %s
+
+func.func @non_entry_bb_allocs() {
+  cf.br ^bb1
+ ^bb1() :
+  %0 = memref.alloc() : memref<16xi32>
+  memref.dealloc %0 : memref<16xi32>
+  return
+}
+// CHECK-LABEL: func @non_entry_bb_allocs()
+//  CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
+//  CHECK-NEXT:   memref.dealloc %[[ALLOC]] : memref<16xi32>
+//  CHECK-NEXT:   cf.br ^bb1
+//  CHECK-NEXT:   ^bb1:
+//  CHECK-NEXT:   return
+
+transform.sequence failures(propagate) {
+^bb1(%module: !pdl.operation):
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!pdl.operation) -> !transform.op<"func.func">
+    %func_2 = transform.iree.hoist_static_alloc %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0, 16)>
+func.func @nested_op_alloc_subview_use_static(%arg0 : index, %o0 : index, %o1 : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c42 = arith.constant 42 : i32
+  scf.for %iv = %c0 to %arg0 step %c1 {
+    %0 = affine.min #map(%iv)
+    %1 = memref.alloc() : memref<16x16xi32>
+    %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref<16x16xi32> to memref<?x?xi32, strided<[?, 1], offset: ?>>
+    memref.dealloc %1 : memref<16x16xi32>
+    scf.yield
+  }
+  return
+}
+// CHECK-LABEL: func @nested_op_alloc_subview_use_static(
+//  CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32>
+//       CHECK:   scf.for
+//       CHECK:     %[[SIZE:.+]] = affine.min
+//       CHECK:     memref.subview %[[ALLOC]]
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   memref.dealloc %[[ALLOC]] : memref<16x16xi32>
+
+transform.sequence failures(propagate) {
+^bb1(%module: !pdl.operation):
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!pdl.operation) -> !transform.op<"func.func">
+    %func_2 = transform.iree.hoist_static_alloc %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+}
+
+// -----
+
+#map = affine_map<(d0) -> (d0, 16)>
+func.func @nested_op_alloc_subview_use_dynamic(%arg0 : index, %o0 : index, %o1 : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c42 = arith.constant 42 : i32
+  scf.for %iv = %c0 to %arg0 step %c1 {
+    %0 = affine.min #map(%iv)
+    %1 = memref.alloc(%0, %0) : memref<?x?xi32>
+    %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref<?x?xi32> to memref<?x?xi32, strided<[?, 1], offset: ?>>
+    memref.dealloc %1 : memref<?x?xi32>
+    scf.yield
+  }
+  return
+}
+// CHECK-LABEL: func @nested_op_alloc_subview_use_dynamic(
+//  CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32>
+//       CHECK:   scf.for
+//       CHECK:     %[[SIZE:.+]] = affine.min
+//       CHECK:     %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC]][0, 0] [%[[SIZE]], %[[SIZE]]] [1, 1]
+//       CHECK:     memref.subview %[[SUBVIEW1]]
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   memref.dealloc %[[ALLOC]] : memref<16x16xi32>
+
+transform.sequence failures(propagate) {
+^bb1(%module: !pdl.operation):
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!pdl.operation) -> !transform.op<"func.func">
+    %func_2 = transform.iree.hoist_static_alloc %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+}
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index c0e0fce..32ad122 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -12,8 +12,12 @@
 
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 
+#include "llvm/Support/Debug.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#define DEBUG_TYPE "iree-codegen-transforms"
 
 namespace mlir {
 namespace iree_compiler {
@@ -84,32 +88,36 @@
       loadOp.getMixedSizes(), loadOp.getMixedStrides(), loadOp.getSourceDims());
 }
 
-std::optional<Value> hoistStaticallyBoundAllocations(
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
     func::FuncOp funcOp, OpBuilder &builder, Location loc,
-    MemRefType allocaType, ValueRange dynamicSizes,
+    MemRefType allocLikeType, ValueRange dynamicSizes,
     std::optional<uint64_t> alignment) {
   IntegerAttr alignmentAttr =
       alignment ? builder.getI64IntegerAttr(alignment.value()) : nullptr;
-  // For static case just create a new allocation in the entry block of the
-  // same size. No need to insert a subview.
+  // For static case just create a new allocation in the entry block of the same
+  // size. No need to insert a subview.
   if (dynamicSizes.empty()) {
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToStart(&funcOp.getBody().front());
     Value allocation =
-        builder.create<memref::AllocaOp>(loc, allocaType, alignmentAttr);
+        builder.create<AllocLikeOpType>(loc, allocLikeType, alignmentAttr);
+    if (std::is_same<AllocLikeOpType, memref::AllocOp>::value) {
+      builder.setInsertionPoint(funcOp.getBody().front().getTerminator());
+      builder.create<memref::DeallocOp>(loc, allocation);
+    }
     return allocation;
   }
 
-  /// For the dynamic but bounded case, insert an allocation
-  /// of the shape of the bounds, and a subview of the
-  /// required size to be used as a replacement.
+  /// For the dynamic but bounded case, insert an allocation of the shape of the
+  /// bounds, and a subview of the required size to be used as a replacement.
   SmallVector<int64_t> staticShape;
   SmallVector<OpFoldResult> subviewSizes;
-  staticShape.reserve(allocaType.getRank());
-  subviewSizes.reserve(allocaType.getRank());
+  staticShape.reserve(allocLikeType.getRank());
+  subviewSizes.reserve(allocLikeType.getRank());
 
   int index = 0;
-  for (auto dimSize : allocaType.getShape()) {
+  for (auto dimSize : allocLikeType.getShape()) {
     if (!ShapedType::isDynamic(dimSize)) {
       staticShape.push_back(dimSize);
       subviewSizes.push_back(builder.getIndexAttr(dimSize));
@@ -123,9 +131,9 @@
     staticShape.push_back(ub.value());
     subviewSizes.push_back(dynamicSize);
   }
-  SmallVector<OpFoldResult> offsets(allocaType.getRank(),
+  SmallVector<OpFoldResult> offsets(allocLikeType.getRank(),
                                     builder.getIndexAttr(0));
-  SmallVector<OpFoldResult> strides(allocaType.getRank(),
+  SmallVector<OpFoldResult> strides(allocLikeType.getRank(),
                                     builder.getIndexAttr(1));
 
   Value allocation;
@@ -133,22 +141,96 @@
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToStart(&funcOp.getBody().front());
     auto allocationType =
-        MemRefType::get(staticShape, allocaType.getElementType());
+        MemRefType::get(staticShape, allocLikeType.getElementType());
     allocation =
-        builder.create<memref::AllocaOp>(loc, allocationType, alignmentAttr);
+        builder.create<AllocLikeOpType>(loc, allocationType, alignmentAttr);
   }
 
   Value subviewOp = builder.create<memref::SubViewOp>(loc, allocation, offsets,
                                                       subviewSizes, strides);
+
+  if (std::is_same<AllocLikeOpType, memref::AllocOp>::value) {
+    builder.setInsertionPoint(funcOp.getBody().front().getTerminator());
+    builder.create<memref::DeallocOp>(loc, allocation);
+  }
   return subviewOp;
 }
-std::optional<Value> hoistStaticallyBoundAllocations(
-    func::FuncOp funcOp, OpBuilder &builder, memref::AllocaOp allocaOp) {
+
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
+    func::FuncOp funcOp, OpBuilder &builder, AllocLikeOpType allocLikeOp) {
   OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPoint(allocaOp);
-  return hoistStaticallyBoundAllocations(
-      funcOp, builder, allocaOp.getLoc(), allocaOp.getType(),
-      allocaOp.getDynamicSizes(), allocaOp.getAlignment());
+  builder.setInsertionPoint(allocLikeOp);
+  return hoistOneStaticallyBoundAllocation<AllocLikeOpType>(
+      funcOp, builder, allocLikeOp.getLoc(), allocLikeOp.getType(),
+      allocLikeOp.getDynamicSizes(), allocLikeOp.getAlignment());
+}
+
+/// Some uses of a AllocLike can be replaced with a `memref.subview`
+/// easily. Other uses (like a use in a `scf.yield` or `func.return`) are
+/// non-trivial because of compatibility between types of different SSA values.
+static bool isUseReplaceableWithSubview(OpOperand &use) {
+  Operation *user = use.getOwner();
+  return isa<linalg::LinalgOp, memref::DeallocOp, memref::StoreOp,
+             memref::SubViewOp>(user);
+}
+
+/// Explicit instantiations for `hoistStaticallyBoundAllocationsInFunc`.
+/// Automatically trigger the explicit instantiations of the needed versions of
+/// `hoistOneStaticallyBoundAllocation`.
+template void hoistStaticallyBoundAllocationsInFunc<memref::AllocOp>(
+    RewriterBase &rewriter, func::FuncOp funcOp);
+template void hoistStaticallyBoundAllocationsInFunc<memref::AllocaOp>(
+    RewriterBase &rewriter, func::FuncOp funcOp);
+
+template <typename AllocLikeOpType>
+void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter,
+                                           func::FuncOp funcOp) {
+  SmallVector<AllocLikeOpType> allocLikeOps;
+
+  // Collect all allocLikes that are hoistable.
+  funcOp.walk([&](AllocLikeOpType allocLikeOp) {
+    if (allocLikeOp->getBlock() == &funcOp.getBody().front()) return;
+    if (allocLikeOp.getDynamicSizes().empty()) {
+      allocLikeOps.push_back(allocLikeOp);
+      return;
+    }
+    if (llvm::all_of(allocLikeOp->getUses(), [](OpOperand &use) {
+          return isUseReplaceableWithSubview(use);
+        })) {
+      allocLikeOps.push_back(allocLikeOp);
+      return;
+    }
+  });
+
+  // Hoist the allocLikes and replace all uses.
+  for (auto allocLikeOp : allocLikeOps) {
+    // Record potential memref::DeallocOps to clean up after hoisting occurs.
+    SmallVector<memref::DeallocOp> deallocOps;
+    for (Operation *user : allocLikeOp->getUsers()) {
+      auto dealloc = dyn_cast<memref::DeallocOp>(user);
+      if (dealloc) deallocOps.push_back(dealloc);
+    }
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "Alloca Op : ";
+      allocLikeOp->dump();
+      int numUses = std::distance(allocLikeOp.getResult().use_begin(),
+                                  allocLikeOp.getResult().use_end());
+      llvm::dbgs() << " num Uses : " << numUses;
+    });
+    std::optional<Value> replacement =
+        hoistOneStaticallyBoundAllocation(funcOp, rewriter, allocLikeOp);
+    if (!replacement) continue;
+    LLVM_DEBUG({
+      llvm::dbgs() << "Replacement : ";
+      replacement->dump();
+    });
+    Value replacementVal = replacement.value();
+    rewriter.replaceOp(allocLikeOp, replacementVal);
+
+    for (memref::DeallocOp deallocOp : deallocOps) rewriter.eraseOp(deallocOp);
+  }
 }
 
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
index ce6cbb8..e073ce0 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
@@ -14,7 +14,6 @@
 
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -45,7 +44,8 @@
 /// allocations creates an allocation, and inserts a subview to match the
 /// dynamic shape of the allocation. Returns std::nullopt if the method
 /// couldnt creat an allocation in the entry block.
-std::optional<Value> hoistStaticallyBoundAllocations(
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
     func::FuncOp funcOp, OpBuilder &builder, Location loc,
     MemRefType allocaType, ValueRange dynamicSizes,
     std::optional<uint64_t> alignment);
@@ -56,9 +56,15 @@
 /// allocations creates an allocation, and inserts a subview to match the
 /// dynamic shape of the allocation. The method returns a value, but
 /// does not replace the uses of the `allocaOp`.
-std::optional<Value> hoistStaticallyBoundAllocations(func::FuncOp funcOp,
-                                                     OpBuilder &builder,
-                                                     memref::AllocaOp allocaOp);
+template <typename AllocLikeOpType>
+std::optional<Value> hoistOneStaticallyBoundAllocation(
+    func::FuncOp funcOp, OpBuilder &builder, AllocLikeOpType allocaOp);
+
+/// Traverse funcOp and try to hoist every AllocaOp to the entry block of the
+/// function if the size is statically bounded.
+template <typename AllocLikeOpType>
+void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter,
+                                           func::FuncOp funcOp);
 
 /// Insert patterns to perform folding of AffineMinOp by matching the
 /// pattern generated by tile and distribute. Try to fold a affine.min op by