Support default configs for BATCH_MATMUL_* in MaterializeEncoding pass (#14762)

diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
index 88aea51..3ee0efd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -60,6 +60,7 @@
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
+        "//llvm-external-projects/iree-dialects:IREELinalgExtEncodingUtils",
         "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
         "//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
         "@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
index d1ed3f3..2721fbd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -48,6 +48,7 @@
     ::PassHeaders
     ::PassesIncGen
     IREELinalgExtDialect
+    IREELinalgExtEncodingUtils
     IREELinalgExtTransforms
     IREELinalgExtUtils
     LLVMSupport
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index c9558d5..782fff7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -7,6 +7,7 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
 #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
 #include "iree/compiler/Codegen/Common/CPU/PassDetail.h"
 #include "iree/compiler/Codegen/Common/CPU/Passes.h"
@@ -171,7 +172,8 @@
   auto user = encoding.getUser().getValue();
   auto role = encoding.getRole().getValue();
   MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr);
-  auto encodingInfo = chooseEncodingInfoForMatmul(user, role, tileParams);
+  auto encodingInfo =
+      IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, tileParams);
   auto originalTypeAttr = encoding.getOriginalType();
   auto originalType = originalTypeAttr
                           ? originalTypeAttr.getValue().cast<RankedTensorType>()
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
index 3115938..3774a95 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
@@ -14,21 +14,10 @@
 namespace mlir {
 namespace iree_compiler {
 
-struct MatmulTileParams {
-  int64_t M = 1;
-  int64_t K = 1;
-  int64_t N = 1;
-};
-
 void adjustTileSizesToNarrowStaticShape(
     IREE::LinalgExt::MaterializeEncodingInfo &encodingInfo,
     ArrayRef<int64_t> shape);
 
-IREE::LinalgExt::MaterializeEncodingInfo
-chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingUser user,
-                            IREE::LinalgExt::EncodingRole role,
-                            MatmulTileParams tileParams);
-
 IREE::LinalgExt::MaterializeEncodingValueFn
 getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index b97b4d9..b5d2a1b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -269,52 +269,6 @@
 
 } // namespace
 
