Split iree.bufferize_op to enable additional canonicalization (#11570)

This is to avoid copies inside in case split-k does not divide the input
size evenly (`reduction_v2_uneven.mlir`).

Also add additional patterns to ApplyPatternsOp and run them before
bufferization in `reduction_v2_uneven.mlir`.

This change includes #11495, which was dropped at some point.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
index 0984be7..2d6abb8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
@@ -52,9 +52,11 @@
 
 // TODO: significantly better namespacing.
 using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
+using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
 using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
+using iree_compiler::IREE::transform_dialect::IREEEliminateEmptyTensorsOp;
 using iree_compiler::IREE::transform_dialect::
     IREEEraseHALDescriptorTypeFromMemRefOp;
 using iree_compiler::IREE::transform_dialect::
@@ -213,13 +215,20 @@
 // TODO: configure patterns.
 Value mlir::iree_compiler::buildVectorize(ImplicitLocOpBuilder &b,
                                           Value funcH) {
-  funcH = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/true);
+  ApplyPatternsOpPatterns patterns;
+  patterns.rankReducing = true;
+  funcH = b.create<ApplyPatternsOp>(funcH, patterns);
   return b.create<VectorizeOp>(funcH);
 }
 
 /// Bufferize and drop HAL descriptor from memref ops.
 Value mlir::iree_compiler::buildBufferize(ImplicitLocOpBuilder &b,
                                           Value variantH, bool targetGpu) {
+  Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+  ApplyPatternsOpPatterns patterns;
+  patterns.foldReassociativeReshapes = true;
+  funcH = b.create<ApplyPatternsOp>(funcH, patterns);
+  variantH = b.create<IREEEliminateEmptyTensorsOp>(variantH);
   variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
   Value memrefFunc =
       b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
@@ -244,7 +253,9 @@
 Value mlir::iree_compiler::buildDistributeVectors(ImplicitLocOpBuilder &b,
                                                   Value variantH, Value funcH,
                                                   int64_t warpSize) {
-  funcH = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/true);
+  ApplyPatternsOpPatterns patterns;
+  patterns.rankReducing = true;
+  funcH = b.create<ApplyPatternsOp>(funcH, patterns);
   Value ifH = b.create<MatchOp>(funcH, scf::IfOp::getOperationName());
   // Locally suppress failures for this op only because it doesn't cover the
   // `threadIdx.x == 0 && threadIdx.y == 0` case at the moment.
@@ -300,9 +311,9 @@
   }
 
   auto funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  auto applyPatterns = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/false);
-  applyPatterns->setAttr(applyPatterns.getBubbleCollapseExpandAttrName(),
-                         b.getUnitAttr());
+  ApplyPatternsOpPatterns patterns;
+  patterns.bubbleCollapseExpand = true;
+  b.create<ApplyPatternsOp>(funcH, patterns);
   std::tie(result.originalFillH, result.splitFillH) =
       matchAndUnpack<2>(b, variantH, linalg::FillOp::getOperationName());
   if (hasTrailingEltwise) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
index 41376fd..e72cd79 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
@@ -91,6 +91,7 @@
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:TransformDialect",
         "@llvm-project//mlir:Transforms",
     ],
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index 04ee20b..5bd3d9f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -52,6 +52,7 @@
     MLIRPass
     MLIRSCFDialect
     MLIRTensorDialect
+    MLIRTensorTransforms
     MLIRTransformDialect
     MLIRTransforms
     iree::compiler::Codegen::Common::CommonPasses
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index a0fd531..2680e1c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/PassManager.h"
@@ -55,16 +56,30 @@
 //===---------------------------------------------------------------------===//
 // ApplyPatternsOp
 //===---------------------------------------------------------------------===//
