[Codegen] Add control options in pack unpack decomposition (#18469)

Some early patterns in pack/unpack decomposition avoid generating
reshape ops for unit dim packs and unpacks. However, the TileAndFuse
pipeline uses these reshape ops to propagate expanded shapes to other
fusable ops. This PR adds an option to the DecomposePackUnPackOps pass
to create reshape ops anyway for unit dim cases.

The reason these unit dims show up right now is that the
iree_linalg_ext.im2col op of a unit-batched conv will have a unit
dimension in the batch dim. Ultimately, it would be good to allow for
batchless im2col ops, but in general it is good to support ops that have
required unit dimensions. When prototyping new ops, it can be easiest to
not support rank-reducing cases at first (winograd ops are another
example), so these unit dims may appear again in the future.

This PR also adds an optional control function to the pass options,
which controls which packs and unpacks get decomposed. The control
function is currently expected to be used when the `useOnlyReshapes`
option is true, since there is no control function in some upstream
patterns yet, but adding the control function upstream and fixing this
is left as a TODO.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
index e0c56a0..e8b1837 100644
--- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "iree-codegen-decompose-pack-unpack-ops"
@@ -35,8 +36,15 @@
 struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
   using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
 
+  explicit LowerPackPattern(MLIRContext *context,
+                            std::optional<PackUnPackControlFn> controlFn)
+      : OpRewritePattern(context), controlFn(controlFn) {}
+
   LogicalResult matchAndRewrite(tensor::PackOp op,
                                 PatternRewriter &rewriter) const override {
+    if (controlFn && failed(controlFn.value()(op))) {
+      return failure();
+    }
     FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
     if (failed(res)) {
       return rewriter.notifyMatchFailure(
@@ -44,6 +52,9 @@
     }
     return success();
   }
+
+private:
+  std::optional<PackUnPackControlFn> controlFn;
 };
 
 /// A warpper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It
@@ -52,8 +63,15 @@
 struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
   using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
 
+  explicit LowerUnPackPattern(MLIRContext *context,
+                              std::optional<PackUnPackControlFn> controlFn)
+      : OpRewritePattern(context), controlFn(controlFn) {}
+
   LogicalResult matchAndRewrite(tensor::UnPackOp op,
                                 PatternRewriter &rewriter) const override {
+    if (controlFn && failed(controlFn.value()(op))) {
+      return failure();
+    }
     FailureOr<linalg::LowerUnPackOpResult> res =
         linalg::lowerUnPack(rewriter, op);
     if (failed(res)) {
@@ -62,14 +80,21 @@
     }
     return success();
   }
+
+private:
+  std::optional<PackUnPackControlFn> controlFn;
 };
 
 struct DecomposePackUnPackOpsPass final
     : impl::DecomposePackUnPackOpsPassBase<DecomposePackUnPackOpsPass> {
   using impl::DecomposePackUnPackOpsPassBase<
       DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase;
-  explicit DecomposePackUnPackOpsPass(bool tileOuterToOne) {
+  explicit DecomposePackUnPackOpsPass(
+      bool tileOuterToOne, bool useOnlyReshapes,
+      std::optional<PackUnPackControlFn> controlFn) {
     this->tileOuterToOne = tileOuterToOne;
+    this->useOnlyReshapes = useOnlyReshapes;
+    this->controlFn = controlFn;
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<linalg::LinalgDialect, arith::ArithDialect, scf::SCFDialect,
@@ -77,6 +102,9 @@
   }
 
   void runOnOperation() override;
+
+private:
+  std::optional<PackUnPackControlFn> controlFn;
 };
 
 } // namespace
@@ -86,7 +114,7 @@
   auto funcOp = getOperation();
   // Generalization patterns for outer unit dims have higher priority because
   // they do not generate reshape ops.
