[rocdl] Fix mfma accumulator base vector type (#16549)

The underlying rocdl op actually accepts 1-D vectors. We are shape
casting anyway, so just cast to the final 1-D vector shape.

This fixes crash when converting from rocdl dialect to LLVM proper, and
allows us to add some more tests to matmul suite.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index e3f963c..2d982c7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -81,11 +81,11 @@
 //       CHECK:   %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
 //       CHECK:   %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
 //       CHECK:   %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x1x4xf32> to vector<4x4xf32>
+//       CHECK:   %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x1x4xf32> to vector<16xf32>
 //       CHECK:   %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
 //  CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
-//  CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<4x4xf32>
-//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+//  CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
+//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x1x4xf32>
 //       CHECK:   %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
 //       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
 //       CHECK:   return {{.*}} %[[R_SIMD]]
@@ -249,16 +249,16 @@
 //       CHECK:   %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
 //       CHECK:   %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
 //       CHECK:   %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x1x4xf32> to vector<4x4xf32>
+//       CHECK:   %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x1x4xf32> to vector<16xf32>
 //       CHECK:   %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
-//       CHECK:   %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+//       CHECK:   %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x1x4xf32>
 //       CHECK:   %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
 //       CHECK:   %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
 //       CHECK:   %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
 //       CHECK:   %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
-//       CHECK:   %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x1x4xf32> to vector<4x4xf32>
+//       CHECK:   %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x1x4xf32> to vector<16xf32>
 //       CHECK:   %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
-//       CHECK:   %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+//       CHECK:   %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
 //       CHECK:   %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
 //       CHECK:   %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x1x4xf32> -> vector<64x32xf32>
 //       CHECK:   return {{.*}}} %[[R]]
@@ -348,7 +348,7 @@
 //       CHECK:   %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
 //       CHECK:   %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
 //       CHECK:   %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
-//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+//       CHECK:   %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
 //       CHECK:   %[[INSERT:.+]] = vector.insert %[[R_CAST]], %{{.+}} [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
 //       CHECK:   %[[R:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
 //       CHECK:   return {{.*}} %[[R]] : vector<32x32xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 9dc6abe..d01cfa2 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -305,6 +305,10 @@
 
 std::tuple<VectorType, VectorType, VectorType>
 MFMAAttr::getABCVectorTypes() const {
+  // Check https://github.com/ROCm/amd_matrix_instruction_calculator for
+  // instruction details. Note here we are returning the number elements, while
+  // amd_matrix_instruction_calculator tells us about the number of 32-bit
+  // registers. So need to adjust accordingly. All vectors should be 1-D.
   switch (getIntrinsic().getValue()) {
   case MFMAIntrinsic::F16_16x16x16_F32: {
     auto aType = VectorType::get({4}, getAType());
@@ -315,7 +319,7 @@
   case MFMAIntrinsic::F16_32x32x8_F32: {
     auto aType = VectorType::get({4}, getAType());
     auto bType = VectorType::get({4}, getBType());
-    auto cType = VectorType::get({4, 4}, getCType());
+    auto cType = VectorType::get({16}, getCType());
     return std::make_tuple(aType, bType, cType);
   }
   }
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 151ca42..e295a3d 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -258,6 +258,10 @@
         MMASchedule("F16_16x16x16_F32", 2, 2, 1, 1, 1),
         MMASchedule("F16_16x16x16_F32", 2, 4, 2, 1, 2),
         MMASchedule("F16_16x16x16_F32", 4, 2, 4, 2, 2),
+        MMASchedule("F16_32x32x8_F32", 1, 1, 1, 2, 2),
+        MMASchedule("F16_32x32x8_F32", 2, 2, 1, 1, 1),
+        MMASchedule("F16_32x32x8_F32", 1, 4, 2, 1, 2),
+        MMASchedule("F16_32x32x8_F32", 4, 2, 1, 2, 4),
     ]
 
     infos = []