-void transform_dialect::ApplyPatternsOp::build(OpBuilder &builder,
-                                               OperationState &result,
-                                               Value target,
-                                               bool rankReducing) {
+void transform_dialect::ApplyPatternsOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    const ApplyPatternsOpPatterns &patterns) {
   MLIRContext *ctx = builder.getContext();
   result.addOperands(target);
-  if (rankReducing) {
-    result.addAttribute(ApplyPatternsOp::getRankReducingAttrName(result.name),
-                        builder.getUnitAttr());
-  }
+  auto unitAttr = builder.getUnitAttr();
+#define ADD_PATTERN(NAME, ATTR) \
+  if (patterns.NAME)            \
+    result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr);
+  ADD_PATTERN(additionalIreePatterns, getAdditionalIreePatternsAttrName)
+  ADD_PATTERN(bubbleCollapseExpand, getBubbleCollapseExpandAttrName)
+  ADD_PATTERN(canonicalization, getCanonicalizationAttrName)
+  ADD_PATTERN(eraseUnnecessaryTensorOperands,
+              getEraseUnnecessaryTensorOperandsAttrName)
+  ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName)
+  ADD_PATTERN(promoteForeachThreadCaptureToShared,
+              getPromoteForeachThreadCaptureToSharedAttrName)
+  ADD_PATTERN(rankReducing, getRankReducingAttrName)
+  ADD_PATTERN(expandMemrefStridedMetadata,
+              getExpandMemrefStridedMetadataAttrName)
+  ADD_PATTERN(swapPaddingElideConditional,
+              getSwapPaddingElideConditionalAttrName)
+  ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName)
+#undef ADD_PATTERN
   result.addTypes({pdl::OperationType::get(ctx)});
 }
 
@@ -152,6 +167,15 @@
   patterns.add<PromoteCaptureToSharedOut>(patterns.getContext());
 }
 
+static void addReassociativeReshapePatterns(RewritePatternSet &patterns) {
+  tensor::populateReassociativeReshapeFoldingPatterns(patterns);
+}
+
+static void addEraseUnnecessaryTensorOperandsPatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateEraseUnnecessaryInputsPatterns(patterns);
+}
+
 static void addRankReducingPatterns(RewritePatternSet &patterns) {
   populateReshapeToInterfaceTensorPatterns(patterns);
   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
@@ -198,6 +222,9 @@
   MLIRContext *ctx = target->getContext();
   RewritePatternSet patterns(ctx);
   if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns);
+  if (getEraseUnnecessaryTensorOperands())
+    addEraseUnnecessaryTensorOperandsPatterns(patterns);
+  if (getFoldReassociativeReshapes()) addReassociativeReshapePatterns(patterns);
   if (getPromoteForeachThreadCaptureToShared())
     addForeachThreadCapturePromotionPatterns(patterns);
   if (getRankReducing()) addRankReducingPatterns(patterns);
@@ -875,40 +902,32 @@
     memCpyFn = gpuComprehensiveBufferizeCopyFn;
   }
 
-  //   1. Eliminate tensor.empty, without the pass baggage.
-  WalkResult res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
-    if (failed(eliminateEmptyTensors(moduleOp.getOperation(),
-                                     getBufferizationOptions())))
-      return WalkResult::interrupt();
-    return WalkResult::advance();
-  });
-  if (res.wasInterrupted())
-    return DiagnosedSilenceableFailure::definiteFailure();
-
-  //   2. Rewrite tensor.empty to tensor.alloc, without the pass baggage.
-  RewritePatternSet patterns(getContext());
-  patterns.add<EmptyTensorLoweringPattern>(patterns.getContext());
-  TrackingListener listener(state);
-  GreedyRewriteConfig config;
-  LogicalResult result = applyPatternsAndFoldGreedily(
-      state.getTopLevel(), std::move(patterns), config, &listener);
-  LogicalResult listenerResult = listener.checkErrorState();
-  if (failed(result)) {
-    return mlir::emitDefiniteFailure(state.getTopLevel(),
-                                     "greedy pattern application failed");
+  //   1. Rewrite tensor.empty to tensor.alloc, without the pass baggage.
+  {
+    RewritePatternSet patterns(getContext());
+    patterns.add<EmptyTensorLoweringPattern>(patterns.getContext());
+    TrackingListener listener(state);
+    GreedyRewriteConfig config;
+    LogicalResult result = applyPatternsAndFoldGreedily(
+        state.getTopLevel(), std::move(patterns), config, &listener);
+    LogicalResult listenerResult = listener.checkErrorState();
+    if (failed(result)) {
+      return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                       "greedy pattern application failed");
+    }
+    if (failed(listenerResult))
+      return mlir::emitDefiniteFailure(state.getTopLevel(),
+                                       "listener tracking failed");
   }