-  {
+  if (!useOnlyReshapes) {
     RewritePatternSet patterns(ctx);
     patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
                  linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
@@ -102,7 +130,7 @@
   // tiled to one.
   if (!tileOuterToOne) {
     RewritePatternSet patterns(ctx);
-    patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx);
+    patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
       funcOp.emitError(
           "failed to apply generalization patterns on pack/unpack ops for "
@@ -136,6 +164,9 @@
               return tileSizes;
             }));
     funcOp->walk([&](tensor::PackOp op) {
+      if (controlFn && failed(controlFn.value()(op))) {
+        return;
+      }
       FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
           scf::tileConsumerAndFuseProducersUsingSCF(
               rewriter, cast<TilingInterface>(op.getOperation()), packOptions);
@@ -161,6 +192,9 @@
               return tileSizes;
             });
     funcOp->walk([&](tensor::UnPackOp op) {
+      if (controlFn && failed(controlFn.value()(op))) {
+        return;
+      }
       FailureOr<scf::SCFTilingResult> tilingResult =
           scf::tileUsingSCF(rewriter, cast<TilingInterface>(op.getOperation()),
                             unpackTilingOptions);
@@ -197,8 +231,12 @@
 
   {
     RewritePatternSet patterns(ctx);
-    patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
-                 linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
+    if (useOnlyReshapes) {
+      patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
+    } else {
+      patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
+                   linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
+    }
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
       return signalPassFailure();
     }
@@ -206,8 +244,10 @@
 }
 
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDecomposePackUnPackOpsPass(bool tileOuterToOne) {
-  return std::make_unique<DecomposePackUnPackOpsPass>(tileOuterToOne);
+createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
+                                 std::optional<PackUnPackControlFn> controlFn) {
+  return std::make_unique<DecomposePackUnPackOpsPass>(
+      tileOuterToOne, useOnlyReshapes, controlFn);
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 3227190..502a3cf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -67,8 +67,16 @@
 std::unique_ptr<InterfacePass<FunctionOpInterface>>
 createConvolutionToIGEMMPass(ConfigFn configFn);
 
+using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;
+/// Pass to decompose pack and unpack ops into pad/extract_slice and reshape
+/// ops. If specified, `controlFn` controls which ops get decomposed. The
+/// `controlFn` should be used with `useOnlyReshapes` set to true.
+/// TODO(Max191): Add a controlFn upstream for `GeneralizeOuterUnitDim*`
+/// patterns and remove the need to have `useOnlyReshapes = true` when using
+/// `controlFn`.
 std::unique_ptr<InterfacePass<FunctionOpInterface>>
-createDecomposePackUnPackOpsPass(bool tileOuterToOne);
+createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
+                                 std::optional<PackUnPackControlFn> controlFn);
 
 std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index eb31e94..7d912f8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -153,7 +153,9 @@
   let summary = "Decompose pack/unpack ops into vectorizable ops";
   let options = [
     Option<"tileOuterToOne", "tile-outer-to-one", "bool", "false",
-           "Always apply tiling to make outer dimension be ones">
+           "Always apply tiling to make outer dimension be ones">,
+    Option<"useOnlyReshapes", "use-only-reshapes", "bool", "false",
+           "Use decomposition into reshape ops, even when packing unit dimensions.">
   ];
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir
index e11573a..b64193c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir
@@ -1,17 +1,23 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-decompose-pack-unpack-ops))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-decompose-pack-unpack-ops))" --split-input-file %s | FileCheck %s -check-prefixes=CHECK-ALL,CHECK
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-decompose-pack-unpack-ops{use-only-reshapes=true}))" --split-input-file %s | FileCheck %s -check-prefixes=CHECK-ALL,CHECK-RESHAPE
 
 func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> {
   %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
   return %0 : tensor<1x1x1x1x8x32xf32>
 }
-// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK:         %[[TRANS:.+]] = linalg.transpose ins(%[[TILE]] : tensor<32x8xf32>) outs(%[[EMPTY]] : tensor<8x32xf32>) permutation = [1, 0]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANS]] into %[[OUT]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK:         return %[[INSERT]]
+// CHECK-ALL-LABEL: func.func @simple_KCRS_to_KCRSsr
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+
+// CHECK-RESHAPE:     %[[EXPANDED:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 1, 32, 1, 8] : tensor<1x1x32x8xf32> into tensor<1x1x1x32x1x8xf32>
+// CHECK-RESHAPE:     %[[RESULT:.+]] = linalg.transpose ins(%[[EXPANDED]] : tensor<1x1x1x32x1x8xf32>) outs(%[[OUT]] : tensor<1x1x1x1x8x32xf32>) permutation = [0, 1, 2, 4, 5, 3]
+
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK:             %[[TRANS:.+]] = linalg.transpose ins(%[[TILE]] : tensor<32x8xf32>) outs(%[[EMPTY]] : tensor<8x32xf32>) permutation = [1, 0]
+// CHECK:             %[[RESULT:.+]] = tensor.insert_slice %[[TRANS]] into %[[OUT]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+
+// CHECK-ALL:         return %[[RESULT]]
 
 // -----
 
@@ -19,14 +25,14 @@
   %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
 }
-// CHECK-LABEL: func.func @simple_pad_and_pack
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[PAD_VAL:[A-Za-z0-9]+]]:
-// CHECK:         %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
-// CHECK:           tensor.yield %[[PAD_VAL]]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK:         return %[[INSERT]]
+// CHECK-ALL-LABEL: func.func @simple_pad_and_pack
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[PAD_VAL:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
+// CHECK-ALL:           tensor.yield %[[PAD_VAL]]
+// CHECK-ALL:         %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
+// CHECK-ALL:         return %[[INSERT]]
 
 // -----
 
@@ -34,11 +40,11 @@
   %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
   return %0 : tensor<1x1x32x8xf32>
 }
-// CHECK-LABEL: func.func @simple_NC_to_CNnc
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         return %[[INSERT]]
+// CHECK-ALL-LABEL: func.func @simple_NC_to_CNnc
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK-ALL:         return %[[INSERT]]
 
 // -----
 
@@ -46,15 +52,15 @@
   %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x128x64xf32> -> tensor<1x1x4x8x8x32xf32>
   return %0 : tensor<1x1x4x8x8x32xf32>
 }
-// CHECK:       func.func @KCRS_to_KCRSsr
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 4, 32, 8, 8] : tensor<1x1x128x64xf32> into tensor<1x1x4x32x8x8xf32>
-// CHECK:          %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:       ins(%[[EXPAND]] : tensor<1x1x4x32x8x8xf32>)
-// CHECK-SAME:       outs(%[[OUT]] : tensor<1x1x4x8x8x32xf32>)
-// CHECK-SAME:       permutation = [0, 1, 2, 4, 5, 3]
-// CHECK:         return %[[TRANSP]]
+// CHECK-ALL:       func.func @KCRS_to_KCRSsr
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 4, 32, 8, 8] : tensor<1x1x128x64xf32> into tensor<1x1x4x32x8x8xf32>
+// CHECK-ALL:          %[[TRANSP:.+]] = linalg.transpose
+// CHECK-ALL-SAME:       ins(%[[EXPAND]] : tensor<1x1x4x32x8x8xf32>)
+// CHECK-ALL-SAME:       outs(%[[OUT]] : tensor<1x1x4x8x8x32xf32>)
+// CHECK-ALL-SAME:       permutation = [0, 1, 2, 4, 5, 3]
+// CHECK-ALL:         return %[[TRANSP]]
 
 // -----
 
@@ -62,19 +68,19 @@
   %0 = tensor.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
   return %0 : tensor<2x8x8x2xf32>
 }
