[LLVMGPU][ROCm] Plumb through i8, i8 -> i32 MFMA intrinsics (#17764)

Add tests to make sure these are generated in the vector distribution
pipeline. Add e2e correctness tests.

I also tested this manually on random inputs against golden outputs from
numpy.

This contains one cherry-pick for llvm-project.

---------

Co-authored-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Lei Zhang <antiagainst@gmail.com>
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 0f427c5..c801f9b 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -6,13 +6,13 @@
 // GFX942: target = #iree_gpu.target<arch = "gfx942",
 // GFX942-SAME: wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8,
 // GFX942-SAME:         subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32,
-// GFX942-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>],
+// GFX942-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
 // GFX942-SAME:         subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
 // GFX942-SAME:         max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
 // GFX942-SAME: chip = <wgp_count = 304>>
 
 // GFX940: target = #iree_gpu.target<arch = "gfx940",
-// GFX940-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>]
+// GFX940-SAME:         mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
 
 // GFX1100: target = #iree_gpu.target<arch = "gfx1100",
 // GFX1100-SAME:        mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 8068200..cbca100 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -205,6 +205,10 @@
                                            MMAIntrinsic type) {
   Type f16 = Float16Type::get(context);
   Type f32 = Float32Type::get(context);
+
+  Type i8 = IntegerType::get(context, 8);
+  Type i32 = IntegerType::get(context, 32);
+
   switch (type) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
     return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
@@ -212,6 +216,12 @@
   case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
     return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
   }
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+    return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
+  }
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+    return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
+  }
   case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
     return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
   }
@@ -280,13 +290,47 @@
     return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
                              bNLayout,     cMLayout, cNLayout};
   }
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+    // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+    // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
+    // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+    // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+
+    auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
+    auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 8});
+    auto aMLayout = outer;
+    auto aKLayout = inner;
+    auto bKLayout = inner;
+    auto bNLayout = outer;
+    auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
+    auto cNLayout = outer;
+    return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+                             bNLayout,     cMLayout, cNLayout};
+  }
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+    // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
+    // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]>
+    // #layout_a = #iree_vector_ext.layout<#outer, #inner>
+    // #layout_b = #iree_vector_ext.layout<#inner, #outer>
+
+    auto outer = PerDimLayoutAttr::get(context, {laneX}, {32});
+    auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 8});
+    auto aMLayout = outer;
+    auto aKLayout = inner;
+    auto bKLayout = inner;
+    auto bNLayout = outer;
+    auto cMLayout =
+        PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4});
+    auto cNLayout = outer;
+    return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
+                             bNLayout,     cMLayout, cNLayout};
+  }
   case MMAIntrinsic::WMMA_F16_16x16x16_F32:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
     // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
-    // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
+    // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
     // #layout_a = #iree_vector_ext.layout<#outer, #inner>
     // #layout_b = #iree_vector_ext.layout<#inner, #outer>
-    // #layout_c = #iree_vector_ext.layout<#inner, #outer>
 
     auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
     auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
@@ -372,6 +416,18 @@
     auto cType = VectorType::get({16}, getCType());
     return std::make_tuple(aType, bType, cType);
   }
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+    auto aType = VectorType::get({8}, getAType());
+    auto bType = VectorType::get({8}, getBType());
+    auto cType = VectorType::get({4}, getCType());
+    return std::make_tuple(aType, bType, cType);
+  }
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+    auto aType = VectorType::get({8}, getAType());
+    auto bType = VectorType::get({8}, getBType());
+    auto cType = VectorType::get({16}, getCType());
+    return std::make_tuple(aType, bType, cType);
+  }
   case MMAIntrinsic::WMMA_F16_16x16x16_F32:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
     auto aType = VectorType::get({16}, getAType());
@@ -396,6 +452,8 @@
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
   case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32:
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16:
   case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
     return 1;
@@ -408,7 +466,9 @@
 int64_t MMAAttr::getSubgroupSize() const {
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
-  case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
+  case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32:
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
     return 64;
   }
   case MMAIntrinsic::WMMA_F16_16x16x16_F32:
@@ -430,6 +490,14 @@
     return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
             /*element=*/{1, 4}};
   }
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+    return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
+            /*element=*/{1, 8}};
+  }
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+    return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
+            /*element=*/{1, 8}};
+  }
   case MMAIntrinsic::WMMA_F16_16x16x16_F32:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
     return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 16},
@@ -449,6 +517,14 @@
     return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
             /*element=*/{4, 1}};
   }
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
+    return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
+            /*element=*/{8, 1}};
+  }
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
+    return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
+            /*element=*/{8, 1}};
+  }
   case MMAIntrinsic::WMMA_F16_16x16x16_F32:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
     return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
@@ -460,11 +536,13 @@
 
 MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
   switch (getIntrinsic().getValue()) {
-  case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
+  case MMAIntrinsic::MFMA_F16_16x16x16_F32:
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
     return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
             /*element=*/{4, 1}};
   }