-  if (failed(listenerResult))
-    return mlir::emitDefiniteFailure(state.getTopLevel(),
-                                     "listener tracking failed");
 
-  //   3. Run one-shot-bufferize, without the pass baggage.
+  //   2. Run one-shot-bufferize, without the pass baggage.
   OneShotBufferizationOptions options = getBufferizationOptions();
   options.allocationFn = allocationFn;
   options.deallocationFn = deallocationFn;
   options.memCpyFn = memCpyFn;
   options.testAnalysisOnly = getTestAnalysisOnly();
   options.printConflicts = getPrintConflicts();
-  res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
+  WalkResult res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
     if (failed(runIREEOneShotBufferize(moduleOp, options)))
       return WalkResult::interrupt();
     return WalkResult::advance();
@@ -916,7 +935,7 @@
   if (res.wasInterrupted())
     return DiagnosedSilenceableFailure::definiteFailure();
 
-  //   4. Post-bufferization passes are fine.
+  //   3. Post-bufferization passes are fine.
   PassManager pm(getContext());
   addIREEPostBufferizationPasses(pm);
   res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
@@ -937,6 +956,31 @@
 }
 
 //===---------------------------------------------------------------------===//
+// IREEEliminateEmptyTensorsOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform_dialect::IREEEliminateEmptyTensorsOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  ArrayRef<Operation *> payloads = state.getPayloadOps(getTarget());
+  for (Operation *payload : payloads) {
+    if (failed(eliminateEmptyTensors(payload, getBufferizationOptions()))) {
+      getOperation()->emitError() << "failed to eliminate tensor.empty ops";
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+  }
+  results.set(getOperation()->getOpResult(0), payloads);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform_dialect::IREEEliminateEmptyTensorsOp::build(
+    OpBuilder &builder, OperationState &result, Value target) {
+  result.addOperands(target);
+  MLIRContext *ctx = builder.getContext();
+  result.addTypes(pdl::OperationType::get(ctx));
+}
+
+//===---------------------------------------------------------------------===//
 // ConfigExtractPart
 //===---------------------------------------------------------------------===//
 void transform_dialect::ConfigExtractPart::build(OpBuilder &builder,
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index 11822ca..b6747bf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -29,6 +29,26 @@
 struct NumThreadsSpec;
 class TransformTypeInterface;
 }  // namespace transform
+
+namespace iree_compiler {
+namespace IREE {
+namespace transform_dialect {
+/// Selected patterns for ApplyPatternOp.
+struct ApplyPatternsOpPatterns {
+  bool additionalIreePatterns = false;
+  bool bubbleCollapseExpand = false;
+  bool canonicalization = false;
+  bool eraseUnnecessaryTensorOperands = false;
+  bool foldReassociativeReshapes = false;
+  bool promoteForeachThreadCaptureToShared = false;
+  bool rankReducing = false;
+  bool expandMemrefStridedMetadata = false;
+  bool swapPaddingElideConditional = false;
+  bool swappingPatterns = false;
+};
+}  // namespace transform_dialect
+}  // namespace IREE
+}  // namespace iree_compiler
 }  // namespace mlir
 
 #define GET_OP_CLASSES
