[GPU] Remove MMAScheduleAttr (#21884)

This attribute isn't used anywhere and is deprecated.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 40aa600..2f9c623 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -230,7 +230,7 @@
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>
 ]>
-func.func @conv() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
+func.func @conv() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64>} {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 857482b..1a73ac7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -2567,7 +2567,7 @@
       hal.return %x, %y, %z : index, index, index
     }
     builtin.module {
-      func.func @set_size_to_tilesize_when_divisible() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 32, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
+      func.func @set_size_to_tilesize_when_divisible() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 32>} {
         %c0 = arith.constant 0 : index
         %c32_i64 = arith.constant 32 : i64
         %cst = arith.constant 0.000000e+00 : f16
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index 7357997..d8e5875 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -415,28 +415,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// MMA schedule
-//===----------------------------------------------------------------------===//
-
-def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
-  let mnemonic = "mma_schedule";
-  let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
-
-  string description = [{
-    A schedule of MMA intrinsic instruction and various levels of tile sizes
-    to solve a specific contraction problem.
-  }];
-
-  let parameters = (ins
-    "::mlir::iree_compiler::IREE::Codegen::InnerTileDescAttrInterface":$intrinsic,
-    "int64_t":$subgroup_m_count,
-    "int64_t":$subgroup_n_count
-  );
-
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
-//===----------------------------------------------------------------------===//
 // iree_gpu.gpu_encoding_resolver
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index da94801..c87c495 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -199,23 +199,24 @@
   return bounds;
 }
 
-static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+struct MMASchedule {
+  IREE::Codegen::InnerTileDescAttrInterface intrinsic;
+  int64_t subgroupMCount;
+  int64_t subgroupNCount;
+};
+
+static LogicalResult setContractionAnchor(MMASchedule &schedule,
                                           SmallVector<bool> promotedOperands,
                                           RewriterBase &rewriter,
                                           linalg::LinalgOp contract) {
-  // TODO: Add SIMT fallback.
-  if (!schedule) {
-    return contract->emitError("missing mma schedule for contraction");
-  }
-
   // This function should have only be called on a contraction op.
   assert(linalg::isaContractionOpInterface(contract) &&
          "cannot set contraction anchor on non contraction op");
 
   SmallVector<int64_t> bounds = getIterationSpaceBounds(contract);
   auto layouts = getContractionLayout(
-      schedule.getIntrinsic(), schedule.getSubgroupMCount(),
-      schedule.getSubgroupNCount(), bounds, contract.getIndexingMapsArray());
+      schedule.intrinsic, schedule.subgroupMCount, schedule.subgroupNCount,
+      bounds, contract.getIndexingMapsArray());
   if (failed(layouts)) {
     return contract->emitError("cannot get concrete layout for contraction");
   }
@@ -230,11 +231,11 @@
   // Set layouts for lhs, rhs and acc.
   rewriter.setInsertionPoint(contract);
   auto layoutedLhs =
-      ToLayoutOp::create(rewriter, loc, lhs, aLayout, schedule.getIntrinsic());
+      rewriter.create<ToLayoutOp>(loc, lhs, aLayout, schedule.intrinsic);
   auto layoutedRhs =
-      ToLayoutOp::create(rewriter, loc, rhs, bLayout, schedule.getIntrinsic());
+      rewriter.create<ToLayoutOp>(loc, rhs, bLayout, schedule.intrinsic);
   auto layoutedAcc =
-      ToLayoutOp::create(rewriter, loc, acc, cLayout, schedule.getIntrinsic());
+      rewriter.create<ToLayoutOp>(loc, acc, cLayout, schedule.intrinsic);
 
   // Promote matmul lhs and rhs.
   // TODO: This is a hack until layout analysis is improved. The layout analysis
@@ -257,23 +258,18 @@
 
   // Set layout for result.
   rewriter.setInsertionPointAfter(contract);
-  auto toLayout = ToLayoutOp::create(rewriter, loc, contract->getResult(0),
-                                     cLayout, schedule.getIntrinsic());
+  auto toLayout = rewriter.create<ToLayoutOp>(loc, contract->getResult(0),
+                                              cLayout, schedule.intrinsic);
   rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
                                 toLayout);
 
   return success();
 }
 
