Generalize matmul tensorcore strategy to work with arbitrary unaligned f32 tens… (#13192)

…or sizes

After this change, we still bail on aligned cases for which some of the transforms do not yet compose.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index fa5aff19..5e1e1b2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -26,6 +26,8 @@
 }
 }
 
+// CHECK-LABEL: func @matmul
+
 // One of this matmul's dimensions is divisible by 64/64/16, we currently bail on such cases.
 // CHECK-NOT: transform.sequence
 
@@ -42,37 +44,6 @@
     func.func @matmul() {
       %c0 = arith.constant 0 : index
       %cst = arith.constant 0.000000e+00 : f32
-      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2049x2556xf32>>
-      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2556x2556xf32>>
-      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2049x2556xf32>>
-      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2049, 2556], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2049x2556xf32>> -> tensor<2049x2556xf32>
-      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2556, 2556], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2556x2556xf32>> -> tensor<2556x2556xf32>
-      %5 = tensor.empty() : tensor<2049x2556xf32>
-      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2049x2556xf32>) -> tensor<2049x2556xf32>
-      %7 = linalg.matmul ins(%3, %4 : tensor<2049x2556xf32>, tensor<2556x2556xf32>) outs(%6 : tensor<2049x2556xf32>) -> tensor<2049x2556xf32>
-      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2049, 2556], strides = [1, 1] : tensor<2049x2556xf32> -> !flow.dispatch.tensor<writeonly:tensor<2049x2556xf32>>
-      return
-    }
-  }
-}
-}
-
-// One of this matmul's dimensions is not divisible by 4, we currently bail on such cases.
-// CHECK-NOT: transform.sequence
-
-// -----
-
-hal.executable @matmul {
-hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
-  hal.executable.export public @matmul ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
-  ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
-    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
-    hal.return %x, %y, %z : index, index, index
-  }
-  builtin.module {
-    func.func @matmul() {
-      %c0 = arith.constant 0 : index
-      %cst = arith.constant 0.000000e+00 : f32
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2052x2556xf32>>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2556x2052xf32>>
       %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2052x2052xf32>>
@@ -88,6 +59,8 @@
 }
 }
 
+// CHECK-LABEL: func @matmul
+
 // CHECK: transform.sequence  failures(propagate) {
 // CHECK: transform.iree.match_callback failures(propagate) "matmul"
 // CHECK: transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<y>, #gpu.block<x>])
@@ -128,3 +101,53 @@
 // CHECK: transform.vector.lower_masks %{{.*}} : (!pdl.operation) -> !pdl.operation
 // CHECK: transform.vector.materialize_masks %{{.*}} : (!pdl.operation) -> !pdl.operation
 // CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization} : (!pdl.operation) -> ()
+
+
+// -----
+
+hal.executable @matmul {
+hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> {
+  hal.executable.export public @matmul ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
+  ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
+    %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
+    hal.return %x, %y, %z : index, index, index
+  }
+  builtin.module {
+    func.func @matmul() {
+      %c0 = arith.constant 0 : index
+      %cst = arith.constant 0.000000e+00 : f32
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2051x2555xf32>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2555x2050xf32>>
+      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2051x2050xf32>>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2051, 2555], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2051x2555xf32>> -> tensor<2051x2555xf32>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2555, 2051], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2555x2050xf32>> -> tensor<2555x2050xf32>
+      %5 = tensor.empty() : tensor<2051x2050xf32>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2051x2050xf32>) -> tensor<2051x2050xf32>
+      %7 = linalg.matmul ins(%3, %4 : tensor<2051x2555xf32>, tensor<2555x2050xf32>) outs(%6 : tensor<2051x2050xf32>) -> tensor<2051x2050xf32>
+      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2051, 2050], strides = [1, 1] : tensor<2051x2050xf32> -> !flow.dispatch.tensor<writeonly:tensor<2051x2050xf32>>
+      return
+    }
+  }
+}
+}
+
+// CHECK-LABEL: func @matmul
+
+// CHECK: transform.sequence  failures(propagate) {
+// CHECK: transform.iree.match_callback failures(propagate) "matmul"
+// CHECK: transform.iree.tile_to_forall_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<y>, #gpu.block<x>])
+// CHECK: transform.structured.tile %{{.*}}[0, 0, 16]
+// align1
+// CHECK: transform.structured.tile_to_forall_op %{{.*}}   num_threads [8, 16] tile_sizes [](mapping = [#gpu.linear<x>, #gpu.linear<y>])
+// align2
+// CHECK: transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 64] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
+// align2
+// CHECK: transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 64] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
+// CHECK: transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
+// CHECK: transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
+// align1
+// CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [16, 1]
+// align2
+// CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [8, 2]
+// align2
+// CHECK: transform.structured.masked_vectorize %{{.*}} vector_sizes [64, 2]
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
index 1319685..eebefa8 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
@@ -504,9 +504,7 @@
 
   bool alignedAny64x64x16 = matmulSize[0] % 64 == 0 ||
                             matmulSize[1] % 64 == 0 || matmulSize[2] % 16 == 0;
-  bool alignedAll4x4x4 = matmulSize[0] % 4 == 0 && matmulSize[1] % 4 == 0 &&
-                         matmulSize[2] % 4 == 0;
-  if (alignedAny64x64x16 || !alignedAll4x4x4) {
+  if (alignedAny64x64x16) {
     LLVM_DEBUG(DBGS() << "alignment check failed\n");
     return failure();
   }