@@ -43,7 +63,7 @@
 
 namespace IREE {
 namespace transform_dialect {
-// Hook to register common transformations to the transform dialect.
+/// Hook to register common transformations to the transform dialect.
 class CommonExtensions
     : public transform::TransformDialectExtension<CommonExtensions> {
  public:
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 8f7cf12..e42f75b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -38,6 +38,10 @@
       down across Linalg ops.
       - canonicalization: adds all the canonicalization patterns of all
       registered dialects and ops.
+      - erase_unnecessary_tensor_operands: add patterns that erase unnecessary
+      tensor operands.
+      - fold_reassociative_reshapes: adds patterns that fold insert_slice/
+      extract_slice ops with reassociative reshape ops.
       - promote_foreach_thread_capture_to_shared: adds patterns that rewrite
       uses of values captured by scf.foreach_thread with the matching
       shared_outs bbarg. This checks that the values captured are
@@ -75,6 +79,8 @@
                        UnitAttr:$additional_iree_patterns,
                        UnitAttr:$bubble_collapse_expand,
                        UnitAttr:$canonicalization,
+                       UnitAttr:$erase_unnecessary_tensor_operands,
+                       UnitAttr:$fold_reassociative_reshapes,
                        UnitAttr:$promote_foreach_thread_capture_to_shared,
                        UnitAttr:$rank_reducing,
                        UnitAttr:$expand_memref_strided_metadata,
@@ -87,7 +93,8 @@
 
   let builders = [
     // TODO: Some bitvector to scale better than n-bools.
-    OpBuilder<(ins "Value":$target, "bool":$rankReducing)>
+    OpBuilder<(ins "Value":$target,
+                   "const ApplyPatternsOpPatterns &":$patterns)>
   ];
 
   let extraClassDeclaration = [{
@@ -110,8 +117,8 @@
     using the following attributes:
       - target_gpu: if set, GPU allocations are emitted.
 
-    Return modes:
-    =============
+    #### Return modes
+
     This operation calls the upstream one-shot bufferization pass with extra
     registered patterns for IREE.
 
@@ -121,7 +128,7 @@
     If any of the pass on any of the ModuleOp fails, the transformation
     definitely fails. Otherwise the transformation succeeds.
 
-    No handles are consumed or produced.
+    This transform consumes the target handle and produces a result handle.
   }];
 
   let arguments = (
@@ -140,6 +147,35 @@
   ];
 }
 
+def IREEEliminateEmptyTensorsOp : Op<
+    Transform_Dialect, "iree.eliminate_empty_tensors",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    This is a pre-processing pass for iree.bufferize. It tries to remove
+    tensor.empty ops by replacing them with a suitable destination tensors,
+    which can reduce the number of allocations when bufferizing.
+
+    This transform is not part of iree.bufferize because additional
+    canonicalization are sometimes possible after eliminate_empty_tensors but
+    before iree.bufferize.
+
+    #### Return modes
+
+    This transform consumes the target handle and produces a result handle.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$result);
+  let assemblyFormat = "attr-dict $target";
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+}
+
 def IREEEraseHALDescriptorTypeFromMemRefOp : Op<Transform_Dialect,
     "iree.erase_hal_descriptor_type_from_memref",
     [FunctionalStyleTransformOpTrait,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
index fb66c9d..4a10f38 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
@@ -35,7 +35,8 @@
 
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%variant_op: !pdl.operation):
-  %variant_op_2 = transform.iree.bufferize %variant_op
-  %func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize %variant_op_2
+  %func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %func
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
index 12b63ac..53c94cc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
@@ -29,8 +29,9 @@
 
   transform.structured.canonicalized_sequence failures(propagate) {
   ^bb1(%variant_op: !pdl.operation):
-    %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-    %func = transform.structured.match ops{["func.func"]} in %variant_op_2
+    %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+    %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+    %func = transform.structured.match ops{["func.func"]} in %variant_op_3
     transform.iree.erase_hal_descriptor_type_from_memref %func
   }
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
index d5f20ac..abbb542 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
@@ -1,6 +1,7 @@
 transform.structured.canonicalized_sequence failures(propagate) {
 ^bb1(%variant_op: !pdl.operation):
-  %variant_op_2 = transform.iree.bufferize %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index 323a989..5ad61a5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -8,12 +8,13 @@
   %foreach_thread_2, %tiled_matmul = transform.structured.tile_to_foreach_thread_op %1 num_threads [7, 9]
   ( mapping = [#gpu.thread<x>, #gpu.thread<y>] )
 
-  %variant_op_2 = transform.iree.bufferize %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Get the function to which to apply to.
-  %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_2
+  %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_3
   %func = transform.get_closest_isolated_parent %2 : (!pdl.operation) -> !pdl.operation
   transform.iree.map_nested_foreach_thread_to_gpu_threads %func { workgroup_size = [10, 11]}
 }
diff --git a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
index c531af6..3efd0cb 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
@@ -13,12 +13,13 @@
 
   // Step 2. Bufferize and drop HAL decriptor from memref ops.
   // =========================================================
-  %variant_op_2 = transform.iree.bufferize %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 3. Post-bufferization mapping workgroup.
   // =========================================================
-  %func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.foreach_thread_to_workgroup %func
 }
diff --git a/tests/transform_dialect/cuda/BUILD b/tests/transform_dialect/cuda/BUILD
index 22c357e..f818876 100644
--- a/tests/transform_dialect/cuda/BUILD
+++ b/tests/transform_dialect/cuda/BUILD
@@ -29,6 +29,7 @@
         "reduction.mlir",
         "reduction_eltwise.mlir",
         "reduction_v2.mlir",
+        "reduction_v2_uneven.mlir",
         "reduction_v3.mlir",
         "softmax.mlir",
         "softmax_v2.mlir",
diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt
index 897d52f..4bd5475 100644
--- a/tests/transform_dialect/cuda/CMakeLists.txt
+++ b/tests/transform_dialect/cuda/CMakeLists.txt
@@ -21,6 +21,7 @@
     "reduction.mlir"
     "reduction_eltwise.mlir"
     "reduction_v2.mlir"
+    "reduction_v2_uneven.mlir"
     "reduction_v3.mlir"
     "softmax.mlir"
     "softmax_partial.mlir"
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
index fe3042d..78ff584 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
@@ -64,13 +64,14 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3
   %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
   %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
       { workgroup_size = [32, 2, 1] }
@@ -78,10 +79,10 @@
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
   %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
-  transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
+  transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
     transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   }
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
index c93af38..e4c4358 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -71,13 +71,14 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_3
   %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
   %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
       { workgroup_size = [32, 2, 1] }
