[VMVX] Enable ukernels for batch_matmul using static shape path. (#15733)

Different from matmul, it does not use dynamic shapes. The reason is
that it will need codegen.query_tile_sizes ukernel to support 3d cases.
The revision enables data-tiling batch_matmul on vmvx path using static
shapes.

Fixes https://github.com/openxla/iree/issues/15314
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index d262217..445a54e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -36,12 +36,10 @@
 // narrow-N cases are handled by transposition in chooseMatmulTile.
 static SmallVector<TileMxNxK>
 enumerateMatmulTilesVMVX(EncodingUser user, ExecutableTargetAttr target) {
-  if (hasUkernel(target)) {
-    // TODO(#15314): Remove the check once it is supported. vmvx + ukernel
-    // does not support batch_matmul atm.
-    if (user == EncodingUser::BATCH_MATMUL) {
-      return {};
-    }
+  // TODO(hanchung): The ukernel path does not support 3d
+  // codegen.query_tile_sizes op, so we disable dynamic tile shapes for
+  // batch_matmul.
+  if (hasUkernel(target) && user != EncodingUser::BATCH_MATMUL) {
     // VMVX+ukernel uses dynamic tile shapes.
     return {TileMxNxK{ShapedType::kDynamic, ShapedType::kDynamic,
                       ShapedType::kDynamic}};
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
index 8488c48..221dcea 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
@@ -58,8 +58,11 @@
                                 bool enableUKernels) {
   addTileAndDistributePasses(passManager);
 
+  OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
   if (enableUKernels) {
-    passManager.nest<ModuleOp>().addPass(
+    nestedModulePM.addNestedPass<func::FuncOp>(
+        createDecomposeBatchMmt4DOpsPass());
+    nestedModulePM.addPass(
         createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
   }
 
@@ -72,7 +75,6 @@
   }
 
   // Lower to buffers.
-  OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
   addCPUBufferizePasses(nestedModulePM);
 
   // Cleanup the IR that may now have unused loops.
diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel
index d3dcf3f..a554c08 100644
--- a/tests/e2e/tosa_ops/BUILD.bazel
+++ b/tests/e2e/tosa_ops/BUILD.bazel
@@ -181,9 +181,6 @@
     name = "check_vmvx_local-sync_microkernels",
     srcs = VMVX_MICROKERNELS_SRCS,
     compiler_flags = [
-        # TODO(15314): Remove the flag once vmvx supports batch_matmul on
-        # ukernel path.
-        "--iree-opt-data-tiling=false",
         "--iree-vmvx-enable-microkernels",
     ],
     # Sync has more strict runtime error checking for mis-compiled programs.
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 2999061..f3f86ff 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -165,7 +165,6 @@
   DRIVER
     "local-sync"
   COMPILER_FLAGS
-    "--iree-opt-data-tiling=false"
     "--iree-vmvx-enable-microkernels"
   INPUT_TYPE
     "tosa"