[Codegen][GPU] Add support for WMMA_I32_16x16x16_I8 (#18372)
This adds support for the signed variant of the I8 WMMA intrinsic for
RDNA3. The same instruction supports unsigned and mixed signedness
variants so integer intrinsics will need to be refactored away from
forced signed in the future.
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 070a735..20a2504 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -16,7 +16,7 @@
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
-// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>]
+// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
// GFX1100-SAME: subgroup_size_choices = [32, 64]
// GFX941: target = #iree_gpu.target<arch = "gfx941",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 3ac2379..cb12c92 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -237,6 +237,9 @@
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f16};
}
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
+ return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
+ }
}
llvm_unreachable("unhandled mfma layout type");
return OpaqueMmaLayout{};
@@ -353,38 +356,25 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
- case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16:
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
+ int64_t vecYShape = type == MMAIntrinsic::WMMA_F16_16x16x16_F16 ? 16 : 8;
+ int64_t laneYShape = type == MMAIntrinsic::WMMA_F16_16x16x16_F16 ? 1 : 2;
+
auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
- auto cMLayout =
- PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {8, 2, 1});
- auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
- return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
- bNLayout, cMLayout, cNLayout};
- }
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
- // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
- // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
- // #layout_a = #iree_vector_ext.layout<#outer, #inner>
- // #layout_b = #iree_vector_ext.layout<#inner, #outer>
-
- auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
- auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
- auto aMLayout = outer;
- auto aKLayout = inner;
- auto bKLayout = inner;
- auto bNLayout = outer;
- auto cMLayout =
- PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {16, 1, 1});
+ auto cMLayout = PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX},
+ {vecYShape, laneYShape, 1});
auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16});
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
@@ -480,7 +470,8 @@
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
- case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({8}, getCType());
@@ -514,7 +505,8 @@
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
- case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return 1;
}
}
@@ -533,7 +525,8 @@
return 64;
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return 32;
}
}
@@ -565,7 +558,8 @@
/*element=*/{1, 8}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
/*element=*/{1, 16}};
}
@@ -597,7 +591,8 @@
/*element=*/{8, 1}};
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{16, 1}};
}
@@ -619,7 +614,8 @@
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
- case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F32_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return {/*outer=*/{8, 1}, /*thread=*/{2, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
@@ -670,7 +666,8 @@
.getResult();
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F16:
+ case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 60501b7..e4e753d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -109,6 +109,10 @@
def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 6>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;
+// TODO: The actual I8 instruction allows specifying (mixed) signedness.
+// This will need to become its own class of MMA attribute.
+def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 8>;
+
def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
@@ -118,7 +122,8 @@
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
WMMA_F32_16x16x16_F16,
- WMMA_F16_16x16x16_F16
+ WMMA_F16_16x16x16_F16,
+ WMMA_I32_16x16x16_I8
]>;
def MMA_LHS : I32EnumAttrCase<"Lhs", 0>;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index d6f3867..d87887e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -192,6 +192,7 @@
static const MMAIntrinsic rdna3MMAOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
+ MMAIntrinsic::WMMA_I32_16x16x16_I8,
};
static const WgpDetails rdna3Wgp = {allComputeBits,
allStorageBits,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
index 1032fce..088a0f9 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
@@ -243,3 +243,34 @@
// CHECK-RESULT-SAME: : tensor<16x16xf16>, tensor<16x16xf16> into tensor<16x1x16xf16>
// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0, 1], [2]]
// CHECK-RESULT: return %[[COLLAPSED]]
+
+// -----
+
+#contraction_accesses = [
+ affine_map<() -> ()>,
+ affine_map<() -> ()>,
+ affine_map<() -> ()>
+]
+func.func @concretize_WMMA_I32_16x16x16_I8(%lhs: tensor<16x16xi8>, %rhs: tensor<16x16xi8>, %acc: tensor<16x16xi32>) -> tensor<16x16xi32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [],
+ kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>
+ } : tensor<16x16xi8>, tensor<16x16xi8> into tensor<16x16xi32>
+ return %0 : tensor<16x16xi32>
+}
+
+// CHECK-LABEL: func @concretize_WMMA_I32_16x16x16_I8
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<16x16xi8>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x16xi8>
+// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<16x16xi32>
+
+// CHECK-INPUTS-NOT: tensor.expand_shape
+// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma
+// CHECK-INPUTS: return %[[MMA]]
+
+// CHECK-RESULT: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0, 1], [2]] output_shape [8, 2, 16]
+// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME: : tensor<16x16xi8>, tensor<16x16xi8> into tensor<8x2x16xi32>
+// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0, 1], [2]]
+// CHECK-RESULT: return %[[COLLAPSED]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
index 60efae3..aa97fc9 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
@@ -239,7 +239,6 @@
}
// CHECK-DAG: #[[$XMAP:.+]] = affine_map<(d0) -> (d0 mod 16)>
-// CHECK-DAG: #[[$YMAP:.+]] = affine_map<() -> ()>
// CHECK-LABEL: func @distribute_WMMA_F16_16x16x16_F16
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<16x16xf16>
@@ -254,3 +253,44 @@
// CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<16x1x1xf16>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDX]]] [16, 1, 1]
// CHECK: mapping = [#iree_gpu.lane_id<0>]
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+module {
+ func.func @matmul_wmma_i32_16x16x16_i8(%arg0: tensor<2x8x16x16xi8>, %arg1: tensor<8x2x16x16xi8>, %arg2: tensor<2x2x8x2x16xi32>) -> tensor<2x2x8x2x16xi32> {
+ %mm = iree_gpu.multi_mma %arg0, %arg1, %arg2 {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>,
+ rhs_permutation = array<i64: 1, 0>
+ } : tensor<2x8x16x16xi8>, tensor<8x2x16x16xi8> into tensor<2x2x8x2x16xi32>
+ return %mm : tensor<2x2x8x2x16xi32>
+ }
+}
+
+// CHECK-DAG: #[[$XMAP:.+]] = affine_map<(d0) -> (d0 mod 16)>
+// CHECK-DAG: #[[$YMAP:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 2)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @matmul_wmma_i32_16x16x16_i8
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x8x16x16xi8>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xi8>
+// CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x8x2x16xi32>)
+// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$XMAP]](%[[LANEID]])
+// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IDX]], 0] [2, 8, 1, 16]
+// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IDX]], 0] [8, 2, 1, 16]
+// CHECK-DAG: %[[IDY:.+]] = affine.apply #[[$YMAP]](%[[LANEID]])
+// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[IDX]]] [2, 2, 8, 1, 1]
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>
+// CHECK-SAME: : tensor<2x8x1x16xi8>, tensor<8x2x1x16xi8> into tensor<2x2x8x1x1xi32>
+// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[IDX]]] [2, 2, 8, 1, 1]
+// CHECK: mapping = [#iree_gpu.lane_id<0>]
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index f117ba1..03dfdf3 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2510,6 +2510,35 @@
iree_generated_e2e_runner_test(
NAME
+ e2e_matmul_rocm_i8_large_rdna3_wmma_tb
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=i8"
+ "--acc_type=i32"
+ "--transpose_rhs"
+ "--shapes=gpu_large_aligned"
+ "--compilation_info=LLVMGPUVectorDistributeWMMA"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-rdna3"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
e2e_matmul_rdna3_experimental_dt_f32_f32
TEST_TYPE
matmul
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 6cb2702..c1e8907 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -299,6 +299,13 @@
MMASchedule("WMMA_F32_16x16x16_F16", 2, 2, 1, 1, 1),
MMASchedule("WMMA_F32_16x16x16_F16", 2, 4, 2, 1, 2),
MMASchedule("WMMA_F32_16x16x16_F16", 4, 2, 4, 2, 2),
+ MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 1, 1, 1),
+ MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 1, 1, 2),
+ MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 1, 2, 1),
+ MMASchedule("WMMA_I32_16x16x16_I8", 1, 1, 2, 1, 1),
+ MMASchedule("WMMA_I32_16x16x16_I8", 2, 2, 1, 1, 1),
+ MMASchedule("WMMA_I32_16x16x16_I8", 2, 4, 2, 1, 2),
+ MMASchedule("WMMA_I32_16x16x16_I8", 4, 2, 4, 2, 2),
]
else:
raise NotImplementedError("unhandled intrinsic case")
@@ -342,6 +349,10 @@
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
+ elif schedule.intrinsic == "WMMA_I32_16x16x16_I8":
+ wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+ wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+ wg_tile_k = schedule.k_tile_count * 16
else:
raise NotImplementedError("unhandled intrinsic case")