@@ -85,10 +86,10 @@
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
   %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
-  transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
+  transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
     transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   }
diff --git a/tests/transform_dialect/cuda/reduction.mlir b/tests/transform_dialect/cuda/reduction.mlir
index 7149139..29fc449 100644
--- a/tests/transform_dialect/cuda/reduction.mlir
+++ b/tests/transform_dialect/cuda/reduction.mlir
@@ -51,9 +51,10 @@
   //     CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id  z
 
   //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
+  //         CHECK: %[[ADDED:.*]] = arith.addi %[[TIDZ]], %[[workgroup_id_x]]
 
   // Distributed reduction: everyone loads then 5 xor + addf expected
-  //         CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
+  //         CHECK: vector.transfer_read %{{.*}}[%[[ADDED]], %[[TIDY]], %[[TIDX]]]
   // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
   //         CHECK: %[[RES:.*]] = arith.addf %{{.*}}
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index dc771fd..6649bc9 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -43,26 +43,28 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes }
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2
-  %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
+  %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_3
+  %func_6 = transform.iree.foreach_thread_to_workgroup %func_5
+  %func_7 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_6
       { workgroup_size = [32, 2, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing }
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
-  transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
+  transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
     transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   }
-  transform.iree.vector.warp_distribute %func_7
+  transform.iree.vector.warp_distribute %func_8
 }
diff --git a/tests/transform_dialect/cuda/reduction_eltwise.mlir b/tests/transform_dialect/cuda/reduction_eltwise.mlir
index 783d033..571f8fd 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise.mlir
@@ -59,9 +59,10 @@
   //     CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id  z
 
   //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
+  //         CHECK: %[[ADDED:.*]] = arith.addi %[[TIDZ]], %[[workgroup_id_x]]
 
   // Distributed reduction: everyone loads then 5 xor + addf expected