-static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
+static LogicalResult setConvolutionAnchor(MMASchedule schedule,
                                           SmallVector<bool> promotedOperands,
                                           RewriterBase &rewriter,
                                           linalg::LinalgOp conv) {
-  // TODO: Add SIMT fallback.
-  if (!schedule) {
-    return conv->emitError("missing mma schedule for convolution");
-  }
-
   // This function should have only be called on a convolution op.
   FailureOr<linalg::ConvolutionDimensions> convDims =
       linalg::inferConvolutionDims(conv);
@@ -298,9 +294,9 @@
   }
 
   SmallVector<int64_t> bounds = getIterationSpaceBounds(conv);
-  auto layouts = getContractionLayout(
-      schedule.getIntrinsic(), schedule.getSubgroupMCount(),
-      schedule.getSubgroupNCount(), bounds, maps);
+  FailureOr<ContractionLayout> layouts =
+      getContractionLayout(schedule.intrinsic, schedule.subgroupMCount,
+                           schedule.subgroupNCount, bounds, maps);
 
   auto [aLayout, bLayout, cLayout] = *layouts;
   Location loc = conv.getLoc();
@@ -312,11 +308,11 @@
   // Set layouts for lhs, rhs and acc.
   rewriter.setInsertionPoint(conv);
   auto layoutedLhs =
-      ToLayoutOp::create(rewriter, loc, lhs, aLayout, schedule.getIntrinsic());
+      rewriter.create<ToLayoutOp>(loc, lhs, aLayout, schedule.intrinsic);
   auto layoutedRhs =
-      ToLayoutOp::create(rewriter, loc, rhs, bLayout, schedule.getIntrinsic());
+      rewriter.create<ToLayoutOp>(loc, rhs, bLayout, schedule.intrinsic);
   auto layoutedAcc =
-      ToLayoutOp::create(rewriter, loc, acc, cLayout, schedule.getIntrinsic());
+      rewriter.create<ToLayoutOp>(loc, acc, cLayout, schedule.intrinsic);
 
   // Promote matmul lhs and rhs.
   // TODO: This is a hack until layout analysis is improved. The layout analysis
@@ -339,8 +335,8 @@
 
   // Set layout for result.
   rewriter.setInsertionPointAfter(conv);
-  auto toLayout = ToLayoutOp::create(rewriter, loc, conv->getResult(0), cLayout,
-                                     schedule.getIntrinsic());
+  auto toLayout = rewriter.create<ToLayoutOp>(loc, conv->getResult(0), cLayout,
+                                              schedule.intrinsic);
   rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
                                 toLayout);
 
@@ -401,26 +397,14 @@
   contractOp->setOperand(1, lhs);
 }
 
-static IREE::GPU::MMAScheduleAttr
-transposeSchedule(RewriterBase &rewriter, IREE::GPU::MMAScheduleAttr schedule) {
-  return rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
-      schedule.getIntrinsic(), schedule.getSubgroupNCount(),
-      schedule.getSubgroupMCount());
-}
-
 static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
                                               linalg::LinalgOp qkMatmul,
                                               linalg::LinalgOp pvMatmul) {
 
-  IREE::GPU::MMAScheduleAttr qkSchedule =
-      rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(qkMatmul),
-                                                   getSubgroupMCount(qkMatmul),
-                                                   getSubgroupNCount(qkMatmul));
-
-  IREE::GPU::MMAScheduleAttr pvSchedule =
-      rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(pvMatmul),
-                                                   getSubgroupMCount(pvMatmul),
-                                                   getSubgroupNCount(pvMatmul));
+  MMASchedule qkSchedule = {getIntrinsic(qkMatmul), getSubgroupMCount(qkMatmul),
+                            getSubgroupNCount(qkMatmul)};
+  MMASchedule pvSchedule = {getIntrinsic(pvMatmul), getSubgroupMCount(pvMatmul),
+                            getSubgroupNCount(pvMatmul)};
 
   // Check if the intrinsic output for qkMatmul can be reused for pvMatmul.
   // We know that pvMatmul takes result of qkMatmul as it's lhs.
@@ -429,10 +413,10 @@
   bool reuseIntrinsicOutput = false;
   bool transposeIntrinsic = false;
 