-  case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
+  case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
     return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
             /*element=*/{4, 1}};
   }
@@ -494,7 +572,9 @@
   }
   switch (getIntrinsic().getValue()) {
   case MMAIntrinsic::MFMA_F16_16x16x16_F32:
-  case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
+  case MMAIntrinsic::MFMA_F16_32x32x8_F32:
+  case MMAIntrinsic::MFMA_I8_16x16x32_I32:
+  case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
     auto [m, n, k] = getMNKShape();
     return builder
         .create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index a7abbb6..0423c2f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -101,14 +101,18 @@
 // Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
 def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
 def MFMA_F16_32x32x8_F32  : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
+def MFMA_I8_16x16x32_I32  : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
+def MFMA_I8_32x32x16_I32  : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
 // TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
-def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 2>;
-def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 3>;
+def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
+def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
 
 def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
     "Descriptor for different MMA intrinsics", [
       MFMA_F16_16x16x16_F32,
       MFMA_F16_32x32x8_F32,
+      MFMA_I8_16x16x32_I32,
+      MFMA_I8_32x32x16_I32,
       WMMA_F16_16x16x16_F32,
       WMMA_F16_16x16x16_F16
     ]>;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index b6f0aeb..993963c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -124,6 +124,8 @@
   static const MMAIntrinsic cdna3MMAOps[] = {
       MMAIntrinsic::MFMA_F16_16x16x16_F32,
       MMAIntrinsic::MFMA_F16_32x32x8_F32,
+      MMAIntrinsic::MFMA_I8_16x16x32_I32,
+      MMAIntrinsic::MFMA_I8_32x32x16_I32,
   };
   static const WgpDetails cdna3Wgp = {
       allComputeBits,   allStorageBits,          allSubgroupOps,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index cbf9c6f..f0c974e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -141,6 +141,7 @@
         "//compiler/src/iree/compiler/Codegen/Utils",
         "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
+        "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 791c6e1..d834a04 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -183,6 +183,7 @@
     iree::compiler::Codegen::Utils
     iree::compiler::Codegen::Utils::VectorOpUtils
     iree::compiler::Dialect::Flow::IR
+    iree::compiler::Dialect::Flow::Transforms
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::LinalgExt::IR
     iree::compiler::Dialect::LinalgExt::Transforms
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 5858759..ed9175e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -20,6 +20,7 @@
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "llvm/ADT/STLExtras.h"
@@ -27,6 +28,7 @@
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -444,6 +446,15 @@
   Type rhsElemType = getElementTypeOrSelf(rhs);
   Type initElemType = getElementTypeOrSelf(init);
 
+  if (auto lhsOp = lhs.getDefiningOp<linalg::GenericOp>()) {
+    if (IREE::Flow::isDequantizationLikeOp(lhsOp))
+      lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]);
+  }
+  if (auto rhsOp = rhs.getDefiningOp<linalg::GenericOp>()) {
+    if (IREE::Flow::isDequantizationLikeOp(rhsOp))
+      rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
+  }
+
   GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
                              lhsElemType,  rhsElemType,  initElemType};
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index d0254d0..25fcc8e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -1,7 +1,12 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
-// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
-// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" %s | FileCheck %s --check-prefix=RDNA3
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
+// RUN:   --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
+// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN:   %s | FileCheck %s
+
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \
+// RUN:   --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \
+// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN:   %s | FileCheck %s --check-prefix=RDNA3
 
 // TODO: This test is still using the legacy LLVMGPU kernel config. This needs
 // to be migrated to the rocdl heuristics, but for now is just physically
@@ -174,6 +179,107 @@
 
 // -----
 