-  //         CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
+  //         CHECK: vector.transfer_read %{{.*}}[%[[ADDED]], %[[TIDY]], %[[TIDX]]]
   // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
   //         CHECK: %[[RES:.*]] = arith.addf %{{.*}}
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 8803f0e..5afa0a8 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -47,26 +47,28 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes }
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2
-  %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
+  %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_3
+  %func_6 = transform.iree.foreach_thread_to_workgroup %func_5
+  %func_7 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_6
       { workgroup_size = [32, 2, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing }
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
-  transform.sequence %variant_op_2 : !pdl.operation failures(suppress) {
+  transform.sequence %variant_op_3 : !pdl.operation failures(suppress) {
   ^bb0(%arg0: !pdl.operation):
     transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   }
-  transform.iree.vector.warp_distribute %func_7
+  transform.iree.vector.warp_distribute %func_8
 }
diff --git a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
index 5d171d0..90c27c8 100644
--- a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
@@ -17,14 +17,14 @@
   %foreach_thread, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 = 
     transform.structured.tile_reduction_using_scf %grid_reduction by tile_sizes = [0, 128]
   %_1:2 =
-    transform.structured.tile_to_foreach_thread_op %block_more_parallel_op_2 num_threads [0, 32] 
+    transform.structured.tile_to_foreach_thread_op %block_more_parallel_op_2 num_threads [0, 32]
     ( mapping = [#gpu.thread<x>] )
 
   // Step 3. Second level of tiling parallelizes to threads.
   // ===========================================================================
   // 1st op is [parallel, parallel], map it to threadIdx.x by 4.
   %_2:2 =
-    transform.structured.tile_to_foreach_thread_op %block_more_parallel_fill_op_2 tile_sizes [0, 4] 
+    transform.structured.tile_to_foreach_thread_op %block_more_parallel_fill_op_2 tile_sizes [0, 4]
     ( mapping = [#gpu.thread<x>] )
   // 2nd op is [parallel, reduction] of 1x128, map the 1-dim to threadIdx.y to
   // trigger mapping of the reduction to threadIdx.x via predication via `if (x==0)`.
@@ -40,21 +40,25 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes }
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_6 = transform.iree.apply_patterns %func_5 { erase_unnecessary_tensor_operands }
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 6. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2
-  %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
+  %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3
+  %func_8 = transform.iree.foreach_thread_to_workgroup %func_7
+  %func_9 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_8
       { workgroup_size = [32, 1, 1] }
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing }
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
-  transform.iree.vector.warp_distribute %func_7
+  transform.iree.vector.warp_distribute %func_10
 }
diff --git a/tests/transform_dialect/cuda/reduction_v2_uneven.mlir b/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
new file mode 100644
index 0000000..473ec18
--- /dev/null
+++ b/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
@@ -0,0 +1,67 @@
+!in_tensor_t = tensor<33x34567xf32>
+!out_tensor_t = tensor<33xf32>
+
+func.func @reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
+  %cst = arith.constant -0.000000e+00 : f32
+
+  %0 = tensor.empty() : !out_tensor_t
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->   !out_tensor_t
+  %2 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
+      ^bb0(%arg3: f32, %arg4: f32):
+        %3 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %3 : f32
+      } -> !out_tensor_t
+  return %2 : !out_tensor_t
+}
+
+// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN:     --iree-abi-transformation-pipeline \
+// RUN:     --iree-flow-transformation-pipeline  \
+// RUN:     --iree-stream-transformation-pipeline \
+// RUN:     --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
+// RUN: FileCheck %s --check-prefix=CHECK
+
+// RUN: iree-compile %s --iree-hal-target-backends=cuda \
+// RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
+// RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="33x34567xf32=1" |\
+// RUN: FileCheck %s --check-prefix=EXEC
+
+  //     CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  //     CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  //     CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+  //     CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
+  //     CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x128xf32, 3>
+  
+  //         CHECK: %[[TIDX:.]] = gpu.thread_id  x
+  //         CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
+  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<4xf32, strided<[1], offset: ?>, 3>
+  //         CHECK: gpu.barrier
+  // Local per-thread scf.for-based reduction.
+  //         CHECK: scf.for
+  //     CHECK-NOT:   memref.alloc
+  //         CHECK:   linalg.generic
+  // TODO: remote unnecessary barrier within the loop
+  //         CHECK:   gpu.barrier
+
+  //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
+  // Distributed reduction: everyone loads then 5 xor + addf expected
+  //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
+  // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
+
+  //         CHECK: %[[RES:.*]] = arith.addf %{{.*}}
+
+  //         CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
+  //         CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
+  //         CHECK: scf.if %[[CONDXIS0]]
+  //         CHECK:   vector.transfer_write %[[RES_VEC]]
+  //         CHECK: gpu.barrier
+
+// only checking the first 6 of 33
+//      EXEC: result[0]: hal.buffer_view
+// EXEC-NEXT: 33xf32=34567 34567 34567 34567 34567 34567
diff --git a/tests/transform_dialect/cuda/reduction_v3.mlir b/tests/transform_dialect/cuda/reduction_v3.mlir
index 1d3ce90..7070b79 100644
--- a/tests/transform_dialect/cuda/reduction_v3.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3.mlir
@@ -39,7 +39,7 @@
   //     CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x1024xf32, 3>
   
   //         CHECK: %[[TIDX:.]] = gpu.thread_id  x
