[spirv] Fix matmul vectorization corner cases (#7137)

* We don't support non-16/non-32 bit element types yet.
* Don't vectorize for odd K sizes. We cannot vector load there.
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index fa57109..cf8c39d 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -208,6 +208,10 @@
 LogicalResult setMatmulOpConfig(linalg::LinalgOp op,
                                 std::array<int64_t, 2> bestWorkgroupSizeXY,
                                 std::array<int64_t, 3> bestThreadTileSizeMNK) {
+  auto lhsType = op.inputs()[0].getType().cast<ShapedType>();
+  auto elementBits = lhsType.getElementType().getIntOrFloatBitWidth();
+  if (elementBits != 16 && elementBits != 32) return success();
+
   ArrayRef<int64_t> lhsShape = getUntiledShape(op.inputs()[0]);
   ArrayRef<int64_t> rhsShape = getUntiledShape(op.inputs()[1]);
   if (llvm::any_of(lhsShape, ShapedType::isDynamic)) return success();
@@ -282,12 +286,13 @@
 
   // Deduce the configuration for the K dimension. We need some power of two
   // here so that we can do vector load.
-  for (int64_t t = llvm::PowerOf2Floor(residualTilingFactor); t >= 1; t >>= 1) {
+  for (int64_t t = llvm::PowerOf2Floor(residualTilingFactor); t >= 2; t >>= 1) {
     if (dimK % t == 0) {
       workgroupTileSizes[2 + isBM] = invocationTileSizes[2 + isBM] = t;
       break;
     }
   }
+  if (workgroupTileSizes[2 + isBM] == 0) return success();
 
   auto pipeline = IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize;
   TileSizesListType tileSizes;
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD
index 5f847d4..75967bc 100644
--- a/iree/compiler/Codegen/SPIRV/test/BUILD
+++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -21,6 +21,7 @@
         [
             "config_adreno_conv.mlir",
             "config_adreno_matmul.mlir",
+            "config_default_matmul.mlir",
             "config_linalg_ext_ops.mlir",
             "config_linalg_ops.mlir",
             "config_mali_conv.mlir",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 0d805b1..b42fcef 100644
--- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -16,6 +16,7 @@
   SRCS
     "config_adreno_conv.mlir"
     "config_adreno_matmul.mlir"
+    "config_default_matmul.mlir"
     "config_linalg_ext_ops.mlir"
     "config_linalg_ops.mlir"
     "config_mali_conv.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
new file mode 100644
index 0000000..6cbdd91
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -0,0 +1,162 @@
+// RUN: iree-opt -split-input-file -mlir-print-local-scope -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true}))' %s | IreeFileCheck %s
+
+// Odd K that forbids vectorization.
+
+hal.executable @batch_matmul_1x3x32 {
+  hal.interface public @io {
+    hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+      spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
+        max_compute_shared_memory_size = 16384 : i32,
+        max_compute_workgroup_invocations = 128 : i32,
+        max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>,
+        subgroup_size = 4 : i32}>
+    }> {
+    hal.executable.entry_point public @batch_matmul_1x3x32 attributes {interface = @io, ordinal = 0 : index}
+    builtin.module  {
+      func @batch_matmul_1x3x32() {
+        %c0 = constant 0 : index
+        %c32 = constant 32 : index
+        %c3 = constant 3 : index
+        %c1 = constant 1 : index
+        %cst = constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:1x3x3xf32>
+        %1 = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c0] : !flow.dispatch.tensor<readonly:1x3x32xf32>
+        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:1x3x32xf32>
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_size_z = hal.interface.workgroup.size[2] : index
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_count_x = hal.interface.workgroup.count[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %workgroup_id_z = hal.interface.workgroup.id[2] : index
+        %workgroup_count_z = hal.interface.workgroup.count[2] : index
+        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+        scf.for %arg0 = %3 to %c1 step %4 {
+          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+          scf.for %arg1 = %5 to %c3 step %6 {
+            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+            scf.for %arg2 = %7 to %c32 step %8 {
+              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
+              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_y]
+              %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x3x3xf32> -> tensor<?x?x3xf32>
+              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
+              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
+              %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%12, 3, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x3x32xf32> -> tensor<?x3x?xf32>
+              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
+              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_y]
+              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
+              %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 1, s0)>(%arg0)[%workgroup_size_z]
+              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 3, s0)>(%arg1)[%workgroup_size_y]
+              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
+              %21 = linalg.init_tensor [%18, %19, %20] : tensor<?x?x?xf32>
+              %22 = linalg.fill(%cst, %21) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+              %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %14 : tensor<?x?x3xf32>, tensor<?x3x?xf32>) outs(%22 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+              flow.dispatch.tensor.store %23, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %17], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x3x32xf32>
+            }
+          }
+        }
+        return
+      }
+      hal.interface private @io {
+        hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+
+//          CHECK-LABEL: hal.executable.entry_point public @batch_matmul_1x3x32
+//           CHECK-SAME:   translation.info = {passPipeline = "SPIRVDistribute", workloadPerWorkgroup = [4, 1, 1]}
+//           CHECK-SAME:   workgroup_size = [4 : index, 1 : index, 1 : index]
+//           CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
+//           CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%[[X]]]
+//           CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y]], %[[Z]]
+
+//                CHECK: func @batch_matmul_1x3x32()
+//                CHECK:   linalg.batch_matmul
+//  CHECK-SAME{LITERAL}:     lowering.config = {tileSizes = [[1, 1, 4], [], [1, 1, 1]]}
+
+// -----
+
+// Non-16 / non-32 bit types cannot be vectorized right now.
+
+hal.executable private @matmul_64x16 {
+  hal.interface public @io {
+    hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
+      spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
+        max_compute_shared_memory_size = 16384 : i32,
+        max_compute_workgroup_invocations = 128 : i32,
+        max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>,
+        subgroup_size = 4 : i32}>
+  }> {
+    hal.executable.entry_point public @matmul_64x16 attributes {interface = @io, ordinal = 0 : index}
+    builtin.module  {
+      func @matmul_64x16() {
+        %c0 = constant 0 : index
+        %c16 = constant 16 : index
+        %c64 = constant 64 : index
+        %c0_i32 = constant 0 : i32
+        %0 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:64x32xi8>
+        %1 = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c0] : !flow.dispatch.tensor<readonly:32x16xi8>
+        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : !flow.dispatch.tensor<writeonly:64x16xi32>
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_count_x = hal.interface.workgroup.count[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+        scf.for %arg0 = %3 to %c64 step %4 {
+          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+          scf.for %arg1 = %5 to %c16 step %6 {
+            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg0)[%workgroup_size_y]
+            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:64x32xi8> -> tensor<?x32xi8>
+            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg1)[%workgroup_size_x]
+            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [32, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:32x16xi8> -> tensor<32x?xi8>
+            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg0)[%workgroup_size_y]
+            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg1)[%workgroup_size_x]
+            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 64, s0)>(%arg0)[%workgroup_size_y]
+            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg1)[%workgroup_size_x]
+            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xi32>
+            %16 = linalg.fill(%c0_i32, %15) : i32, tensor<?x?xi32> -> tensor<?x?xi32>
+            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x32xi8>, tensor<32x?xi8>) outs(%16 : tensor<?x?xi32>) -> tensor<?x?xi32>
+            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:64x16xi32>
+          }
+        }
+        return
+      }
+      hal.interface private @io {
+        hal.interface.binding public @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+
+//          CHECK-LABEL: hal.executable.entry_point public @matmul_64x16
+//           CHECK-SAME:   translation.info = {passPipeline = "SPIRVDistribute", workloadPerWorkgroup = [4, 1]}
+//           CHECK-SAME:   workgroup_size = [4 : index, 1 : index, 1 : index]
+//           CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
+//           CHECK-NEXT:   %[[ONE:.+]] = constant 1 : index
+//           CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%[[X]]]
+//           CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y]], %[[ONE]]
+
+//                CHECK: func @matmul_64x16()
+//                CHECK:   linalg.matmul
+//  CHECK-SAME{LITERAL}:     lowering.config = {tileSizes = [[1, 4], [], [1, 1]]}