More IREEGPUAttrs.cpp cleanups (#19142)

Two things in this PR:
1. Make a big switch statement more concise.
2. Currently, `DataTileMMAAttr::buildMmaOperation` creates a `MMAAttr`
just to call `buildMmaOperation` on it, to reuse that implementation. In
addition to being roundabout, this required a comment explaining why we
discarded the error status, as `MMAAttr::buildMmaOperation` is fallible
but here we were already past validation and mutating IR. This PR
refactors that to let both call a shared, infallible helper.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index e011b78..0a27437 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -220,88 +220,45 @@
   Type i8 = IntegerType::get(context, 8);
   Type i32 = IntegerType::get(context, 32);
   switch (intrinsic) {
-  case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
+  case MMAIntrinsic::MFMA_F64_16x16x4_F64:
     return {f64, f64, f64};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
+  case MMAIntrinsic::MFMA_F32_16x16x4_F32:
     return {f32, f32, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
+  case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+  case MMAIntrinsic::WMMA_F32_16x16x16_F16:
+  case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
     return {f16, f16, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
-    return {f16, f16, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
-    return {bf16, bf16, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
-    return {bf16, bf16, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
-    return {bf16, bf16, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
-    return {bf16, bf16, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
-    return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
-    return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
-    return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
-    return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
-    return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
-    return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
-    return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
-    return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
-  }
-  case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
-    return {i8, i8, i32};
-  }
-  case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
-    return {i8, i8, i32};
-  }
-  case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
-    return {i8, i8, i32};
-  }
-  case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
-    return {i8, i8, i32};
-  }
-  case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
-    return {f16, f16, f32};
-  }
-  case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+  case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+  case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
     return {f16, f16, f16};
-  }
-  case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
+  case MMAIntrinsic::MFMA_F32_16x16x8_BF16:
+  case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
+  case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
+  case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
+  case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
     return {bf16, bf16, f32};
-  }
-  case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
+  case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
     return {bf16, bf16, bf16};
-  }
-  case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
+  case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
+  case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
+    return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
+  case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+  case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
+    return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
+  case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
+  case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
+    return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
+  case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
+  case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
+    return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
+  case MMAIntrinsic::MFMA_I32_16x16x16_I8:
+  case MMAIntrinsic::MFMA_I32_32x32x8_I8:
+  case MMAIntrinsic::MFMA_I32_16x16x32_I8:
+  case MMAIntrinsic::MFMA_I32_32x32x16_I8:
+  case MMAIntrinsic::WMMA_I32_16x16x16_I8:
     return {i8, i8, i32};
   }
-  case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
-    return {f16, f16, f16};
-  }
-  case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
-    return {f16, f16, f32};
-  }
-  }
   assert(false && "unexpected enum value");
   return {};
 }
@@ -498,11 +455,15 @@
   return IREE::GPU::getContractionLayout(contract, layout);
 }
 
-int64_t MMAAttr::getBlockSize() const {
+static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
   // Not supporting any block size other than 1 at the moment.
   return 1;
 }
 
+int64_t MMAAttr::getBlockSize() const {
+  return IREE::GPU::getBlockSize(getIntrinsic().getValue());
+}
+
 static uint32_t getArchID(MMAIntrinsic intrinsic) {
   return static_cast<int>(intrinsic) & 0xFF00;
 }
@@ -704,6 +665,31 @@
   }
 }
 
+static Value createMmaOp(OpBuilder &builder, Location loc,
+                         MMAIntrinsic intrinsic, Type resultType, Value lhs,
+                         Value rhs, Value acc) {
+  auto getVecOrSingleElem = [&](Value vec) -> Value {
+    bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
+    return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
+  };
+  auto layout = getOpaqueMMALayout(builder.getContext(), intrinsic);
+  if (is_AMD_MFMA(intrinsic)) {
+    // MFMA intrinsics want single-element operands of element type, not vector.
+    lhs = getVecOrSingleElem(lhs);
+    rhs = getVecOrSingleElem(rhs);
+    return builder
+        .create<amdgpu::MFMAOp>(loc, resultType, layout.mSize, layout.nSize,
+                                layout.kSize, getBlockSize(intrinsic), lhs, rhs,
+                                acc)
+        .getResult();
+  }
+  if (is_AMD_WMMA(intrinsic)) {
+    return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
+        .getResult();
+  }
+  return {};
+}
+
 // Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
 // type.
 FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
@@ -719,23 +705,9 @@
   if (cType != resultType) {
     return failure();
   }
-  auto getVecOrSingleElem = [&](Value vec) -> Value {
-    bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
-    return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
-  };
-  MMAIntrinsic intrinsic = getIntrinsic().getValue();
-  if (is_AMD_MFMA(intrinsic)) {
-    // MFMA intrinsics want single-element operands of element type, not vector.
-    lhs = getVecOrSingleElem(lhs);
-    rhs = getVecOrSingleElem(rhs);
-    auto [m, n, k] = getMNKShape();
-    return builder
-        .create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
-                                rhs, acc)
-        .getResult();
-  } else if (is_AMD_WMMA(intrinsic)) {
-    return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
-        .getResult();
+  if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(),
+                                resultType, lhs, rhs, acc)) {
+    return value;
   }
   return failure();
 }
@@ -1168,23 +1140,18 @@
   SmallVector<Value> intrinsicsAcc =
       distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);
 
-  // Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
-  // to create the target intrinsics.
-  auto intrinsicMma = MMAAttr::get(getContext(), getIntrinsic().getValue());
-  auto [intrinsicAType, intrinsicBType, intrinsicCType] =
-      intrinsicMma.getABCVectorTypes();
+  MMAIntrinsic intrinsic = getIntrinsic().getValue();
+  VectorType intrinCType =
+      getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc);
 
   // Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics.
   for (int mu = 0; mu < getUnrollM(); ++mu) {
     for (int nu = 0; nu < getUnrollN(); ++nu) {
       for (int ku = 0; ku < getUnrollK(); ++ku) {
-        // Assume intrinsicMma.buildMmaOperation() success: validation should be
-        // completed prior to mutating IR.
         Value lhs = intrinsicsLhs[mu * getUnrollK() + ku];
         Value rhs = intrinsicsRhs[nu * getUnrollK() + ku];
         Value &acc = intrinsicsAcc[mu * getUnrollN() + nu];
-        acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs,
-                                              rhs, acc);
+        acc = createMmaOp(builder, loc, intrinsic, intrinCType, lhs, rhs, acc);
       }
     }
   }