[Codegen] Push up the extract slice op (#19680)

Push the extract_slice ops to the beginning of the block if all its
operands are block arguments. This lets the bufferization framework know
the presense of subset buffer that can be reused.
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
index 9f47175..2d72364 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -253,10 +254,39 @@
 };
 } // namespace
 
+// Find the earliest insertion point in the block for the given operation.
+static Operation *getEarliestInsertionPointInsideBlock(Block *block,
+                                                       Operation *op) {
+
+  Operation *currInsertionPoint = &(*block->getOperations().begin());
+  DominanceInfo dominanceInfo(currInsertionPoint);
+
+  for (auto operand : op->getOperands()) {
+    if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
+      continue;
+    }
+    Operation *defOp = operand.getDefiningOp();
+    if (!dominanceInfo.dominates(defOp, currInsertionPoint)) {
+      currInsertionPoint = defOp;
+    }
+  }
+  return currInsertionPoint;
+}
+
 void OptimizeTensorInsertExtractSlicesPass::runOnOperation() {
   auto funcOp = getOperation();
   IRRewriter rewriter(funcOp->getContext());
 
+  // TODO: This is a temporary hack enabled for bufferization to
+  // get rid of empty buffers.
+  // Tracked here: https://github.com/llvm/llvm-project/issues/122869
+  funcOp.walk([&](tensor::ExtractSliceOp extractSliceOp) {
+    Block *currBlock = extractSliceOp.getOperation()->getBlock();
+    auto latestInsertionPoint =
+        getEarliestInsertionPointInsideBlock(currBlock, extractSliceOp);
+    extractSliceOp->moveAfter(latestInsertionPoint);
+  });
+
   funcOp.walk([&](scf::ForOp forOp) { moveLoopInvariantCode(forOp); });
   LDBG("after hoisting loop invariant code\n" << funcOp);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
index cbb76b0..0aacf4e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
@@ -321,3 +321,21 @@
 // CHECK-LABEL: @fold_identity_extract_slice
 //       CHECK:   %[[ARG0:.+]]: tensor<?xf32>
 //       CHECK:   return %[[ARG0]]
+
+// -----
+
+func.func @push_up_extract_slice(%arg0: index, %arg1: vector<64x64xf32>, %arg2: tensor<2x4096x10x64xf16>) -> tensor<1x64x1x64xf16> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<64x64xf16>
+  %c2 = arith.constant 2 : index
+  %1 = arith.addi %arg0, %c2 : index
+  %2 = arith.truncf %arg1 : vector<64x64xf32> to vector<64x64xf16>
+  %3 = vector.transfer_write %2, %0[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf16>, tensor<64x64xf16>
+  %extracted_slice = tensor.extract_slice %arg2[%arg0, %c2, %1, %arg0] [1, 64, 1, 64] [1, 1, 1, 1] : tensor<2x4096x10x64xf16> to tensor<1x64x1x64xf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [1, 64, 1, 64] [1, 1, 1, 1] : tensor<64x64xf16> into tensor<1x64x1x64xf16>
+  return %inserted_slice : tensor<1x64x1x64xf16>
+}
+
+// CHECK-LABEL: @push_up_extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   vector.transfer_write
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
index 3d3504d..3466bf8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
@@ -142,14 +142,6 @@
 //          CHECK:   scf.forall ({{.*}}) in (17, 81) {
 //          CHECK:     %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C721]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
 //          CHECK:       gpu.barrier
-//      CHECK-DAG:       %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<1xf16>
-//      CHECK-DAG:       vector.transfer_write %[[LHS_RD]]
-// Note that to simplify the test we are not showing the mapping of the RHS_RD
-// to its buffer as it goes through an scf.if/else control structure
-// involving allocas.
-//      CHECK-DAG:       %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf16>
-//      CHECK-DAG:       vector.transfer_write %[[RHS_RD]]
-//          CHECK:       gpu.barrier
 //      CHECK-DAG:       %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<4xf16>
 //      CHECK-DAG:       %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4x1x1xf16>
 // CHECK-COUNT-1:       amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 1a521e6..75dc8b2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -1151,11 +1151,6 @@
 //       CHECK:   scf.forall ({{.*}}) in (12, 37, 10) {
 //       CHECK:     %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c145 step %c1 {{.*}} -> (vector<1x1x1x4x1xf32>)
 //       CHECK:       gpu.barrier
-//   CHECK-DAG:       %[[LHS_RD:.+]] = vector.transfer_read {{.*}} vector<4xf32>
-//   CHECK-DAG:       vector.transfer_write %[[LHS_RD]]
-//   CHECK-DAG:       %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf32>
-//   CHECK-DAG:       vector.transfer_write %[[RHS_RD]]
-//       CHECK:       gpu.barrier
 //   CHECK-DAG:       vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<1xf32>
 //   CHECK-DAG:       vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<1xf32>
 // CHECK-COUNT-1:     amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
index 4396888..0a27cb9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
@@ -552,16 +552,16 @@
 // CHECK-DAG:     %[[RHS_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<64x1281x1281xf16, #hal.descriptor_type<storage_buffer>>
 // CHECK-DAG:     %[[OUT_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : memref<64x968x1281xf16, #hal.descriptor_type<storage_buffer>>
 // CHECK-DAG:     %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
-// CHECK:         %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
+// CHECK-DAG:     %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK-DAG:     %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
-// CHECK:         %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
+// CHECK-DAG:     %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:         vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
 // CHECK:         vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
 // CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
 // CHECK-DAG:       %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
-// CHECK:           %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
+// CHECK-DAG:       %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
 // CHECK-DAG:       %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
-// CHECK:           %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
+// CHECK-DAG:       %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:           gpu.barrier
 // CHECK-DAG:       %{{.+}} = vector.transfer_read %[[LHS_SHARED]]
 // CHECK-DAG:       %{{.+}} = vector.transfer_read %[[RHS_SHARED]]