[rocdl] Fix mfma accumulator base vector type (#16549)
The underlying rocdl op actually accepts 1-D vectors. We are shape
casting anyway, so just cast to the final 1-D vector shape.
This fixes crash when converting from rocdl dialect to LLVM proper, and
allows us to add some more tests to matmul suite.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index e3f963c..2d982c7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -81,11 +81,11 @@
// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<1x1x1x1x1x4xf16>
// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x1x4xf32> to vector<4x4xf32>
+// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x1x4xf32> to vector<16xf32>
// CHECK: %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
-// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<4x4xf32>
-// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x1x4xf32>
// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
// CHECK: return {{.*}} %[[R_SIMD]]
@@ -249,16 +249,16 @@
// CHECK: %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
// CHECK: %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x1x4xf32> to vector<4x4xf32>
+// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x1x4xf32> to vector<16xf32>
// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
-// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x1x4xf32>
// CHECK: %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x1x4xf32> from vector<2x1x4x1x1x4xf32>
// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
-// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x1x4xf32> to vector<4x4xf32>
+// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x1x4xf32> to vector<16xf32>
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
-// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
// CHECK: %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x1x4xf32> into vector<2x1x4x1x1x4xf32>
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x1x4xf32> -> vector<64x32xf32>
// CHECK: return {{.*}}} %[[R]]
@@ -348,7 +348,7 @@
// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
-// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<4x4xf32> to vector<4x1x1x4xf32>
+// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x1x4xf32>
// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %{{.+}} [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
// CHECK: return {{.*}} %[[R]] : vector<32x32xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 9dc6abe..d01cfa2 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -305,6 +305,10 @@
std::tuple<VectorType, VectorType, VectorType>
MFMAAttr::getABCVectorTypes() const {
+ // Check https://github.com/ROCm/amd_matrix_instruction_calculator for
+ // instruction details. Note here we are returning the number elements, while
+ // amd_matrix_instruction_calculator tells us about the number of 32-bit
+ // registers. So need to adjust accordingly. All vectors should be 1-D.
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32: {
auto aType = VectorType::get({4}, getAType());
@@ -315,7 +319,7 @@
case MFMAIntrinsic::F16_32x32x8_F32: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
- auto cType = VectorType::get({4, 4}, getCType());
+ auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
}
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 151ca42..e295a3d 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -258,6 +258,10 @@
MMASchedule("F16_16x16x16_F32", 2, 2, 1, 1, 1),
MMASchedule("F16_16x16x16_F32", 2, 4, 2, 1, 2),
MMASchedule("F16_16x16x16_F32", 4, 2, 4, 2, 2),
+ MMASchedule("F16_32x32x8_F32", 1, 1, 1, 2, 2),
+ MMASchedule("F16_32x32x8_F32", 2, 2, 1, 1, 1),
+ MMASchedule("F16_32x32x8_F32", 1, 4, 2, 1, 2),
+ MMASchedule("F16_32x32x8_F32", 4, 2, 1, 2, 4),
]
infos = []