-IREE::LinalgExt::MaterializeEncodingInfo
-chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
-                            MatmulTileParams tileParams) {
-  // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
-  int64_t matmulDimBase = 0;
-  switch (user) {
-  case EncodingUser::BATCH_MATMUL_F32F32F32:
-  case EncodingUser::BATCH_MATMUL_F16F16F32:
-  case EncodingUser::BATCH_MATMUL_F16F16F16:
-  case EncodingUser::BATCH_MATMUL_BF16BF16F32:
-  case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
-  case EncodingUser::BATCH_MATMUL_I8I8I32:
-    matmulDimBase = 1;
-    break;
-  default:
-    break;
-  }
-
-  MaterializeEncodingInfo encodingInfo;
-  encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
-  switch (role) {
-  case (EncodingRole::LHS): {
-    encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
-    break;
-  }
-  case (EncodingRole::RHS): {
-    encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
-    encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
-    encodingInfo.outerDimsPerm =
-        llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
-    encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
-    encodingInfo.outerDimsPerm.push_back(matmulDimBase);
-    break;
-  }
-  case (EncodingRole::RESULT): {
-    encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
-    break;
-  }
-  default: {
-    assert(false);
-    return {};
-  }
-  }
-  return encodingInfo;
-}
-
 void adjustTileSizesToNarrowStaticShape(MaterializeEncodingInfo &encodingInfo,
                                         ArrayRef<int64_t> shape) {
   for (size_t i = 0; i < encodingInfo.innerDimsPos.size(); i++) {
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 1f8a3fd..a2945ed 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -312,12 +312,13 @@
 
 cc_library(
     name = "IREELinalgExtUtils",
-    srcs = glob([
-        "lib/Dialect/LinalgExt/Utils/*.cpp",
-    ]),
-    hdrs = glob([
-        "include/iree-dialects/Dialect/LinalgExt/Utils/*.h",
-    ]),
+    srcs = [
+        "lib/Dialect/LinalgExt/Utils/Utils.cpp",
+    ],
+    hdrs = [
+        "include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h",
+        "include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h",
+    ],
     includes = ["include"],
     deps = [
         "@llvm-project//llvm:Support",
@@ -332,6 +333,21 @@
 )
 
 cc_library(
+    name = "IREELinalgExtEncodingUtils",
+    srcs = [
+        "lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp",
+    ],
+    hdrs = [
+        "include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":IREELinalgExtDialect",
+        ":IREELinalgExtUtils",
+    ],
+)
+
+cc_library(
     name = "IREELinalgExtDialect",
     srcs = glob([
         "lib/Dialect/LinalgExt/IR/*.cpp",
@@ -437,6 +453,7 @@
     deps = [
         ":IREEInputDialect",
         ":IREELinalgExtDialect",
+        ":IREELinalgExtEncodingUtils",
         ":IREELinalgExtPassIncGen",
         ":IREELinalgExtUtils",
         "@llvm-project//llvm:Support",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
new file mode 100644
index 0000000..700dce5
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
@@ -0,0 +1,33 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+struct MatmulTileParams {
+  int64_t M = 1;
+  int64_t K = 1;
+  int64_t N = 1;
+};
+
+MaterializeEncodingInfo
+chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
+                            MatmulTileParams tileParams);
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
index 27a2e58..5c921da 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
@@ -16,6 +16,7 @@
   LINK_LIBS PUBLIC
   IREEInputDialect
   IREELinalgExtDialect
+  IREELinalgExtEncodingUtils
   IREELinalgExtUtils
   MLIRAffineDialect
   MLIRIR
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index e510ebf..769541b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -7,6 +7,7 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
 #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
 #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -89,19 +90,23 @@
   auto encoding = getEncodingAttr(tensorType);
   if (!encoding)
     return failure();
+
+  auto user = encoding.getUser().getValue();
   auto role = encoding.getRole().getValue();
-  switch (role) {
-  case EncodingRole::LHS:
-    return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
-    break;
-  case EncodingRole::RHS:
-    return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
-    break;
-  case EncodingRole::RESULT:
-    return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
-    break;
-  default:
-    return failure();
+  switch (user) {
+  case EncodingUser::MATMUL_F32F32F32:
+  case EncodingUser::MATMUL_F16F16F32:
+  case EncodingUser::MATMUL_F16F16F16:
+  case EncodingUser::MATMUL_BF16BF16F32:
+  case EncodingUser::MATMUL_BF16BF16BF16:
+  case EncodingUser::MATMUL_I8I8I32:
+  case EncodingUser::BATCH_MATMUL_F32F32F32:
+  case EncodingUser::BATCH_MATMUL_F16F16F32:
+  case EncodingUser::BATCH_MATMUL_F16F16F16:
+  case EncodingUser::BATCH_MATMUL_BF16BF16F32:
+  case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
+  case EncodingUser::BATCH_MATMUL_I8I8I32:
+    return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
   }
 }
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
index 85bdc21..fcff056 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -11,3 +11,11 @@
   MLIRTensorDialect
   MLIRMemRefDialect
 )
+
+add_mlir_library(IREELinalgExtEncodingUtils
+  EncodingUtils.cpp
+
+  LINK_LIBS PUBLIC
+  IREELinalgExtDialect
+  IREELinalgExtUtils
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
new file mode 100644
index 0000000..cae238d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
@@ -0,0 +1,63 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+MaterializeEncodingInfo
+chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
+                            MatmulTileParams tileParams) {
+  // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
+  int64_t matmulDimBase = 0;
+  switch (user) {
+  case EncodingUser::BATCH_MATMUL_F32F32F32:
+  case EncodingUser::BATCH_MATMUL_F16F16F32:
+  case EncodingUser::BATCH_MATMUL_F16F16F16:
+  case EncodingUser::BATCH_MATMUL_BF16BF16F32:
+  case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
+  case EncodingUser::BATCH_MATMUL_I8I8I32:
+    matmulDimBase = 1;
+    break;
+  default:
+    break;
+  }
+
+  MaterializeEncodingInfo encodingInfo;
+  encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
+  switch (role) {
+  case (EncodingRole::LHS): {
+    encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
+    break;
+  }
+  case (EncodingRole::RHS): {
+    encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
+    encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
+    encodingInfo.outerDimsPerm =
+        llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
+    encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
+    encodingInfo.outerDimsPerm.push_back(matmulDimBase);
+    break;
+  }
+  case (EncodingRole::RESULT): {
+    encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
+    break;
+  }
+  default: {
+    assert(false);
+    return {};
+  }
+  }
+  return encodingInfo;
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 60cc9be..7f757aa 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -163,3 +163,80 @@
 // CHECK-SAME:       outs(%[[FILL]] :
 //      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
 //      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @pack_unpack_batch_matmul_lhs(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>> -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//      CHECK: func @pack_unpack_batch_matmul_lhs(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+//  CHECK-DAG:   %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+//  CHECK-DAG:   %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D2]]]
+//      CHECK:   %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x4xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[PACK_DEST]]
+//      CHECK:   %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
+//      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @pack_unpack_batch_matmul_rhs(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>> -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//      CHECK: func @pack_unpack_batch_matmul_rhs(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+//  CHECK-DAG:   %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
+//  CHECK-DAG:   %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
+//      CHECK:   %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x4xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG0]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[PACK_DEST]]
+//      CHECK:   %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[PACK]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
+//      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @pack_unpack_batch_matmul_result(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+  %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//      CHECK: func @pack_unpack_batch_matmul_result(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+//  CHECK-DAG:   %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+//  CHECK-DAG:   %[[OUTER_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
+//      CHECK:   %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x8xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack
+// CHECK-SAME:     %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[PACK_DEST]]
+//      CHECK:   %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]]
+//      CHECK:   return %[[UNPACK]]