-// CHECK:       func.func @pad_and_pack
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[PAD_VAL:[A-Za-z0-9]+]]:
-// CHECK:         %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
-// CHECK:           tensor.yield %[[PAD_VAL]]
-// CHECK:         } : tensor<13x15xf32> to tensor<16x16xf32>
-// CHECK:         %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] output_shape [2, 8, 8, 2] : tensor<16x16xf32> into tensor<2x8x8x2xf32>
-// CHECK:         %[[TRANS:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[EXPAND]] : tensor<2x8x8x2xf32>)
-// CHECK-SAME:      outs(%[[OUT:.+]] : tensor<2x8x8x2xf32>)
-// CHECK-SAME:      permutation = [0, 2, 1, 3]
-// CHECK:         return %[[TRANSP]]
+// CHECK-ALL:       func.func @pad_and_pack
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[PAD_VAL:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
+// CHECK-ALL:           tensor.yield %[[PAD_VAL]]
+// CHECK-ALL:         } : tensor<13x15xf32> to tensor<16x16xf32>
+// CHECK-ALL:         %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] output_shape [2, 8, 8, 2] : tensor<16x16xf32> into tensor<2x8x8x2xf32>
+// CHECK-ALL:         %[[TRANS:.+]] = linalg.transpose
+// CHECK-ALL-SAME:      ins(%[[EXPAND]] : tensor<2x8x8x2xf32>)
+// CHECK-ALL-SAME:      outs(%[[OUT:.+]] : tensor<2x8x8x2xf32>)
+// CHECK-ALL-SAME:      permutation = [0, 2, 1, 3]
+// CHECK-ALL:         return %[[TRANSP]]
 
 // -----
 
@@ -82,15 +88,15 @@
   %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32>
   return %0 : tensor<32x4x32x8xf32>
 }
-// CHECK:       func.func @KC_to_CKck
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0, 1], [2, 3]] output_shape [4, 32, 32, 8] : tensor<128x256xf32> into tensor<4x32x32x8xf32>
-// CHECK:         %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[EXPAND]] : tensor<4x32x32x8xf32>)
-// CHECK-SAME:      outs(%[[OUT]] : tensor<32x4x32x8xf32>)
-// CHECK-SAME:      permutation = [2, 0, 1, 3]
-// CHECK:         return %[[TRANSP]]
+// CHECK-ALL:       func.func @KC_to_CKck
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0, 1], [2, 3]] output_shape [4, 32, 32, 8] : tensor<128x256xf32> into tensor<4x32x32x8xf32>
+// CHECK-ALL:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-ALL-SAME:      ins(%[[EXPAND]] : tensor<4x32x32x8xf32>)
+// CHECK-ALL-SAME:      outs(%[[OUT]] : tensor<32x4x32x8xf32>)
+// CHECK-ALL-SAME:      permutation = [2, 0, 1, 3]
+// CHECK-ALL:         return %[[TRANSP]]
 
 // -----
 
@@ -98,15 +104,21 @@
   %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x1x1x8x32xf32> -> tensor<1x1x32x8xf32>
   return %0 : tensor<1x1x32x8xf32>
 }
-// CHECK-LABEL: func.func @simple_KCRSsr_to_KCRS
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[IN]]
-// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
-// CHECK:         %[[TRANS:.+]] = linalg.transpose ins(%[[TILE]] : tensor<8x32xf32>) outs(%[[EMPTY]] : tensor<32x8xf32>) permutation = [1, 0]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANS]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         return %[[INSERT]]
+// CHECK-ALL-LABEL: func.func @simple_KCRSsr_to_KCRS
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+
+// CHECK-RESHAPE:     %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x1x32x1x8xf32>
+// CHECK-RESHAPE:     %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x1x1x8x32xf32>) outs(%[[EMPTY]] : tensor<1x1x1x32x1x8xf32>) permutation = [0, 1, 2, 5, 3, 4]
+// CHECK-RESHAPE:     %[[COLLAPSE:.+]] = tensor.collapse_shape %[[TRANS]] {{\[}}[0], [1], [2, 3], [4, 5]] : tensor<1x1x1x32x1x8xf32> into tensor<1x1x32x8xf32>
+// CHECK-RESHAPE:     %[[RESULT:.+]] = linalg.copy ins(%[[COLLAPSE]] : tensor<1x1x32x8xf32>) outs(%[[OUT]] : tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>
+
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK-SAME:          [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
+// CHECK:             %[[TRANS:.+]] = linalg.transpose ins(%[[TILE]] : tensor<8x32xf32>) outs(%[[EMPTY]] : tensor<32x8xf32>) permutation = [1, 0]
+// CHECK:             %[[RESULT:.+]] = tensor.insert_slice %[[TRANS]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK-ALL:         return %[[RESULT]]
 
 // -----
 
@@ -114,12 +126,16 @@
   %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<1x1x8x2xf32> -> tensor<5x1xf32>
   return %0 : tensor<5x1xf32>
 }
