[Codegen][ROCDL] Extend mfma pipeline to support a few more matmul variants (#16582)

Similar to convolution, we just tile all MNK dimensions to 1 except the
inner most. This is in preparation for some fusion experiments.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 48fa797..066fa32 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -402,8 +402,7 @@
     workgroupTileSizes[batch] = 1;
   }
   // Tile all output image dimensions with unit size except the last one.
-  for (int64_t oi :
-       ArrayRef<unsigned int>(convolutionDims->outputImage).drop_back()) {
+  for (int64_t oi : llvm::drop_end(convolutionDims->outputImage)) {
     workgroupTileSizes[oi] = 1;
   }
   // Compute the M/N dimension tile size by multiply subgroup information.
@@ -463,15 +462,23 @@
       mlir::linalg::inferContractionDims(op);
   assert(succeeded(contractionDims) && "Could not infer contraction dims");
 
-  // TODO: Relax this condition to strictly alignment requirements.
-  if (contractionDims->k.size() != 1 || contractionDims->m.size() != 1 ||
-      contractionDims->n.size() != 1) {
+  if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 ||
+      contractionDims->n.size() < 1) {
     return failure();
   }
 
-  int64_t mDim = contractionDims->m[0];
-  int64_t nDim = contractionDims->n[0];
-  int64_t kDim = contractionDims->k[0];
+  // For now we are not being smart and trying to reshape dimensions to allow
+  // for better usage of intrinsics, and instead are tiling all dimensions
+  // except the inner most m, n, and k dimensions to 1.
+  int64_t mDim = contractionDims->m.back();
+  int64_t nDim = contractionDims->n.back();
+
+  // Bail out on matvec-like cases.
+  if (bounds[mDim] == 1 || bounds[nDim] == 1) {
+    return failure();
+  }
+
+  int64_t kDim = contractionDims->k.back();
 
   Value lhs = op.getDpsInputOperand(0)->get();
   Value rhs = op.getDpsInputOperand(1)->get();
@@ -520,6 +527,19 @@
   for (int64_t batch : contractionDims->batch) {
     workgroupTileSizes[batch] = 1;
   }
+
+  // Tile all m, n, and k dimensions to 1 except the innermost. Unit dims
+  // from this tiling are folded before vectorization.
+  for (int64_t m : llvm::drop_end(contractionDims->m)) {
+    workgroupTileSizes[m] = 1;
+  }
+  for (int64_t n : llvm::drop_end(contractionDims->n)) {
+    workgroupTileSizes[n] = 1;
+  }
+  for (int64_t k : llvm::drop_end(contractionDims->k)) {
+    workgroupTileSizes[k] = 1;
+  }
+
   // Compute the M/N dimension tile size by multiply subgroup information.
   workgroupTileSizes[mDim] =
       schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
@@ -555,7 +575,7 @@
 setVectorDistributionConfig(mlir::FunctionOpInterface entryPoint,
                             linalg::LinalgOp linalgOp,
                             const TargetInfo &targetInfo) {
-  if (isMatmulOrBatchMatmul(linalgOp)) {
+  if (linalg::isaContractionOpInterface(linalgOp)) {
     return setMatmulVectorDistributionConfig(entryPoint, linalgOp, targetInfo);
   }
   if (isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index 169e7c1..ae5d3bb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -5,6 +5,74 @@
 // to be migrated to the rocdl heuristics, but for now is just physically
 // located here.
 
+// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[1, 1, 64, 64, 32]{{\]}}
+// CHECK:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
+// CHECK-SAME:   intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>
+// CHECK-SAME:   subgroup_m_count = 1, subgroup_n_count = 4,
+// CHECK-SAME:   subgroup_m_tile_count = 4, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @expanded_matmul_transpose_b_executable {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
+      target_arch = "gfx940",
+      mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
+                        #iree_gpu.mfma_layout<F16_32x32x8_F32>]
+  }>) {
+  hal.executable.export @expanded_matmul_transpose_b layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @expanded_matmul_transpose_b() {
+        %c0 = arith.constant 0 : index
+        %cst = arith.constant 0.000000e+00 : f16
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+          : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+          : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+          : !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 2048], strides = [1, 1, 1]
+          : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>> -> tensor<2x64x2048xf16>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 64, 2048], strides = [1, 1, 1]
+          : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>> -> tensor<10x64x2048xf16>
+
+        %5 = tensor.empty() : tensor<2x10x64x64xf16>
+        %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
+        %7 = linalg.generic {
+          indexing_maps = [
+            affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
+            affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
+            affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+          ],
+          iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+        } ins(%3, %4 : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
+        ^bb0(%lhs: f16, %rhs: f16, %out: f16):
+          %mul = arith.mulf %lhs, %rhs : f16
+          %add = arith.addf %mul, %out : f16
+          linalg.yield %add : f16
+        } -> tensor<2x10x64x64xf16>
+
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 64, 64], strides = [1, 1, 1, 1]
+          : tensor<2x10x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+        return
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: hal.executable public @expanded_matmul_transpose_b
+// CHECK: linalg.generic {{.*}}lowering_config = #[[$TILE_SIZES]]
+
+// -----
+
 // CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[1, 1, 64, 128, 1, 1, 32]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index 31bdc44..ddcc660 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -111,7 +111,7 @@
 //          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf16>)
 //          CHECK:     arith.extf %[[ARG]] : vector<2x4x1x1x1x4xf16> to vector<2x4x1x1x1x4xf32>
 // CHECK-COUNT-16:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
-//          CHECK:     %[[TRUNC:.+]] = arith.truncf %157 : vector<2x4x1x1x1x4xf32> to vector<2x4x1x1x1x4xf16>
+//          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x4x1x1x1x4xf32> to vector<2x4x1x1x1x4xf16>
 //          CHECK:     scf.yield %[[TRUNC]] : vector<2x4x1x1x1x4xf16>
 //  CHECK-COUNT-8:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
 
@@ -123,6 +123,75 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
+hal.executable @expanded_matmul_transpose_b_executable {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
+      target_arch = "gfx940",
+      mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
+                        #iree_gpu.mfma_layout<F16_32x32x8_F32>]
+  }>) {
+  hal.executable.export @expanded_matmul_transpose_b layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @expanded_matmul_transpose_b() {
+        %c0 = arith.constant 0 : index
+        %cst = arith.constant 0.000000e+00 : f16
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
+          : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
+          : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
+          : !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 2048], strides = [1, 1, 1]
+          : !flow.dispatch.tensor<readonly:tensor<2x64x2048xf16>> -> tensor<2x64x2048xf16>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 64, 2048], strides = [1, 1, 1]
+          : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>> -> tensor<10x64x2048xf16>
+
+        %5 = tensor.empty() : tensor<2x10x64x64xf16>
+        %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
+        %7 = linalg.generic {
+          indexing_maps = [
+            affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
+            affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
+            affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+          ],
+          iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+        } ins(%3, %4 : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
+        ^bb0(%lhs: f16, %rhs: f16, %out: f16):
+          %mul = arith.mulf %lhs, %rhs : f16
+          %add = arith.addf %mul, %out : f16
+          linalg.yield %add : f16
+        } -> tensor<2x10x64x64xf16>
+
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 64, 64], strides = [1, 1, 1, 1]
+          : tensor<2x10x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x64x64xf16>>
+        return
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: hal.executable.export public @expanded_matmul_transpose_b
+//  CHECK-SAME:    workgroup_size = [256 : index, 1 : index, 1 : index]
+
+//    CHECK-LABEL: func @expanded_matmul_transpose_b
+//          CHECK:   scf.for {{.*}} = %c0 to %c2048 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x1x4xf16>)
+//          CHECK:     arith.extf %[[ARG]] : vector<4x1x1x1x1x4xf16> to vector<4x1x1x1x1x4xf32>
+// CHECK-COUNT-8:      amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+//          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x1x4xf32> to vector<4x1x1x1x1x4xf16>
+//          CHECK:     scf.yield %[[TRUNC]] : vector<4x1x1x1x1x4xf16>
+//  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<2x10x64x64xf16, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
 hal.executable @conv_nhwc_dispatch_0 {
 hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
       target_arch = "gfx940",