-  auto qkIntrinsic = cast<IREE::Codegen::InnerTileDescAttrInterface>(
-      qkSchedule.getIntrinsic());
-  auto pvIntrinsic = cast<IREE::Codegen::InnerTileDescAttrInterface>(
-      pvSchedule.getIntrinsic());
+  auto qkIntrinsic =
+      cast<IREE::Codegen::InnerTileDescAttrInterface>(qkSchedule.intrinsic);
+  auto pvIntrinsic =
+      cast<IREE::Codegen::InnerTileDescAttrInterface>(pvSchedule.intrinsic);
   IREE::GPU::MMASingleSubgroupLayout lhsLayout =
       getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Lhs);
   IREE::GPU::MMASingleSubgroupLayout rhsLayout =
@@ -475,8 +459,8 @@
     }
     swapOperandsToTransposeIntrinsic(rewriter, qkGeneric);
     swapOperandsToTransposeIntrinsic(rewriter, pvGeneric);
-    qkSchedule = transposeSchedule(rewriter, qkSchedule);
-    pvSchedule = transposeSchedule(rewriter, pvSchedule);
+    std::swap(qkSchedule.subgroupMCount, qkSchedule.subgroupNCount);
+    std::swap(pvSchedule.subgroupMCount, pvSchedule.subgroupNCount);
 
     // Swap promoted operands.
     std::swap(promotedQKOperands[0], promotedQKOperands[1]);
@@ -581,9 +565,8 @@
     ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
 
   SmallVector<bool> promotedOperands = getPromotedOperands(candidate);
-  auto schedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
-      getIntrinsic(candidate), getSubgroupMCount(candidate),
-      getSubgroupNCount(candidate));
+  MMASchedule schedule = {getIntrinsic(candidate), getSubgroupMCount(candidate),
+                          getSubgroupNCount(candidate)};
 
   if (linalg::isaContractionOpInterface(candidate)) {
     if (succeeded(setContractionAnchor(schedule, promotedOperands, rewriter,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index ea10c44..1683f88 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -1,10 +1,6 @@
 // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-cast-type-to-fit-mma))' -mlir-print-local-scope %s | FileCheck %s
 
-func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
-      subgroup_m_count = 1, subgroup_n_count = 1>,
-    workgroup_size = [64, 1, 1]} {
+func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -28,11 +24,7 @@
 
 // -----
 
-func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
-      subgroup_m_count = 1, subgroup_n_count = 1>,
-    workgroup_size = [64, 1, 1]} {
+func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16>  {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -53,11 +45,7 @@
 
 // -----
 
-func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
-      subgroup_m_count = 1, subgroup_n_count = 1>,
-    workgroup_size = [64, 1, 1]} {
+func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -77,11 +65,7 @@
 
 // -----
 
-func.func @wmmar3_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf16>) -> vector<48x32xf16> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>,
-      subgroup_m_count = 1, subgroup_n_count = 1>,
-    workgroup_size = [32, 1, 1]} {
+func.func @wmmar3_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf16>, %init: vector<48x32xf16>) -> vector<48x32xf16> {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -109,11 +93,7 @@
 // "iree.amdgpu.mma" will be generated from the "intrinsic" attribute of to_layout.
 // this also shows that we can overwrite default intrinsics if explicitly set.
 
-func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
-    mma_schedule = #iree_gpu.mma_schedule<
-      intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
-      subgroup_m_count = 1, subgroup_n_count = 1>,
-    workgroup_size = [64, 1, 1]} {
+func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
@@ -141,7 +121,7 @@
 // it will not have mma_schedule on function attributes, but instead it will have
 // "iree.amdgpu.mma" attribute directly on vector.contract.
 
-func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {translation_info = #iree_codegen.translation_info<pipeline = None workgroup_size = [64, 1, 1] subgroup_size = 64>} {
+func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> {
     %0 = vector.contract {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index 074664c..62e48ef 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -152,10 +152,7 @@
 
 #translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
                                               workgroup_size = [64, 1, 1]
-                                              subgroup_size = 64,
-      {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-                                             subgroup_m_count = 4,
-                                             subgroup_n_count = 1>}>
+                                              subgroup_size = 64>
 
 #maps = [
   affine_map<(m, n, k) -> (m, k)>,