[VMVX] Enable ukernels for batch_matmul using static shape path. (#15733)
Different from matmul, it does not use dynamic shapes. The reason is
that it will need codegen.query_tile_sizes ukernel to support 3d cases.
The revision enables data-tiling batch_matmul on vmvx path using static
shapes.
Fixes https://github.com/openxla/iree/issues/15314
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index d262217..445a54e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -36,12 +36,10 @@
// narrow-N cases are handled by transposition in chooseMatmulTile.
static SmallVector<TileMxNxK>
enumerateMatmulTilesVMVX(EncodingUser user, ExecutableTargetAttr target) {
- if (hasUkernel(target)) {
- // TODO(#15314): Remove the check once it is supported. vmvx + ukernel
- // does not support batch_matmul atm.
- if (user == EncodingUser::BATCH_MATMUL) {
- return {};
- }
+ // TODO(hanchung): The ukernel path does not support 3d
+ // codegen.query_tile_sizes op, so we disable dynamic tile shapes for
+ // batch_matmul.
+ if (hasUkernel(target) && user != EncodingUser::BATCH_MATMUL) {
// VMVX+ukernel uses dynamic tile shapes.
return {TileMxNxK{ShapedType::kDynamic, ShapedType::kDynamic,
ShapedType::kDynamic}};
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
index 8488c48..221dcea 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
@@ -58,8 +58,11 @@
bool enableUKernels) {
addTileAndDistributePasses(passManager);
+ OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
if (enableUKernels) {
- passManager.nest<ModuleOp>().addPass(
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createDecomposeBatchMmt4DOpsPass());
+ nestedModulePM.addPass(
createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
}
@@ -72,7 +75,6 @@
}
// Lower to buffers.
- OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
addCPUBufferizePasses(nestedModulePM);
// Cleanup the IR that may now have unused loops.
diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel
index d3dcf3f..a554c08 100644
--- a/tests/e2e/tosa_ops/BUILD.bazel
+++ b/tests/e2e/tosa_ops/BUILD.bazel
@@ -181,9 +181,6 @@
name = "check_vmvx_local-sync_microkernels",
srcs = VMVX_MICROKERNELS_SRCS,
compiler_flags = [
- # TODO(15314): Remove the flag once vmvx supports batch_matmul on
- # ukernel path.
- "--iree-opt-data-tiling=false",
"--iree-vmvx-enable-microkernels",
],
# Sync has more strict runtime error checking for mis-compiled programs.
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 2999061..f3f86ff 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -165,7 +165,6 @@
DRIVER
"local-sync"
COMPILER_FLAGS
- "--iree-opt-data-tiling=false"
"--iree-vmvx-enable-microkernels"
INPUT_TYPE
"tosa"