-  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[TIDX]]]{{.*}}to memref<f32, strided<[], offset: ?>, 3>
+  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[TIDX]]]{{.*}}to memref<1x1xf32, strided<[1024, 1], offset: ?>, 3>
   // Local per-thread scf.for-based reduction.
   //         CHECK: scf.for
   //         CHECK:   vector.transfer_read %{{.*}} : memref<f32, strided<[], offset: ?>>, vector<f32>
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index 3f42b23..5b02ef6 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -39,21 +39,25 @@
 
   // Step 4. Bufferize and drop HAL descriptor from memref ops.
   // ===========================================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes }
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_6 = transform.iree.apply_patterns %func_5 { erase_unnecessary_tensor_operands }
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 5. Post-bufferization mapping to blocks and threads.
   // ===========================================================================
-  %func_4 = transform.structured.match ops{["func.func"]} in %variant_op_2
-  %func_5 = transform.iree.foreach_thread_to_workgroup %func_4
-  %func_6 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_5
+  %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3
+  %func_8 = transform.iree.foreach_thread_to_workgroup %func_7
+  %func_9 = transform.iree.map_nested_foreach_thread_to_gpu_threads %func_8
       { workgroup_size = [1024, 1, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing }
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
-  transform.iree.vector.warp_distribute %func_7
+  transform.iree.vector.warp_distribute %func_10
 }
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index e57a3c7..1d849c6 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -75,22 +75,23 @@
 
   // Step 4. Bufferize and drop HAL decriptor from memref ops.
   // =========================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
-  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3
   %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
   transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
     { workgroup_size = [32, 4, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
-  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %end_func_2
 }
diff --git a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
index 81a8a87..ef712a2 100644
--- a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
@@ -59,22 +59,23 @@
 
   // Step 4. Bufferize and drop HAL decriptor from memref ops.
   // =========================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
-  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3
   %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
   transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
     { workgroup_size = [32, 4, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
-  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %end_func_2
 }
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 449d49b..52869f9 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -66,22 +66,23 @@
 
   // Step 4. Bufferize and drop HAL decriptor from memref ops.
   // =========================================================
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 5. Post-bufferization mapping to blocks and threads.
   // =========================================================
-  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3
   %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
   transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3
     { workgroup_size = [32, 4, 1] }
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
-  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
-  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+  %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %end_func_2
 }
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
index a71ce12..4562466 100644
--- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
@@ -10,12 +10,13 @@
   // ===========================================================================
   %func = transform.structured.match ops{["func.func"]} in %variant_op
   transform.iree.apply_patterns %func { rank_reducing }
-  %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
-  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op
+  %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
+  %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func
 
   // Step 3. Map to GPU thread blocks.
   // ===========================================================================
-  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+  %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3
   transform.iree.foreach_thread_to_workgroup %func_2
 }