Support default configs for BATCH_MATMUL_* in MaterializeEncoding pass (#14762)
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel index 88aea51..3ee0efd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -60,6 +60,7 @@ "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", + "//llvm-external-projects/iree-dialects:IREELinalgExtEncodingUtils", "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms", "//llvm-external-projects/iree-dialects:IREELinalgExtUtils", "@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt index d1ed3f3..2721fbd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -48,6 +48,7 @@ ::PassHeaders ::PassesIncGen IREELinalgExtDialect + IREELinalgExtEncodingUtils IREELinalgExtTransforms IREELinalgExtUtils LLVMSupport
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp index c9558d5..782fff7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -7,6 +7,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h" #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/Codegen/Common/CPU/PassDetail.h" #include "iree/compiler/Codegen/Common/CPU/Passes.h" @@ -171,7 +172,8 @@ auto user = encoding.getUser().getValue(); auto role = encoding.getRole().getValue(); MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr); - auto encodingInfo = chooseEncodingInfoForMatmul(user, role, tileParams); + auto encodingInfo = + IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, tileParams); auto originalTypeAttr = encoding.getOriginalType(); auto originalType = originalTypeAttr ? originalTypeAttr.getValue().cast<RankedTensorType>()
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h index 3115938..3774a95 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
@@ -14,21 +14,10 @@ namespace mlir { namespace iree_compiler { -struct MatmulTileParams { - int64_t M = 1; - int64_t K = 1; - int64_t N = 1; -}; - void adjustTileSizesToNarrowStaticShape( IREE::LinalgExt::MaterializeEncodingInfo &encodingInfo, ArrayRef<int64_t> shape); -IREE::LinalgExt::MaterializeEncodingInfo -chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingUser user, - IREE::LinalgExt::EncodingRole role, - MatmulTileParams tileParams); - IREE::LinalgExt::MaterializeEncodingValueFn getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr);
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index b97b4d9..b5d2a1b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -269,52 +269,6 @@ } // namespace -IREE::LinalgExt::MaterializeEncodingInfo -chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, - MatmulTileParams tileParams) { - // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. - int64_t matmulDimBase = 0; - switch (user) { - case EncodingUser::BATCH_MATMUL_F32F32F32: - case EncodingUser::BATCH_MATMUL_F16F16F32: - case EncodingUser::BATCH_MATMUL_F16F16F16: - case EncodingUser::BATCH_MATMUL_BF16BF16F32: - case EncodingUser::BATCH_MATMUL_BF16BF16BF16: - case EncodingUser::BATCH_MATMUL_I8I8I32: - matmulDimBase = 1; - break; - default: - break; - } - - MaterializeEncodingInfo encodingInfo; - encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1}; - switch (role) { - case (EncodingRole::LHS): { - encodingInfo.innerTileSizes = {tileParams.M, tileParams.K}; - break; - } - case (EncodingRole::RHS): { - encodingInfo.innerTileSizes = {tileParams.N, tileParams.K}; - encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase}; - encodingInfo.outerDimsPerm = - llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase)); - encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1); - encodingInfo.outerDimsPerm.push_back(matmulDimBase); - break; - } - case (EncodingRole::RESULT): { - encodingInfo.innerTileSizes = {tileParams.M, tileParams.N}; - break; - } - default: { - assert(false); - return {}; - } - } - return encodingInfo; -} - void adjustTileSizesToNarrowStaticShape(MaterializeEncodingInfo &encodingInfo, ArrayRef<int64_t> shape) { for (size_t i = 0; i < encodingInfo.innerDimsPos.size(); i++) {
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index 1f8a3fd..a2945ed 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -312,12 +312,13 @@ cc_library( name = "IREELinalgExtUtils", - srcs = glob([ - "lib/Dialect/LinalgExt/Utils/*.cpp", - ]), - hdrs = glob([ - "include/iree-dialects/Dialect/LinalgExt/Utils/*.h", - ]), + srcs = [ + "lib/Dialect/LinalgExt/Utils/Utils.cpp", + ], + hdrs = [ + "include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h", + "include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h", + ], includes = ["include"], deps = [ "@llvm-project//llvm:Support", @@ -332,6 +333,21 @@ ) cc_library( + name = "IREELinalgExtEncodingUtils", + srcs = [ + "lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp", + ], + hdrs = [ + "include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h", + ], + includes = ["include"], + deps = [ + ":IREELinalgExtDialect", + ":IREELinalgExtUtils", + ], +) + +cc_library( name = "IREELinalgExtDialect", srcs = glob([ "lib/Dialect/LinalgExt/IR/*.cpp", @@ -437,6 +453,7 @@ deps = [ ":IREEInputDialect", ":IREELinalgExtDialect", + ":IREELinalgExtEncodingUtils", ":IREELinalgExtPassIncGen", ":IREELinalgExtUtils", "@llvm-project//llvm:Support",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h new file mode 100644 index 0000000..700dce5 --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h
@@ -0,0 +1,33 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_ +#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_ + +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace LinalgExt { + +struct MatmulTileParams { + int64_t M = 1; + int64_t K = 1; + int64_t N = 1; +}; + +MaterializeEncodingInfo +chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, + MatmulTileParams tileParams); + +} // namespace LinalgExt +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt index 27a2e58..5c921da 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
@@ -16,6 +16,7 @@ LINK_LIBS PUBLIC IREEInputDialect IREELinalgExtDialect + IREELinalgExtEncodingUtils IREELinalgExtUtils MLIRAffineDialect MLIRIR
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp index e510ebf..769541b 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -7,6 +7,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h" #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -89,19 +90,23 @@ auto encoding = getEncodingAttr(tensorType); if (!encoding) return failure(); + + auto user = encoding.getUser().getValue(); auto role = encoding.getRole().getValue(); - switch (role) { - case EncodingRole::LHS: - return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}}; - break; - case EncodingRole::RHS: - return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}}; - break; - case EncodingRole::RESULT: - return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}}; - break; - default: - return failure(); + switch (user) { + case EncodingUser::MATMUL_F32F32F32: + case EncodingUser::MATMUL_F16F16F32: + case EncodingUser::MATMUL_F16F16F16: + case EncodingUser::MATMUL_BF16BF16F32: + case EncodingUser::MATMUL_BF16BF16BF16: + case EncodingUser::MATMUL_I8I8I32: + case EncodingUser::BATCH_MATMUL_F32F32F32: + case EncodingUser::BATCH_MATMUL_F16F16F32: + case EncodingUser::BATCH_MATMUL_F16F16F16: + case EncodingUser::BATCH_MATMUL_BF16BF16F32: + case EncodingUser::BATCH_MATMUL_BF16BF16BF16: + case EncodingUser::BATCH_MATMUL_I8I8I32: + return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8}); } }
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt index 85bdc21..fcff056 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -11,3 +11,11 @@ MLIRTensorDialect MLIRMemRefDialect ) + +add_mlir_library(IREELinalgExtEncodingUtils + EncodingUtils.cpp + + LINK_LIBS PUBLIC + IREELinalgExtDialect + IREELinalgExtUtils +)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp new file mode 100644 index 0000000..cae238d --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp
@@ -0,0 +1,63 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace LinalgExt { + +MaterializeEncodingInfo +chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, + MatmulTileParams tileParams) { + // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. + int64_t matmulDimBase = 0; + switch (user) { + case EncodingUser::BATCH_MATMUL_F32F32F32: + case EncodingUser::BATCH_MATMUL_F16F16F32: + case EncodingUser::BATCH_MATMUL_F16F16F16: + case EncodingUser::BATCH_MATMUL_BF16BF16F32: + case EncodingUser::BATCH_MATMUL_BF16BF16BF16: + case EncodingUser::BATCH_MATMUL_I8I8I32: + matmulDimBase = 1; + break; + default: + break; + } + + MaterializeEncodingInfo encodingInfo; + encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1}; + switch (role) { + case (EncodingRole::LHS): { + encodingInfo.innerTileSizes = {tileParams.M, tileParams.K}; + break; + } + case (EncodingRole::RHS): { + encodingInfo.innerTileSizes = {tileParams.N, tileParams.K}; + encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase}; + encodingInfo.outerDimsPerm = + llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase)); + encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1); + encodingInfo.outerDimsPerm.push_back(matmulDimBase); + break; + } + case (EncodingRole::RESULT): { + encodingInfo.innerTileSizes = {tileParams.M, tileParams.N}; + break; + } + default: { + assert(false); + return {}; + } + } + return encodingInfo; +} + +} // namespace LinalgExt +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir index 60cc9be..7f757aa 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -163,3 +163,80 @@ // CHECK-SAME: outs(%[[FILL]] : // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] // CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_batch_matmul_lhs(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>> + %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>> -> tensor<?x?x?xf32> + return %1 : tensor<?x?x?xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_batch_matmul_lhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D2]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_batch_matmul_rhs(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>> + %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>> -> tensor<?x?x?xf32> + return %1 : tensor<?x?x?xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_batch_matmul_rhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D2]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_batch_matmul_result(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> + %1 = iree_linalg_ext.unset_encoding %0 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32> + return %1 : tensor<?x?x?xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK: func @pack_unpack_batch_matmul_result( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor<?x?x?x8x8xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]]