[LLVMGPU][ROCm] Plumb through i8, i8 -> i32 MFMA intrinsics (#17764)
Add tests to make sure these are generated in the vector distribution
pipeline. Add e2e correctness tests.
I also tested this manually on random inputs against golden outputs from
numpy.
This contains one cherry-pick for llvm-project.
---------
Co-authored-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Lei Zhang <antiagainst@gmail.com>
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 0f427c5..c801f9b 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
-// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>],
+// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>
// GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>]
+// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 8068200..cbca100 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -205,6 +205,10 @@
MMAIntrinsic type) {
Type f16 = Float16Type::get(context);
Type f32 = Float32Type::get(context);
+
+ Type i8 = IntegerType::get(context, 8);
+ Type i32 = IntegerType::get(context, 32);
+
switch (type) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
@@ -212,6 +216,12 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+ return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
+ }
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+ return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
+ }
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
@@ -280,13 +290,47 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
+ // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+ // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+
+ auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+ auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 8});
+ auto aMLayout = outer;
+ auto aKLayout = inner;
+ auto bKLayout = inner;
+ auto bNLayout = outer;
+ auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
+ auto cNLayout = outer;
+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+ bNLayout, cMLayout, cNLayout};
+ }
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]>
+ // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+ // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+
+ auto outer = PerDimLayoutAttr::get(context, {laneX}, {32});
+ auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 8});
+ auto aMLayout = outer;
+ auto aKLayout = inner;
+ auto bKLayout = inner;
+ auto bNLayout = outer;
+ auto cMLayout =
+ PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4});
+ auto cNLayout = outer;
+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+ bNLayout, cMLayout, cNLayout};
+ }
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
- // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
- // #layout_c = #iree_vector_ext.layout<#inner, #outer>
auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
@@ -372,6 +416,18 @@
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+ auto aType = VectorType::get({8}, getAType());
+ auto bType = VectorType::get({8}, getBType());
+ auto cType = VectorType::get({4}, getCType());
+ return std::make_tuple(aType, bType, cType);
+ }
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+ auto aType = VectorType::get({8}, getAType());
+ auto bType = VectorType::get({8}, getBType());
+ auto cType = VectorType::get({16}, getCType());
+ return std::make_tuple(aType, bType, cType);
+ }
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
auto aType = VectorType::get({16}, getAType());
@@ -396,6 +452,8 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32:
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 1;
@@ -408,7 +466,9 @@
int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
- case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
+ case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32:
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return 64;
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
@@ -430,6 +490,14 @@
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+ return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
+ /*element=*/{1, 8}};
+ }
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+ return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
+ /*element=*/{1, 8}};
+ }
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 16},
@@ -449,6 +517,14 @@
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+ return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
+ /*element=*/{8, 1}};
+ }
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+ return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
+ /*element=*/{8, 1}};
+ }
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
@@ -460,11 +536,13 @@
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
- case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
+ case MMAIntrinsic::MFMA_F16_16x16x16_F32:
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
- case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
+ case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
@@ -494,7 +572,9 @@
}
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
- case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
+ case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+ case MMAIntrinsic::MFMA_I8_16x16x32_I32:
+ case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index a7abbb6..0423c2f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -101,14 +101,18 @@
// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
+def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
+def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
-def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 2>;
-def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 3>;
+def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
+def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
+ MFMA_I8_16x16x32_I32,
+ MFMA_I8_32x32x16_I32,
WMMA_F16_16x16x16_F32,
WMMA_F16_16x16x16_F16
]>;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index b6f0aeb..993963c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -124,6 +124,8 @@
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
+ MMAIntrinsic::MFMA_I8_16x16x32_I32,
+ MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
static const WgpDetails cdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index cbf9c6f..f0c974e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -141,6 +141,7 @@
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 791c6e1..d834a04 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -183,6 +183,7 @@
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::LinalgExt::Transforms
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 5858759..ed9175e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -20,6 +20,7 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
@@ -27,6 +28,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -444,6 +446,15 @@
Type rhsElemType = getElementTypeOrSelf(rhs);
Type initElemType = getElementTypeOrSelf(init);
+ if (auto lhsOp = lhs.getDefiningOp<linalg::GenericOp>()) {
+ if (IREE::Flow::isDequantizationLikeOp(lhsOp))
+ lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]);
+ }
+ if (auto rhsOp = rhs.getDefiningOp<linalg::GenericOp>()) {
+ if (IREE::Flow::isDequantizationLikeOp(rhsOp))
+ rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
+ }
+
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index d0254d0..25fcc8e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -1,7 +1,12 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
-// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
-// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s --check-prefix=RDNA3
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
+// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
+// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN: %s | FileCheck %s
+
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \
+// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
+// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN: %s | FileCheck %s --check-prefix=RDNA3
// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
// to be migrated to the rocdl heuristics, but for now is just physically
@@ -174,6 +179,107 @@
// -----
+// Basic i8, i8 -> i32 matmul.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable @matmul_256x256x256_i8_i32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @matmul_256x256x256_i8_i32 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul_256x256x256_i8_i32() {
+ %cst = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+ %5 = tensor.empty() : tensor<256x256xi32>
+ %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
+ %7 = linalg.matmul ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+ return
+ }
+ }
+}
+}
+
+// Make sure it generates the mfma instructions we expect for integer inputs.
+
+// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I8_16x16x32_I32>,
+// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
+// CHECK-SAME: prefetch_shared_memory
+
+// CHECK-LABEL: func.func @matmul_256x256x256_i8_i32()
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 mfma ops.
+// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
+// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xi32>, memref<256x256xi32, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+// Basic i8, i8 -> i32 matmul_transpose_b.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable @matmul_transpose_b_256x256x256_i8_i32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @matmul_transpose_b_256x256x256_i8_i32 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul_transpose_b_256x256x256_i8_i32() {
+ %cst = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+ %5 = tensor.empty() : tensor<256x256xi32>
+ %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
+ %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+ return
+ }
+ }
+}
+}
+
+// Make sure it generates the mfma instructions we expect for integer inputs.
+
+// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I8_16x16x32_I32>,
+// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2>
+// CHECK-SAME: prefetch_shared_memory
+
+// CHECK-LABEL: func.func @matmul_transpose_b_256x256x256_i8_i32()
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 mfma ops.
+// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
+// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xi32>, memref<256x256xi32, #hal.descriptor_type<storage_buffer>>
+
+
+// -----
+
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 43aca63..9347cc2 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2222,6 +2222,35 @@
"requires-gpu-cdna3"
)
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_rocm_i8_large_cdna3_mfma_tb
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=i8"
+ "--acc_type=i32"
+ "--transpose_rhs"
+ "--shapes=gpu_large_aligned"
+ "--compilation_info=LLVMGPUVectorDistributeMFMA"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
unset(IREE_HIP_TEST_COMPILER_FLAGS)
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 15f0e33..003f3de 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -246,7 +246,9 @@
return tile_workgroup_size_pairs
-def get_rocm_test_compilation_infos(compilation_info_id: CompilationInfoId):
+def get_rocm_test_compilation_infos(
+ compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId
+):
intrinsic = ""
if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeMFMA:
intrinsic = "MFMA"
@@ -269,6 +271,14 @@
MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
+ MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1),
+ MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2),
+ MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1),
+ MMASchedule("MFMA_I8_16x16x32_I32", 4, 2, 4, 2, 1),
+ MMASchedule("MFMA_I8_32x32x16_I32", 1, 1, 1, 1, 1),
+ MMASchedule("MFMA_I8_32x32x16_I32", 2, 2, 1, 1, 2),
+ MMASchedule("MFMA_I8_32x32x16_I32", 4, 1, 1, 2, 2),
+ MMASchedule("MFMA_I8_32x32x16_I32", 4, 2, 2, 2, 2),
]
elif intrinsic == "WMMA":
schedules = [
@@ -287,6 +297,11 @@
infos = []
for schedule in schedules:
+ # Skip schedules with an intrinsic which element type does not
+ # match the requested one.
+ if lhs_rhs_type.value.upper() not in schedule.intrinsic:
+ continue
+
if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
@@ -295,6 +310,14 @@
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 8
+ elif schedule.intrinsic == "MFMA_I8_16x16x32_I32":
+ wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+ wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+ wg_tile_k = schedule.k_tile_count * 32
+ elif schedule.intrinsic == "MFMA_I8_32x32x16_I32":
+ wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
+ wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
+ wg_tile_k = schedule.k_tile_count * 16
elif schedule.intrinsic == "WMMA_F16_16x16x16_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
@@ -328,7 +351,7 @@
CompilationInfoId.LLVMGPUVectorDistributeMFMA,
CompilationInfoId.LLVMGPUVectorDistributeWMMA,
]:
- return get_rocm_test_compilation_infos(compilation_info_id)
+ return get_rocm_test_compilation_infos(compilation_info_id, lhs_rhs_type)
software_pipeline_depth = 0
if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt:
diff --git a/third_party/llvm-project b/third_party/llvm-project
index c83d9e9..1e498cb 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit c83d9e99b0663cf8e7e81bd552d42a0c4298ab2c
+Subproject commit 1e498cbf26917713a562ae9551a549dfbfed3add