-// CHECK-LABEL: func.func @simple_unpack_and_extract_slice
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
-// CHECK:         %[[RES:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
-// CHECK:         return %[[RES:.+]]
+// CHECK-ALL-LABEL: func.func @simple_unpack_and_extract_slice
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+
+// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
+// CHECK:             %[[RES:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
+
+// CHECK-RESHAPE:     %[[RES:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1]
+
+// CHECK-ALL:         return %[[RES:.+]]
 
 // -----
 
@@ -127,11 +143,11 @@
   %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<1x1x32x8xf32> -> tensor<32x8xf32>
   return %0 : tensor<32x8xf32>
 }
-// CHECK-LABEL: func.func @simple_CNnc_to_NC
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         return %[[TILE]]
+// CHECK-ALL-LABEL: func.func @simple_CNnc_to_NC
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
+// CHECK-ALL:         return %[[TILE]]
 
 // -----
 
@@ -139,19 +155,19 @@
   %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<13x12x4x8x8x32xf32> -> tensor<13x12x128x64xf32>
   return %0 : tensor<13x12x128x64xf32>
 }
-// CHECK:       func.func @KCRSsr_to_KCRS
-// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<13x12x4x32x8x8xf32>
-// CHECK:         %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:      ins(%[[IN]] : tensor<13x12x4x8x8x32xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<13x12x4x32x8x8xf32>)
-// CHECK-SAME:      permutation = [0, 1, 2, 5, 3, 4]
-// CHECK:         %[[COLLAPSE:.+]] = tensor.collapse_shape %[[TRANSP]]
-// CHECK-SAME:      {{\[}}[0], [1], [2, 3], [4, 5]] : tensor<13x12x4x32x8x8xf32> into tensor<13x12x128x64xf32>
-// CHECK:         %[[COPY:.]] = linalg.copy ins(%[[COLLAPSE]]
-// CHECK-SAME:        outs(%[[OUT]]
-// CHECK:         return %[[COPY]]
+// CHECK-ALL:       func.func @KCRSsr_to_KCRS
+// CHECK-ALL-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:         %[[EMPTY:.+]] = tensor.empty() : tensor<13x12x4x32x8x8xf32>
+// CHECK-ALL:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-ALL-SAME:      ins(%[[IN]] : tensor<13x12x4x8x8x32xf32>)
+// CHECK-ALL-SAME:      outs(%[[EMPTY]] : tensor<13x12x4x32x8x8xf32>)
+// CHECK-ALL-SAME:      permutation = [0, 1, 2, 5, 3, 4]
+// CHECK-ALL:         %[[COLLAPSE:.+]] = tensor.collapse_shape %[[TRANSP]]
+// CHECK-ALL-SAME:      {{\[}}[0], [1], [2, 3], [4, 5]] : tensor<13x12x4x32x8x8xf32> into tensor<13x12x128x64xf32>
+// CHECK-ALL:         %[[COPY:.]] = linalg.copy ins(%[[COLLAPSE]]
+// CHECK-ALL-SAME:        outs(%[[OUT]]
+// CHECK-ALL:         return %[[COPY]]
 
 // -----
 
@@ -159,34 +175,34 @@
   %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : tensor<2x8x8x2xf32> -> tensor<13x15xf32>
   return %0 : tensor<13x15xf32>
 }