+// Basic i8, i8 -> i32 matmul.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @matmul_256x256x256_i8_i32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+  hal.executable.export @matmul_256x256x256_i8_i32 layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+  builtin.module {
+    func.func @matmul_256x256x256_i8_i32() {
+      %cst = arith.constant 0 : i32
+      %c0 = arith.constant 0 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+      %5 = tensor.empty() : tensor<256x256xi32>
+      %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
+      %7 = linalg.matmul ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
+      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+      return
+    }
+  }
+}
+}
+
+// Make sure it generates the mfma instructions we expect for integer inputs.
+
+//       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
+//  CHECK-SAME:   mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I8_16x16x32_I32>,
+//  CHECK-SAME:     subgroup_m_count = 2, subgroup_n_count = 2>
+//  CHECK-SAME:   prefetch_shared_memory
+
+//    CHECK-LABEL: func.func @matmul_256x256x256_i8_i32()
+//     CHECK-SAME:    translation_info = #[[$TRANSLATION]]
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 mfma ops.
+// CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<8xi8>, vector<8xi8>, vector<4xi32>
+//  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xi32>, memref<256x256xi32, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+// Basic i8, i8 -> i32 matmul_transpose_b.
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @matmul_transpose_b_256x256x256_i8_i32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+  hal.executable.export @matmul_transpose_b_256x256x256_i8_i32 layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+  builtin.module {
+    func.func @matmul_transpose_b_256x256x256_i8_i32() {
+      %cst = arith.constant 0 : i32
+      %c0 = arith.constant 0 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xi8>>
+      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
+      %5 = tensor.empty() : tensor<256x256xi32>
+      %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
+      %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
+      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
+      return
+    }
+  }
+}
+}
+
+// Make sure it generates the mfma instructions we expect for integer inputs.
+
+//       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
+//  CHECK-SAME:   mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I8_16x16x32_I32>,
+//  CHECK-SAME:     subgroup_m_count = 2, subgroup_n_count = 2>
+//  CHECK-SAME:   prefetch_shared_memory
+
+//    CHECK-LABEL: func.func @matmul_transpose_b_256x256x256_i8_i32()
+//     CHECK-SAME:    translation_info = #[[$TRANSLATION]]
+// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
+// along the K dimension. So in total 32 mfma ops.
+// CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<8xi8>, vector<8xi8>, vector<4xi32>
+//  CHECK-COUNT-4:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xi32>, memref<256x256xi32, #hal.descriptor_type<storage_buffer>>
+
+
+// -----
+
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 43aca63..9347cc2 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -2222,6 +2222,35 @@
     "requires-gpu-cdna3"
 )
 
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_rocm_i8_large_cdna3_mfma_tb
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--acc_type=i32"
+    "--transpose_rhs"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistributeMFMA"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
 elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")
 
 unset(IREE_HIP_TEST_COMPILER_FLAGS)
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 15f0e33..003f3de 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -246,7 +246,9 @@
     return tile_workgroup_size_pairs
 
 
-def get_rocm_test_compilation_infos(compilation_info_id: CompilationInfoId):
+def get_rocm_test_compilation_infos(
+    compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId
+):
     intrinsic = ""
     if compilation_info_id == CompilationInfoId.LLVMGPUVectorDistributeMFMA:
         intrinsic = "MFMA"
@@ -269,6 +271,14 @@
             MMASchedule("MFMA_F16_32x32x8_F32", 2, 2, 1, 1, 1),
             MMASchedule("MFMA_F16_32x32x8_F32", 1, 4, 2, 1, 2),
             MMASchedule("MFMA_F16_32x32x8_F32", 4, 2, 1, 2, 4),
+            MMASchedule("MFMA_I8_16x16x32_I32", 1, 1, 1, 1, 1),
+            MMASchedule("MFMA_I8_16x16x32_I32", 2, 2, 1, 1, 2),
+            MMASchedule("MFMA_I8_16x16x32_I32", 4, 1, 4, 1, 1),
+            MMASchedule("MFMA_I8_16x16x32_I32", 4, 2, 4, 2, 1),
+            MMASchedule("MFMA_I8_32x32x16_I32", 1, 1, 1, 1, 1),
+            MMASchedule("MFMA_I8_32x32x16_I32", 2, 2, 1, 1, 2),
+            MMASchedule("MFMA_I8_32x32x16_I32", 4, 1, 1, 2, 2),
+            MMASchedule("MFMA_I8_32x32x16_I32", 4, 2, 2, 2, 2),
         ]
     elif intrinsic == "WMMA":
         schedules = [
@@ -287,6 +297,11 @@
 
     infos = []
     for schedule in schedules:
+        # Skip schedules with an intrinsic which element type does not
+        # match the requested one.
+        if lhs_rhs_type.value.upper() not in schedule.intrinsic:
+            continue
+
         if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
@@ -295,6 +310,14 @@
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
             wg_tile_k = schedule.k_tile_count * 8
+        elif schedule.intrinsic == "MFMA_I8_16x16x32_I32":
+            wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
+            wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
+            wg_tile_k = schedule.k_tile_count * 32
+        elif schedule.intrinsic == "MFMA_I8_32x32x16_I32":
+            wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
+            wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
+            wg_tile_k = schedule.k_tile_count * 16
         elif schedule.intrinsic == "WMMA_F16_16x16x16_F32":
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
@@ -328,7 +351,7 @@
         CompilationInfoId.LLVMGPUVectorDistributeMFMA,
         CompilationInfoId.LLVMGPUVectorDistributeWMMA,
     ]:
-        return get_rocm_test_compilation_infos(compilation_info_id)
+        return get_rocm_test_compilation_infos(compilation_info_id, lhs_rhs_type)
 
     software_pipeline_depth = 0
     if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt:
diff --git a/third_party/llvm-project b/third_party/llvm-project
index c83d9e9..1e498cb 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit c83d9e99b0663cf8e7e81bd552d42a0c4298ab2c
+Subproject commit 1e498cbf26917713a562ae9551a549dfbfed3add