Support default configs for BATCH_MATMUL_* in MaterializeEncoding pass (#14762)
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
index 88aea51..3ee0efd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -60,6 +60,7 @@
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtEncodingUtils",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
index d1ed3f3..2721fbd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -48,6 +48,7 @@
::PassHeaders
::PassesIncGen
IREELinalgExtDialect
+ IREELinalgExtEncodingUtils
IREELinalgExtTransforms
IREELinalgExtUtils
LLVMSupport
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index c9558d5..782fff7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Codegen/Common/CPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
@@ -171,7 +172,8 @@
auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr);
- auto encodingInfo = chooseEncodingInfoForMatmul(user, role, tileParams);
+ auto encodingInfo =
+ IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, tileParams);
auto originalTypeAttr = encoding.getOriginalType();
auto originalType = originalTypeAttr
? originalTypeAttr.getValue().cast<RankedTensorType>()
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
index 3115938..3774a95 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
@@ -14,21 +14,10 @@
namespace mlir {
namespace iree_compiler {
-struct MatmulTileParams {
- int64_t M = 1;
- int64_t K = 1;
- int64_t N = 1;
-};
-
void adjustTileSizesToNarrowStaticShape(
IREE::LinalgExt::MaterializeEncodingInfo &encodingInfo,
ArrayRef<int64_t> shape);
-IREE::LinalgExt::MaterializeEncodingInfo
-chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingUser user,
- IREE::LinalgExt::EncodingRole role,
- MatmulTileParams tileParams);
-
IREE::LinalgExt::MaterializeEncodingValueFn
getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr);
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index b97b4d9..b5d2a1b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -269,52 +269,6 @@
} // namespace
-IREE::LinalgExt::MaterializeEncodingInfo
-chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
- MatmulTileParams tileParams) {
- // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
- int64_t matmulDimBase = 0;
- switch (user) {
- case EncodingUser::BATCH_MATMUL_F32F32F32:
- case EncodingUser::BATCH_MATMUL_F16F16F32:
- case EncodingUser::BATCH_MATMUL_F16F16F16:
- case EncodingUser::BATCH_MATMUL_BF16BF16F32:
- case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
- case EncodingUser::BATCH_MATMUL_I8I8I32:
- matmulDimBase = 1;
- break;
- default:
- break;
- }
-
- MaterializeEncodingInfo encodingInfo;
- encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
- switch (role) {
- case (EncodingRole::LHS): {
- encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
- break;
- }
- case (EncodingRole::RHS): {
- encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
- encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
- encodingInfo.outerDimsPerm =
- llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
- encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
- encodingInfo.outerDimsPerm.push_back(matmulDimBase);
- break;
- }
- case (EncodingRole::RESULT): {
- encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
- break;
- }
- default: {
- assert(false);
- return {};
- }
- }
- return encodingInfo;
-}
-
void adjustTileSizesToNarrowStaticShape(MaterializeEncodingInfo &encodingInfo,
ArrayRef<int64_t> shape) {
for (size_t i = 0; i < encodingInfo.innerDimsPos.size(); i++) {
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 1f8a3fd..a2945ed 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -312,12 +312,13 @@
cc_library(
name = "IREELinalgExtUtils",
- srcs = glob([
- "lib/Dialect/LinalgExt/Utils/*.cpp",
- ]),
- hdrs = glob([
- "include/iree-dialects/Dialect/LinalgExt/Utils/*.h",
- ]),
+ srcs = [
+ "lib/Dialect/LinalgExt/Utils/Utils.cpp",
+ ],
+ hdrs = [
+ "include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h",
+ "include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h",
+ ],
includes = ["include"],
deps = [
"@llvm-project//llvm:Support",
@@ -332,6 +333,21 @@
)
cc_library(
+ name = "IREELinalgExtEncodingUtils",
+ srcs = [
+ "lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp",
+ ],
+ hdrs = [
+ "include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h",
+ ],
+ includes = ["include"],
+ deps = [
+ ":IREELinalgExtDialect",
+ ":IREELinalgExtUtils",
+ ],
+)
+
+cc_library(
name = "IREELinalgExtDialect",
srcs = glob([
"lib/Dialect/LinalgExt/IR/*.cpp",
@@ -437,6 +453,7 @@
deps = [
":IREEInputDialect",
":IREELinalgExtDialect",
+ ":IREELinalgExtEncodingUtils",
":IREELinalgExtPassIncGen",
":IREELinalgExtUtils",
"@llvm-project//llvm:Support",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
new file mode 100644
index 0000000..700dce5
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
@@ -0,0 +1,33 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+struct MatmulTileParams {
+ int64_t M = 1;
+ int64_t K = 1;
+ int64_t N = 1;
+};
+
+MaterializeEncodingInfo
+chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
+ MatmulTileParams tileParams);
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
index 27a2e58..5c921da 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
@@ -16,6 +16,7 @@
LINK_LIBS PUBLIC
IREEInputDialect
IREELinalgExtDialect
+ IREELinalgExtEncodingUtils
IREELinalgExtUtils
MLIRAffineDialect
MLIRIR
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index e510ebf..769541b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -89,19 +90,23 @@
auto encoding = getEncodingAttr(tensorType);
if (!encoding)
return failure();
+
+ auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
- switch (role) {
- case EncodingRole::LHS:
- return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}};
- break;
- case EncodingRole::RHS:
- return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}};
- break;
- case EncodingRole::RESULT:
- return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}};
- break;
- default:
- return failure();
+ switch (user) {
+ case EncodingUser::MATMUL_F32F32F32:
+ case EncodingUser::MATMUL_F16F16F32:
+ case EncodingUser::MATMUL_F16F16F16:
+ case EncodingUser::MATMUL_BF16BF16F32:
+ case EncodingUser::MATMUL_BF16BF16BF16:
+ case EncodingUser::MATMUL_I8I8I32:
+ case EncodingUser::BATCH_MATMUL_F32F32F32:
+ case EncodingUser::BATCH_MATMUL_F16F16F32:
+ case EncodingUser::BATCH_MATMUL_F16F16F16:
+ case EncodingUser::BATCH_MATMUL_BF16BF16F32:
+ case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
+ case EncodingUser::BATCH_MATMUL_I8I8I32:
+ return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
}
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
index 85bdc21..fcff056 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -11,3 +11,11 @@
MLIRTensorDialect
MLIRMemRefDialect
)
+
+add_mlir_library(IREELinalgExtEncodingUtils
+ EncodingUtils.cpp
+
+ LINK_LIBS PUBLIC
+ IREELinalgExtDialect
+ IREELinalgExtUtils
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
new file mode 100644
index 0000000..cae238d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
@@ -0,0 +1,63 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+MaterializeEncodingInfo
+chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
+ MatmulTileParams tileParams) {
+ // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
+ int64_t matmulDimBase = 0;
+ switch (user) {
+ case EncodingUser::BATCH_MATMUL_F32F32F32:
+ case EncodingUser::BATCH_MATMUL_F16F16F32:
+ case EncodingUser::BATCH_MATMUL_F16F16F16:
+ case EncodingUser::BATCH_MATMUL_BF16BF16F32:
+ case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
+ case EncodingUser::BATCH_MATMUL_I8I8I32:
+ matmulDimBase = 1;
+ break;
+ default:
+ break;
+ }
+
+ MaterializeEncodingInfo encodingInfo;
+ encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
+ switch (role) {
+ case (EncodingRole::LHS): {
+ encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
+ break;
+ }
+ case (EncodingRole::RHS): {
+ encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
+ encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
+ encodingInfo.outerDimsPerm =
+ llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
+ encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1);
+ encodingInfo.outerDimsPerm.push_back(matmulDimBase);
+ break;
+ }
+ case (EncodingRole::RESULT): {
+ encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
+ break;
+ }
+ default: {
+ assert(false);
+ return {};
+ }
+ }
+ return encodingInfo;
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 60cc9be..7f757aa 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -163,3 +163,80 @@
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @pack_unpack_batch_matmul_lhs(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>> -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+// CHECK: func @pack_unpack_batch_matmul_lhs(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D2]]]
+// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[PACK_DEST]]
+// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @pack_unpack_batch_matmul_rhs(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>> -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+// CHECK: func @pack_unpack_batch_matmul_rhs(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
+// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
+// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[PACK_DEST]]
+// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @pack_unpack_batch_matmul_result(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
+ %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK: func @pack_unpack_batch_matmul_result(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
+// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[PACK_DEST]]
+// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]]
+// CHECK: return %[[UNPACK]]