-// CHECK:      func.func @unpack_and_extract_slice
-// CHECK-SAME:   %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:   %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:          %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK:          %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:       ins(%[[IN]] : tensor<2x8x8x2xf32>)
-// CHECK-SAME:       outs(%[[EMPTY]] : tensor<2x8x8x2xf32>)
-// CHECK-SAME:       permutation = [0, 2, 1, 3]
-// CHECK:          %[[COLLAPSE:.+]] = tensor.collapse_shape %[[TRANSP]]
-// CHECK-SAME:       {{\[}}[0, 1], [2, 3]] : tensor<2x8x8x2xf32> into tensor<16x16xf32>
-// CHECK:          %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]]
-// CHECK-SAME:       [0, 0] [13, 15] [1, 1] : tensor<16x16xf32> to tensor<13x15xf32>
-// CHECK:          %[[COPY:.]] = linalg.copy ins(%[[SLICE]]
-// CHECK-SAME:         outs(%[[OUT]]
-// CHECK:          return %[[COPY]]
+// CHECK-ALL:      func.func @unpack_and_extract_slice
+// CHECK-ALL-SAME:   %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:   %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:          %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-ALL:          %[[TRANSP:.+]] = linalg.transpose
+// CHECK-ALL-SAME:       ins(%[[IN]] : tensor<2x8x8x2xf32>)
+// CHECK-ALL-SAME:       outs(%[[EMPTY]] : tensor<2x8x8x2xf32>)
+// CHECK-ALL-SAME:       permutation = [0, 2, 1, 3]
+// CHECK-ALL:          %[[COLLAPSE:.+]] = tensor.collapse_shape %[[TRANSP]]
+// CHECK-ALL-SAME:       {{\[}}[0, 1], [2, 3]] : tensor<2x8x8x2xf32> into tensor<16x16xf32>
+// CHECK-ALL:          %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]]
+// CHECK-ALL-SAME:       [0, 0] [13, 15] [1, 1] : tensor<16x16xf32> to tensor<13x15xf32>
+// CHECK-ALL:          %[[COPY:.]] = linalg.copy ins(%[[SLICE]]
+// CHECK-ALL-SAME:         outs(%[[OUT]]
+// CHECK-ALL:          return %[[COPY]]
 // -----
 
 func.func @CKck_to_KC(%arg0: tensor<32x4x32x8xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
   %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
   return %0 : tensor<128x256xf32>
 }
-// CHECK:      func.func @CKck_to_KC
-// CHECK-SAME:   %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:   %[[OUT:[A-Za-z0-9]+]]:
-// CHECK:        %[[TRANSP:.+]] = linalg.transpose ins(%[[IN]]
-// CHECK:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRANSP]] {{.+}} : tensor<4x32x32x8xf32> into tensor<128x256xf32>
-// CHECK:        %[[RES:.+]] = linalg.copy ins(%[[COLLAPSED]]
-// CHECK:        return %[[RES]]
+// CHECK-ALL:      func.func @CKck_to_KC
+// CHECK-ALL-SAME:   %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:   %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL:        %[[TRANSP:.+]] = linalg.transpose ins(%[[IN]]
+// CHECK-ALL:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRANSP]] {{.+}} : tensor<4x32x32x8xf32> into tensor<128x256xf32>
+// CHECK-ALL:        %[[RES:.+]] = linalg.copy ins(%[[COLLAPSED]]
+// CHECK-ALL:        return %[[RES]]
 
 // -----
 
@@ -194,23 +210,23 @@
   %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %dest : tensor<?x?xf32> -> tensor<?x?x16x1xf32>
   return %pack : tensor<?x?x16x1xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
-// CHECK:      func.func @pack_matmul_DYN_LHS
-// CHECK-SAME:   %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:   %[[OUT:[A-Za-z0-9]+]]:
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[IN]], %c0 : tensor<?x?xf32>
-// CHECK-DAG:    %[[H0:.+]] = affine.apply #[[MAP0]]
-// CHECK-DAG:    %[[H1:.+]] = affine.apply #[[MAP1]]
-// CHECK:        %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[%[[H0]], %[[H1]]]
-// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[PAD]]
-// CHECK-SAME:     {{\[}}[0, 1], [2, 3]]
-// CHECK:        %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:     ins(%[[EXPANDED]] : tensor<?x16x?x1xf32>)
-// CHECK-SAME:     outs(%[[OUT]] : tensor<?x?x16x1xf32>)
-// CHECK-SAME:   permutation = [0, 2, 1, 3]
-// CHECK:        return %[[TRANSP]]
+// CHECK-ALL-DAG:  #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)>
+// CHECK-ALL-DAG:  #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK-ALL:      func.func @pack_matmul_DYN_LHS
+// CHECK-ALL-SAME:   %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:   %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-ALL-DAG:    %[[D0:.+]] = tensor.dim %[[IN]], %c0 : tensor<?x?xf32>
+// CHECK-ALL-DAG:    %[[H0:.+]] = affine.apply #[[MAP0]]
+// CHECK-ALL-DAG:    %[[H1:.+]] = affine.apply #[[MAP1]]
+// CHECK-ALL:        %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[%[[H0]], %[[H1]]]
+// CHECK-ALL:        %[[EXPANDED:.+]] = tensor.expand_shape %[[PAD]]
+// CHECK-ALL-SAME:     {{\[}}[0, 1], [2, 3]]
+// CHECK-ALL:        %[[TRANSP:.+]] = linalg.transpose
+// CHECK-ALL-SAME:     ins(%[[EXPANDED]] : tensor<?x16x?x1xf32>)
+// CHECK-ALL-SAME:     outs(%[[OUT]] : tensor<?x?x16x1xf32>)
+// CHECK-ALL-SAME:   permutation = [0, 2, 1, 3]
+// CHECK-ALL:        return %[[TRANSP]]
 
 // -----
 
