[Codegen][GPU] Add iree_gpu.multi_mma op to PartitionableLoopsInterface (#18653)

This PR registers the iree_gpu.multi_mma op with the
`OuterParallelAsPartitionableLoops` interface implementation for
`PartitionableLoopsInterface`. This enables workgroup tiling and
distribution for GPU data tiling codegen, since the
TileAndDistributeToWorkgroupsPass uses the interface.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
index 5c11c61..a99c8d4 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
@@ -160,6 +160,7 @@
     ],
     deps = [
         ":PartitionableLoopsInterfaceGen",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:DialectUtils",
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index 53e91f5..7a842a5 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -127,6 +127,7 @@
     MLIRLinalgStructuredOpsIncGenLib
     MLIRSupport
     MLIRTensorDialect
+    iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
     iree::compiler::Dialect::LinalgExt::IR
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
index b695e1e..94be437 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
@@ -6,6 +6,8 @@
 
 #include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
 
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "llvm/ADT/SmallVector.h"
@@ -268,6 +270,11 @@
     tensor::UnPackOp::attachInterface<
         OuterParallelAsPartitionableLoops<tensor::UnPackOp>>(*ctx);
   });
+  registry.addExtension(
+      +[](MLIRContext *ctx, IREE::GPU::IREEGPUDialect *dialect) {
+        IREE::GPU::MultiMmaOp::attachInterface<
+            OuterParallelAsPartitionableLoops<IREE::GPU::MultiMmaOp>>(*ctx);
+      });
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 4b44314..9945e2d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -787,12 +787,12 @@
       func.func @multi_mma_data_tiled_unrolled_MFMA_F32_16x16x4_F32()
         attributes {translation_info = #translation_info} {
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x8x4x16x4xf32>>
-        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4x2x4x16x4xf32>>
-        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<1x1x8x4x2x4x16x4xf32>>
-        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x8x4x16x4xf32>> -> tensor<1x1x8x4x16x4xf32>
-        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4x2x4x16x4xf32>> -> tensor<1x1x4x2x4x16x4xf32>
-        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x1x8x4x2x4x16x4xf32>> -> tensor<1x1x8x4x2x4x16x4xf32>
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x1x8x4x16x4xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x1x4x2x4x16x4xf32>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0, 0], sizes = [4, 1, 8, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x1x8x4x16x4xf32>> -> tensor<4x1x8x4x16x4xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0, 0, 0, 0], sizes = [4, 1, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x1x4x2x4x16x4xf32>> -> tensor<4x1x4x2x4x16x4xf32>
+        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [4, 4, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>> -> tensor<4x4x8x4x2x4x16x4xf32>
         %6 = iree_gpu.multi_mma %3, %4, %5 {
           lowering_config = #config,
           indexing_maps = [
@@ -809,8 +809,8 @@
             unroll_n = 2,
             unroll_n_to_subgroups = 4,
             unroll_k = 4>}
-          : tensor<1x1x8x4x16x4xf32>, tensor<1x1x4x2x4x16x4xf32> into tensor<1x1x8x4x2x4x16x4xf32>
-        flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<1x1x8x4x2x4x16x4xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x1x8x4x2x4x16x4xf32>>
+          : tensor<4x1x8x4x16x4xf32>, tensor<4x1x4x2x4x16x4xf32> into tensor<4x4x8x4x2x4x16x4xf32>
+        flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [4, 4, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<4x4x8x4x2x4x16x4xf32> -> !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>>
         return
       }
     }