@@ -218,18 +234,18 @@
   %pack = tensor.pack %src outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %dest : tensor<?x?xf32> -> tensor<?x?x16x1xf32>
   return %pack : tensor<?x?x16x1xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
-// CHECK:      func.func @pack_matmul_DYN_RHS
-// CHECK-SAME:   %[[IN:[A-Za-z0-9]+]]:
-// CHECK-SAME:   %[[OUT:[A-Za-z0-9]+]]:
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:    %[[H0:.+]] = affine.apply #[[MAP1]]
-// CHECK-DAG:    %[[H1:.+]] = affine.apply #[[MAP0]]
-// CHECK:        %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[%[[H0]], %[[H1]]]
-// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[PAD]]
-// CHECK-SAME:     {{\[}}[0, 1], [2, 3]]
-// CHECK:        %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME:     ins(%[[EXPANDED]] : tensor<?x1x?x16xf32>)
-// CHECK-SAME:     outs(%[[OUT]] : tensor<?x?x16x1xf32>)
-// CHECK-SAME:     permutation = [2, 0, 3, 1]
+// CHECK-ALL-DAG:  #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)>
+// CHECK-ALL-DAG:  #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK-ALL:      func.func @pack_matmul_DYN_RHS
+// CHECK-ALL-SAME:   %[[IN:[A-Za-z0-9]+]]:
+// CHECK-ALL-SAME:   %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-ALL-DAG:    %[[C1:.+]] = arith.constant 1 : index
+// CHECK-ALL-DAG:    %[[H0:.+]] = affine.apply #[[MAP1]]
+// CHECK-ALL-DAG:    %[[H1:.+]] = affine.apply #[[MAP0]]
+// CHECK-ALL:        %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[%[[H0]], %[[H1]]]
+// CHECK-ALL:        %[[EXPANDED:.+]] = tensor.expand_shape %[[PAD]]
+// CHECK-ALL-SAME:     {{\[}}[0, 1], [2, 3]]
+// CHECK-ALL:        %[[TRANSP:.+]] = linalg.transpose
+// CHECK-ALL-SAME:     ins(%[[EXPANDED]] : tensor<?x1x?x16xf32>)
+// CHECK-ALL-SAME:     outs(%[[OUT]] : tensor<?x?x16x1xf32>)
+// CHECK-ALL-SAME:     permutation = [2, 0, 3, 1]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index cbe19e5..b63dfa7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -348,7 +348,9 @@
 
   // Step 3. Decompose pack and unpack ops and propagate the resulting reshapes.
   funcPassManager.addPass(
-      createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false));
+      createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false,
+                                       /*useOnlyReshapes=*/true,
+                                       /*controlFn=*/std::nullopt));
 
   // Step 3.5. Expand the inner dimensions of MultiMma ops in preparation for
   // distribution to lanes.
@@ -944,7 +946,9 @@
   funcPassManager.addPass(createCSEPass());
 
   funcPassManager.addPass(
-      createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/true));
+      createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/true,
+                                       /*useOnlyReshapes=*/false,
+                                       /*controlFn=*/std::nullopt));
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
   addGPUVectorizationPasses(funcPassManager);