[ROCM] Port mlir ukernels to ukernel descriptor lowering flow (#21683)

This copies all ukernels from the tuning spec to the ukernel descriptor
based lowering and PDL patterns. This doesn't remove the ukernels in the
spec yet as that requires the usage of
`--iree-codegen-enable-default-tuning-specs=true` to be updated to
`--iree-hip-enable-tensor-ukernels` everywhere, which imo is better done in a
separate PR.

The ukernels and matching patterns being copied in this PR:
- pingpong_large_f8_expanded
- pingpong_large_f16
- pingpong_medium_f16_expanded
- pingpong_large_f16_expanded
- pingpong_large_bf16
- pingpong_medium_bf16_expanded
- pingpong_large_bf16_expanded

Note that the mmt_2048x1280x5120_f16_f16_f32 matching and annotation is
not ported as I think this is not reachable due to pingpong_large_f16
matching the same and taking precedence.

---------

Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
Co-authored-by: Quinn Dawkins <quinn.dawkins@gmail.com>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
index 8a54cf6..f562c11 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
@@ -16,6 +16,7 @@
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Utils/ShapeUtils.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -61,7 +62,8 @@
     return rewriter.notifyMatchFailure(rootOp,
                                        "expected StringAttr for attr name.");
   }
-  rootOp->setAttr(strName.strref(), annotation);
+  rewriter.modifyOpInPlace(
+      rootOp, [&]() { rootOp->setAttr(strName.strref(), annotation); });
   return success();
 }
 
@@ -75,7 +77,6 @@
     return rewriter.notifyMatchFailure(rootOp,
                                        "not a contraction like linalg op");
   }
-
   if (linalgOp.getIndexingMaps() != indexingMaps) {
     return rewriter.notifyMatchFailure(rootOp, "indexing maps mismatch");
   }
@@ -102,6 +103,9 @@
   if (!dim) {
     return failure();
   }
+  if (dimInt.getInt() >= shapedType.getRank()) {
+    return failure();
+  }
   auto divisorInt = dyn_cast<IntegerAttr>(divisor);
   if (!divisor) {
     return failure();
@@ -140,6 +144,9 @@
   if (!dimInt) {
     return failure();
   }
+  if (dimInt.getInt() >= shapedType.getRank()) {
+    return failure();
+  }
   if (auto lowerBoundInt = dyn_cast<IntegerAttr>(lowerBound)) {
     FailureOr<int64_t> constantLb =
         ValueBoundsConstraintSet::computeConstantBound(
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
index 6959765..18c5013 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
@@ -133,3 +133,77 @@
 // CHECK-LABEL: @negative_matmul_f8_dynamic_lower_bound
 // CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
 // CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_f8_expanded", tensor>
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_f16(%arg0: tensor<256x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<256x1024xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<256x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<256x1024xf32>) {
+    ^bb0(%in: f16, %in_4: f16, %out: f32):
+      %12 = arith.extf %in : f16 to f32
+      %13 = arith.extf %in_4 : f16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<256x1024xf32>
+  return %2 : tensor<256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f16
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_bf16(%arg0: tensor<256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<256x1024xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<256x1024xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<256x1024xf32>) {
+    ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+      %12 = arith.extf %in : bf16 to f32
+      %13 = arith.extf %in_4 : bf16 to f32
+      %14 = arith.mulf %12, %13 : f32
+      %15 = arith.addf %out, %14 : f32
+      linalg.yield %15 : f32
+    } -> tensor<256x1024xf32>
+  return %2 : tensor<256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+// The dynamic dimension is a multiple of 256, but doesn't have a lower bound of 256.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_bf16_dynamic_lower_bound(%arg0: index) -> tensor<1x256x1024xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = util.assume.int %arg0<umin = 128, udiv = 256> : index
+  %1 = tensor.empty(%0) : tensor<1x256x?xbf16>
+  %2 = tensor.empty(%0) : tensor<1024x?xbf16>
+  %3 = tensor.empty() : tensor<1x256x1024xf32>
+  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+  %5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %2 : tensor<1x256x?xbf16>, tensor<1024x?xbf16>) outs(%4 : tensor<1x256x1024xf32>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: f32):
+    %6 = arith.extf %in : bf16 to f32
+    %7 = arith.extf %in_0 : bf16 to f32
+    %8 = arith.mulf %6, %7 : f32
+    %9 = arith.addf %out, %8 : f32
+    linalg.yield %9 : f32
+  } -> tensor<1x256x1024xf32>
+  return %5 : tensor<1x256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16_dynamic_lower_bound
+// CHECK-NOT:     compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT:     iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
index 7ca6748..e56b7ec 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
@@ -17,7 +17,7 @@
 module attributes {
   hal.executable.target = #executable_target_rocm_hsaco_fb
 } {
-  func.func @matmul_f8(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
+  func.func @matmul_f8_medium_expanded(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
     %cst = arith.constant 0.000000e+00 : f32
     %0 = tensor.empty() : tensor<1x128x1024xf32>
     %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
@@ -34,3 +34,297 @@
 }
 // CHECK-LABEL: util.func private @pingpong_medium_f8_expanded
 // CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_f8_large_expanded(%arg0: tensor<1x256x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x256x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1x256x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x256x1024xf32>) {
+      ^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
+        %12 = arith.extf %in : f8E4M3FNUZ to f32
+        %13 = arith.extf %in_4 : f8E4M3FNUZ to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1x256x1024xf32>
+    return %2 : tensor<1x256x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_f8_large_expanded
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_f8_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_f8_expanded
+// CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_f16_large(%arg0: tensor<1024x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1024x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1024x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1024x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1024x1024xf32>) {
+      ^bb0(%in: f16, %in_4: f16, %out: f32):
+        %12 = arith.extf %in : f16 to f32
+        %13 = arith.extf %in_4 : f16 to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1024x1024xf32>
+    return %2 : tensor<1024x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_f16_large
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_f16", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_f16
+// CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_f16_medium_expanded(%arg0: tensor<1x128x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1x128x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1x128x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1x128x1024xf32>) {
+      ^bb0(%in: f16, %in_4: f16, %out: f32):
+        %12 = arith.extf %in : f16 to f32
+        %13 = arith.extf %in_4 : f16 to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1x128x1024xf32>
+    return %2 : tensor<1x128x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_f16_medium_expanded
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_f16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_medium_f16_expanded
+// CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_f16_large_expanded(%arg0: tensor<1x256x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1x256x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1x256x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1x256x1024xf32>) {
+      ^bb0(%in: f16, %in_4: f16, %out: f32):
+        %12 = arith.extf %in : f16 to f32
+        %13 = arith.extf %in_4 : f16 to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1x256x1024xf32>
+    return %2 : tensor<1x256x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_f16_large_expanded
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_f16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_f16_expanded
+// CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_bf16_large(%arg0: tensor<1024x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1024x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1024x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1024x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1024x1024xf32>) {
+      ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+        %12 = arith.extf %in : bf16 to f32
+        %13 = arith.extf %in_4 : bf16 to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1024x1024xf32>
+    return %2 : tensor<1024x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_bf16_large
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_bf16
+// CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_bf16_expanded_large(%arg0: tensor<1x256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1x256x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1x256x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1x256x1024xf32>) {
+      ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+        %12 = arith.extf %in : bf16 to f32
+        %13 = arith.extf %in_4 : bf16 to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1x256x1024xf32>
+    return %2 : tensor<1x256x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_bf16_expanded_large
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_bf16_expanded
+// CHECK:         iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+  {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+                                               wgp = <compute = fp16, storage =  b16,
+                                               subgroup =  none,
+                                               subgroup_size_choices = [64],
+                                               max_workgroup_sizes = [1024, 1024, 1024],
+                                               max_thread_count_per_workgroup = 1024,
+                                               max_workgroup_memory_bytes = 65536,
+                                               max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+   ukernels = "none"}>
+module attributes {
+  hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+  func.func @matmul_bf16_expanded_medium(%arg0: tensor<1x128x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1x128x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1x128x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1x128x1024xf32>) {
+      ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+        %12 = arith.extf %in : bf16 to f32
+        %13 = arith.extf %in_4 : bf16 to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1x128x1024xf32>
+    return %2 : tensor<1x128x1024xf32>
+  }
+}
+// CHECK-LABEL: @matmul_bf16_expanded_medium
+// CHECK:         linalg.generic
+// CHECK-SAME:      compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME:      lowering_config =
+// CHECK-SAME:      translation_info =
+// CHECK-SAME:      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_bf16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_medium_bf16_expanded
+// CHECK:         iree_codegen.inner_tiled
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel
index c9749cd..a5c4e18 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel
@@ -55,6 +55,8 @@
 iree_c_embed_data(
     name = "iree_mlir_ukernels_amdgpu",
     srcs = [
+        "iree_uk_amdgpu_matmul_bf16.mlir",
+        "iree_uk_amdgpu_matmul_f16.mlir",
         "iree_uk_amdgpu_matmul_f8.mlir",
     ],
     c_file_output = "iree_mlir_ukernels_amdgpu.c",
@@ -67,6 +69,8 @@
 iree_lit_test_suite(
     name = "verify_mlir_ukernels_amdgpu",
     srcs = [
+        "iree_uk_amdgpu_matmul_bf16.mlir",
+        "iree_uk_amdgpu_matmul_f16.mlir",
         "iree_uk_amdgpu_matmul_f8.mlir",
     ],
     cfg = "//compiler:lit.cfg.py",
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt
index 0df959c..73815c9 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt
@@ -40,6 +40,8 @@
   NAME
     iree_mlir_ukernels_amdgpu
   SRCS
+    "iree_uk_amdgpu_matmul_bf16.mlir"
+    "iree_uk_amdgpu_matmul_f16.mlir"
     "iree_uk_amdgpu_matmul_f8.mlir"
   C_FILE_OUTPUT
     "iree_mlir_ukernels_amdgpu.c"
@@ -53,6 +55,8 @@
   NAME
     verify_mlir_ukernels_amdgpu
   SRCS
+    "iree_uk_amdgpu_matmul_bf16.mlir"
+    "iree_uk_amdgpu_matmul_f16.mlir"
     "iree_uk_amdgpu_matmul_f8.mlir"
   TOOLS
     iree-opt
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir
new file mode 100644
index 0000000..9a138e2
--- /dev/null
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir
@@ -0,0 +1,659 @@
+//  RUN: iree-opt %s
+
+!bf16_in_ty = tensor<256x?xbf16>
+!bf16_exp_in_ty = tensor<1x256x?xbf16>
+!bf16_block_in = tensor<256x64xbf16>
+!bf16_exp_block_in = tensor<1x256x64xbf16>
+!bf16_flat_shared = memref<16384xbf16, #gpu.address_space<workgroup>>
+!bf16_shared = memref<256x64xbf16, #gpu.address_space<workgroup>>
+!bf16_shared_exp = memref<16x16x4x16xbf16, #gpu.address_space<workgroup>>
+
+!in_ty_bf16 = tensor<256x?xbf16>
+!exp_in_ty_bf16 = tensor<1x256x?xbf16>
+!block_in_bf16 = tensor<256x64xbf16>
+!exp_block_in_bf16 = tensor<1x256x64xbf16>
+!flat_shared_bf16 = memref<16384xbf16, #gpu.address_space<workgroup>>
+!shared_bf16 = memref<256x64xbf16, #gpu.address_space<workgroup>>
+!shared_exp_bf16 = memref<16x16x4x16xbf16, #gpu.address_space<workgroup>>
+
+!mexp_in_ty_bf16 = tensor<1x128x?xbf16>
+!mexp_block_in_bf16 = tensor<1x128x64xbf16>
+!mflat_shared_bf16 = memref<8192xbf16, #gpu.address_space<workgroup>>
+!mshared_bf16 = memref<128x64xbf16, #gpu.address_space<workgroup>>
+!mshared_exp_bf16 = memref<8x16x4x16xbf16, #gpu.address_space<workgroup>>
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+
+util.func private @pingpong_large_bf16(%lhs_base: !bf16_in_ty, %rhs_base: !bf16_in_ty, %unused_acc: tensor<256x256xf32>) -> tensor<256x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : bf16
+  %lhs_shared_base = memref.alloc() : !bf16_flat_shared
+  %rhs_shared_base = memref.alloc() : !bf16_flat_shared
+
+  %dim = tensor.dim %lhs_base, %c1 : !bf16_in_ty
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [%delin#0, %vec] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+
+  %0 = tensor.empty() : tensor<16x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<16x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw>: index
+    %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw>: index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw>: index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 8 elements.
+    %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw>: index
+    // Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw>: index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw>: index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw>: index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw>: index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw>: index
+
+    %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, %i] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [%glb0, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [%glb1, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_2 = tensor.extract_slice %lhs_block [%glb2, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_3 = tensor.extract_slice %lhs_block [%glb3, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+      %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot3 : vector<8x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+    %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+    %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+    %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+    %empty = tensor.empty() : tensor<8x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<8x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[%m_outer_id, %ids#3, %n_outer_id, %inner_id] [8, 1, 4, 4] [1, 1, 1, 1] : tensor<8x1x4x4xf32> into tensor<16x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<16x16x16x16xf32> into tensor<256x256xf32>
+  util.return %collapse : tensor<256x256xf32>
+}
+
+// Expanded variants of BF16
+
+util.func private @pingpong_medium_bf16_expanded(%lhs_base: !mexp_in_ty_bf16, %rhs_base: !in_ty_bf16, %unused_acc: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : bf16
+  %lhs_shared_base = memref.alloc() : !mflat_shared_bf16
+  %rhs_shared_base = memref.alloc() : !flat_shared_bf16
+
+  %dim = tensor.dim %rhs_base, %c1 : !in_ty_bf16
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty_bf16
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_bf16
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !mflat_shared_bf16
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared_bf16
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [128, 64] : !mflat_shared_bf16 into !mshared_bf16
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared_bf16 into !shared_bf16
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 128, 64] [1, 1, 1] : !mexp_in_ty_bf16 to !mexp_block_in_bf16
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty_bf16 to !block_in_bf16
+
+  scf.forall (%id) in (1024) {
+    %delin:2 = affine.delinearize_index %id into (128, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !mexp_block_in_bf16 to tensor<1x1x8xbf16>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !mshared_bf16
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [8, 16, 4, 16] : !mshared_bf16 into !mshared_exp_bf16
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared_bf16 into !shared_exp_bf16
+
+  %0 = tensor.empty() : tensor<1x8x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw>: index
+    %m_outer_id = arith.muli %ids#0, %c4 overflow<nsw, nuw>: index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw>: index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 8 elements.
+    %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw>: index
+    // RHS indexing. Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw>: index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw>: index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw>: index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw>: index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw>: index
+    // LHS indexing.
+    %bpo_lhs = arith.muli %wt#0, %c16 overflow<nsw, nuw>: index
+    %glb0_lhs = arith.addi %wt#1, %bpo_lhs overflow<nsw, nuw>: index
+    %glb1_lhs = arith.addi %glb0_lhs, %c8 overflow<nsw, nuw>: index
+
+    %2 = arith.constant dense<0.0> : vector<4x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<4x4x1x4xf32> {
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+      %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+      %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty_bf16 to !block_in_bf16
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+      %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+      %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+      rocdl.sched.barrier 0
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 128, 64] [1, 1, 1] : !mexp_in_ty_bf16 to !mexp_block_in_bf16
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in_bf16 to tensor<1x1x8xbf16>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in_bf16 to tensor<1x1x8xbf16>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !mshared_bf16
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !mshared_bf16
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot2 : vector<4x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+    %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+    %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+    %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+    %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+    %tp = vector.transpose %dot2, [0, 2, 1, 3] : vector<4x4x1x4xf32> to vector<4x1x4x4xf32>
+    %empty = tensor.empty() : tensor<1x4x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<4x1x4x4xf32>, tensor<1x4x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 4, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x4x1x4x4xf32> into tensor<1x8x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x8x16x16x16xf32> into tensor<1x128x256xf32>
+  util.return %collapse : tensor<1x128x256xf32>
+}
+
+util.func private @pingpong_large_bf16_expanded(%lhs_base: !bf16_exp_in_ty, %rhs_base: !bf16_in_ty, %unused_acc: tensor<1x256x256xf32>) -> tensor<1x256x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : bf16
+  %lhs_shared_base = memref.alloc() : !bf16_flat_shared
+  %rhs_shared_base = memref.alloc() : !bf16_flat_shared
+
+  %dim = tensor.dim %rhs_base, %c1 : !bf16_in_ty
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !bf16_exp_in_ty
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 256, 64] [1, 1, 1] : !bf16_exp_in_ty to !bf16_exp_block_in
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+
+  %0 = tensor.empty() : tensor<1x16x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+    %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 8 elements.
+    %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+    // Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw>: index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw>: index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw>: index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw>: index
+
+    %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 256, 64] [1, 1, 1] : !bf16_exp_in_ty to !bf16_exp_block_in
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_2 = tensor.extract_slice %lhs_block [0, %glb2, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+      %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+      %lhs_thread_3 = tensor.extract_slice %lhs_block [0, %glb3, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+      %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+      %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+      %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+      } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot3 : vector<8x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+    %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+    %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+    %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+    %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+    } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+    %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+    %empty = tensor.empty() : tensor<1x8x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<1x8x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 8, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x8x1x4x4xf32> into tensor<1x16x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x16x16x16x16xf32> into tensor<1x256x256xf32>
+  util.return %collapse : tensor<1x256x256xf32>
+}
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir
new file mode 100644
index 0000000..480b849
--- /dev/null
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir
@@ -0,0 +1,649 @@
+//  RUN: iree-opt %s
+
+!in_ty = tensor<256x?xf16>
+!exp_in_ty = tensor<1x256x?xf16>
+!block_in = tensor<256x64xf16>
+!exp_block_in = tensor<1x256x64xf16>
+!flat_shared = memref<16384xf16, #gpu.address_space<workgroup>>
+!shared = memref<256x64xf16, #gpu.address_space<workgroup>>
+!shared_exp = memref<16x16x4x16xf16, #gpu.address_space<workgroup>>
+
+!mexp_in_ty = tensor<1x128x?xf16>
+!mexp_block_in = tensor<1x128x64xf16>
+!mflat_shared = memref<8192xf16, #gpu.address_space<workgroup>>
+!mshared = memref<128x64xf16, #gpu.address_space<workgroup>>
+!mshared_exp = memref<8x16x4x16xf16, #gpu.address_space<workgroup>>
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+
+util.func private @pingpong_large_f16(%lhs_base: !in_ty, %rhs_base: !in_ty, %unused_acc: tensor<256x256xf32>) -> tensor<256x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : f16
+  %lhs_shared_base = memref.alloc() : !flat_shared
+  %rhs_shared_base = memref.alloc() : !flat_shared
+
+  %dim = tensor.dim %lhs_base, %c1 : !in_ty
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !in_ty
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+
+  %0 = tensor.empty() : tensor<16x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<16x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+    %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 8 elements.
+    %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+    // Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+
+    %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %lhs_thread_2 = tensor.extract_slice %lhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %lhs_thread_3 = tensor.extract_slice %lhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+      %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot3 : vector<8x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+    %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+    %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+    %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+    %empty = tensor.empty() : tensor<8x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<8x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[%m_outer_id, %ids#3, %n_outer_id, %inner_id] [8, 1, 4, 4] [1, 1, 1, 1] : tensor<8x1x4x4xf32> into tensor<16x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<16x16x16x16xf32> into tensor<256x256xf32>
+  util.return %collapse : tensor<256x256xf32>
+}
+
+util.func private @pingpong_medium_f16_expanded(%lhs_base: !mexp_in_ty, %rhs_base: !in_ty, %unused_acc: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : f16
+  %lhs_shared_base = memref.alloc() : !mflat_shared
+  %rhs_shared_base = memref.alloc() : !flat_shared
+
+  %dim = tensor.dim %rhs_base, %c1 : !in_ty
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !mflat_shared
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [128, 64] : !mflat_shared into !mshared
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 128, 64] [1, 1, 1] : !mexp_in_ty to !mexp_block_in
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+
+  scf.forall (%id) in (1024) {
+    %delin:2 = affine.delinearize_index %id into (128, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [8, 16, 4, 16] : !mshared into !mshared_exp
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+
+  %0 = tensor.empty() : tensor<1x8x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+    %m_outer_id = arith.muli %ids#0, %c4 overflow<nsw, nuw> : index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 8 elements.
+    %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+    // RHS indexing. Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+    // LHS indexing.
+    %bpo_lhs = arith.muli %wt#0, %c16 overflow<nsw, nuw> : index
+    %glb0_lhs = arith.addi %wt#1, %bpo_lhs overflow<nsw, nuw> : index
+    %glb1_lhs = arith.addi %glb0_lhs, %c8 overflow<nsw, nuw> : index
+
+    %2 = arith.constant dense<0.0> : vector<4x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<4x4x1x4xf32> {
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+      %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+      %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+      %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+      %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+      rocdl.sched.barrier 0
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 128, 64] [1, 1, 1] : !mexp_in_ty to !mexp_block_in
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot2 : vector<4x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+    %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+    %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+    %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+    %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+    %tp = vector.transpose %dot2, [0, 2, 1, 3] : vector<4x4x1x4xf32> to vector<4x1x4x4xf32>
+    %empty = tensor.empty() : tensor<1x4x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<4x1x4x4xf32>, tensor<1x4x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 4, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x4x1x4x4xf32> into tensor<1x8x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x8x16x16x16xf32> into tensor<1x128x256xf32>
+  util.return %collapse : tensor<1x128x256xf32>
+}
+
+util.func private @pingpong_large_f16_expanded(%lhs_base: !exp_in_ty, %rhs_base: !in_ty, %unused_acc: tensor<1x256x256xf32>) -> tensor<1x256x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : f16
+  %lhs_shared_base = memref.alloc() : !flat_shared
+  %rhs_shared_base = memref.alloc() : !flat_shared
+
+  %dim = tensor.dim %rhs_base, %c1 : !in_ty
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !exp_in_ty
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 256, 64] [1, 1, 1] : !exp_in_ty to !exp_block_in
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+
+  %0 = tensor.empty() : tensor<1x16x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+    %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 8 elements.
+    %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+    // Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+
+    %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 256, 64] [1, 1, 1] : !exp_in_ty to !exp_block_in
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+      %lhs_thread_2 = tensor.extract_slice %lhs_block [0, %glb2, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+      %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+      %lhs_thread_3 = tensor.extract_slice %lhs_block [0, %glb3, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+      %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+      %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+      %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+      } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot3 : vector<8x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+    %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+    %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+    %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+    %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+    } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+    %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+    %empty = tensor.empty() : tensor<1x8x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<1x8x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 8, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x8x1x4x4xf32> into tensor<1x16x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x16x16x16x16xf32> into tensor<1x256x256xf32>
+  util.return %collapse : tensor<1x256x256xf32>
+}
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir
index b5dadcd..d808b97 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir
@@ -209,3 +209,226 @@
   %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x8x16x16x16xf32> into tensor<1x128x256xf32>
   util.return %collapse : tensor<1x128x256xf32>
 }
+
+util.func private @pingpong_large_f8_expanded(%lhs_base: !exp_in_ty_f8, %rhs_base: !in_ty_f8, %unused_acc: tensor<1x256x256xf32>) -> tensor<1x256x256xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  %cst = arith.constant 0.0 : f8E4M3FNUZ
+  %lhs_shared_base = memref.alloc() : !flat_shared_f8
+  %rhs_shared_base = memref.alloc() : !flat_shared_f8
+
+  %dim = tensor.dim %rhs_base, %c1 : !in_ty_f8
+  %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !exp_in_ty_f8
+  %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_f8
+
+  %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8
+  %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8
+
+  %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 128] : !flat_shared_f8 into !shared_f8
+  %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 128] : !flat_shared_f8 into !shared_f8
+
+  %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 256, 128] [1, 1, 1] : !exp_in_ty_f8 to !exp_block_in_f8
+  %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8
+
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
+    %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+    %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+    vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  scf.forall (%id) in (2048) {
+    %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+    %vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
+    %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+    %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+    vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+
+  %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 32] : !shared_f8 into !shared_exp_f8
+  %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 32] : !shared_f8 into !shared_exp_f8
+
+  %0 = tensor.empty() : tensor<1x16x16x16x16xf32>
+  %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
+    %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+    %inner_id = arith.muli %ids#2, %c8 overflow<nsw, nuw> : index
+    %inner_id_acc = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+    %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+    %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+    %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+    %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+    // Inner 64 loads 8 threads x 16 elements.
+    %gko = arith.muli %wt#2, %c16 overflow<nsw, nuw> : index
+    // Each subgroup loads 32 contiguous rows out of 256.
+    %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+    // Base index is remaining outer 8 lanes + subgroup base.
+    %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+    %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+    %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+    %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+
+    %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+    %cmp0 = arith.cmpi slt, %id, %c256 : index
+    %cmp1 = arith.cmpi sge, %id, %c256 : index
+    scf.if %cmp0 {
+      rocdl.s.barrier
+    }
+    %3 = scf.for %i = %c128 to %dim step %c128 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+      // Global loads of lhs.
+      %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 256, 128] [1, 1, 1] : !exp_in_ty_f8 to !exp_block_in_f8
+      %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+      %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+      %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+      %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+      %lhs_thread_2 = tensor.extract_slice %lhs_block [0, %glb2, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+      %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+      %lhs_thread_3 = tensor.extract_slice %lhs_block [0, %glb3, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+      %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+
+      %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+      %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+      } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      // Global loads of rhs.
+      %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8
+      %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+      %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+      %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+      %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+      %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+      %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+      %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+      %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+
+      %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+      %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+      } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+      %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+      %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+      %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+      } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+      vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+      vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+      vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+
+      vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+      vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+      vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+      vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+
+      gpu.barrier
+      rocdl.sched.barrier 0
+      rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+      %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+        indexing_maps = #contraction_accesses,
+        iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+        kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+      } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+      rocdl.s.setprio 0
+      gpu.barrier
+      rocdl.sched.barrier 0
+
+      scf.yield %dot3 : vector<8x4x1x4xf32>
+    }
+    scf.if %cmp1 {
+      rocdl.s.barrier
+    }
+
+    // Epilogue
+    %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+    %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+    %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+    } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+    %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+    %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+    %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+    } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+    %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+    %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+    %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+    } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+    %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+    %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+    %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+      indexing_maps = #contraction_accesses,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+    } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+    %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+    %empty = tensor.empty() : tensor<1x8x1x4x4xf32>
+    %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<1x8x1x4x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id_acc] [1, 8, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x8x1x4x4xf32> into tensor<1x16x16x16x16xf32>
+    }
+  } {mapping = [#gpu.thread<linear_dim_0>]}
+  %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x16x16x16x16xf32> into tensor<1x256x256xf32>
+  util.return %collapse : tensor<1x256x256xf32>
+}
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
index 64545dc..c6c83b4 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
@@ -1,8 +1,10 @@
 // RUN: iree-opt -allow-unregistered-dialect %s
 
-// This pattern matches an expanded matmul-like operation and annotates it
-// with ukernel descriptor and configuration attributes.
-pdl.pattern @annotate_expanded_matmul_like : benefit(1) {
+// F8 Patterns
+
+// This pattern matches a medium-sized expanded matmul-like operation and
+// annotates it with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_f8_medium_expanded : benefit(1) {
   %elemtypes = pdl.attribute = [f8E4M3FNUZ, f8E4M3FNUZ, f32]
   %imaps = pdl.attribute = [
     affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
@@ -19,7 +21,7 @@
   %out_init = pdl.operand : %out_type
 
   // Match the a matmul-like generic with above indexin maps.
-  %generic_op = pdl.operation "linalg.generic" (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
   pdl.apply_native_constraint "matchContraction"(
         %generic_op, %elemtypes, %imaps
         : !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -77,3 +79,568 @@
     pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
   }
 }
+
+// This pattern matches a large expanded f8 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes. This is preferred over the
+// medium-sized ukernel.
+pdl.pattern @annotate_matmul_like_f8_large_expanded : benefit(2) {
+  %elemtypes = pdl.attribute = [f8E4M3FNUZ, f8E4M3FNUZ, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 256 == 0, K % 128 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c128 = pdl.attribute = 128
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // N >= 1024, K >= 512
+  %c512 = pdl.attribute = 512
+  %c1024 = pdl.attribute = 1024
+
+  // TODO: Kernel specialization is needed to apply this strategy selectively at
+  // runtime. Additionally model exports don't specify lower bounds so it is
+  // impossible to use this strategy with this check.
+  // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c512, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_f8_expanded", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{
+        workgroup = [1, 256, 256, 0]
+      }>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f8.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
+
+// F16 Patterns
+
+// This pattern matches a large f16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_f16_large : benefit(1) {
+  %elemtypes = pdl.attribute = [f16, f16, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2) -> (d0, d2)>,
+    affine_map<(d0, d1, d2) -> (d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 256 == 0, K % 64 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c64 = pdl.attribute = 64
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // M, N >= 1024, K >= 256
+  %c1024 = pdl.attribute = 1024
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c1, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_f16", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [256, 256, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f16.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
+
+// This pattern matches a medium-sized f16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_f16_medium_expanded : benefit(1) {
+  %elemtypes = pdl.attribute = [f16, f16, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 128 == 0, K % 64 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c64 = pdl.attribute = 64
+  %c128 = pdl.attribute = 128
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // M, N >= 1024, K >= 256
+  %c1024 = pdl.attribute = 1024
+
+  // TODO: Kernel specialization is needed to apply this strategy selectively at
+  // runtime. Additionally model exports don't specify lower bounds so it is
+  // impossible to use this strategy with this check.
+  // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_medium_f16_expanded", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f16.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
+
+// This pattern matches a medium-sized f16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes. This is preferred over the
+// medium-sized ukernel.
+pdl.pattern @annotate_matmul_like_f16_large_expanded : benefit(2) {
+  %elemtypes = pdl.attribute = [f16, f16, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 256 == 0, K % 64 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c64 = pdl.attribute = 64
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // M, N >= 1024, K >= 256
+  %c1024 = pdl.attribute = 1024
+
+  // TODO: Kernel specialization is needed to apply this strategy selectively at
+  // runtime. Additionally model exports don't specify lower bounds so it is
+  // impossible to use this strategy with this check.
+  // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_f16_expanded", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f16.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
+
+// BF16 Patterns
+
+// This pattern matches a bf16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_bf16_large : benefit(1) {
+  %elemtypes = pdl.attribute = [bf16, bf16, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2) -> (d0, d2)>,
+    affine_map<(d0, d1, d2) -> (d1, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 256 == 0, K % 64 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c64 = pdl.attribute = 64
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // M, N >= 1024, K >= 256
+  %c1024 = pdl.attribute = 1024
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c1, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [256, 256, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_bf16.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
+
+// This pattern matches an expanded bf16 matmul-like operation of medium size and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_bf16_medium_expanded : benefit(1) {
+  %elemtypes = pdl.attribute = [bf16, bf16, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 128 == 0, K % 64 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c64 = pdl.attribute = 64
+  %c128 = pdl.attribute = 128
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // M, N >= 1024, K >= 256
+  %c4 = pdl.attribute = 4
+  %c512 = pdl.attribute = 512
+  %c1024 = pdl.attribute = 1024
+
+  // TODO: Kernel specialization is needed to apply this strategy selectively at
+  // runtime. Additionally model exports don't specify lower bounds so it is
+  // impossible to use this strategy with this check.
+  // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_medium_bf16_expanded", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_bf16.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
+
+// This pattern matches an expanded bf16 matmul-like operation of large size and annotates it
+// with ukernel descriptor and configuration attributes. This is preferred over the medium
+// strategy.
+pdl.pattern @annotate_matmul_like_bf16_large_expanded : benefit(2) {
+  %elemtypes = pdl.attribute = [bf16, bf16, f32]
+  %imaps = pdl.attribute = [
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  ]
+
+  %lhs_type = pdl.type
+  %rhs_type = pdl.type
+  %out_type = pdl.type
+
+  %lhs = pdl.operand : %lhs_type
+  %rhs = pdl.operand : %rhs_type
+  %out_init = pdl.operand : %out_type
+
+  // Match the a matmul-like generic with above indexing maps.
+  %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+  pdl.apply_native_constraint "matchContraction"(
+        %generic_op, %elemtypes, %imaps
+        : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+  %attr_name = pdl.attribute = "iree_codegen.ukernel"
+  pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+  // M % 256 == 0, K % 64 == 0, N % 256 == 0
+  %empty = pdl.attribute = {}
+  %c0 = pdl.attribute = 0
+  %c1 = pdl.attribute = 1
+  %c2 = pdl.attribute = 2
+  %c64 = pdl.attribute = 64
+  %c256 = pdl.attribute = 256
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+  // M, N >= 1024, K >= 256
+  %c1024 = pdl.attribute = 1024
+  pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+  pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+  pdl.rewrite {
+    // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+    // This modifies the operation in-place.
+
+    %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %config_name = pdl.attribute = "compilation_info"
+    %config = pdl.attribute = #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 256, 256, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        // This strategy uses the maximum amount of possible shared memory on
+        // all gfx942 architectures so shared memory padding to reduce bank
+        // conflicts must be disabled. Also prefetching is done manually in the
+        // above and is disabled here as well.
+        {gpu_pipeline_options =
+          #iree_gpu.pipeline_options<
+            prefetch_shared_memory = false,
+            no_reduce_shared_memory_bank_conflicts = true>,
+        // This strategy requires 2 waves per SIMD.
+          llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+    >
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+    %builtin_attr = pdl.attribute = "rocm.builtin_name"
+    %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_bf16.mlir"
+    pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+  }
+}
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
index 4438517..8ac410f 100644
--- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h"
 
 #include "iree/compiler/Utils/EquivalenceUtils.h"
+#include "iree/compiler/Utils/ShapeUtils.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -33,32 +34,6 @@
 // MatchCastCompatibleDagFromRootOp
 //===----------------------------------------------------------------------===//
 
-static bool isCastableToTensorType(Type from, RankedTensorType to) {
-  auto tensorType = dyn_cast<RankedTensorType>(from);
-  if (!tensorType) {
-    return false;
-  }
-  if (tensorType.getRank() != to.getRank()) {
-    return false;
-  }
-  if (tensorType.getElementType() != to.getElementType()) {
-    return false;
-  }
-  for (auto [fromSize, toSize] :
-       llvm::zip_equal(tensorType.getShape(), to.getShape())) {
-    // If the target dimension is dynamic we can always cast to it.
-    if (ShapedType::isDynamic(toSize)) {
-      continue;
-    }
-    // Casting a dynamic dimension to a static one is never valid, and static
-    // sizes must always match.
-    if (toSize != fromSize) {
-      return false;
-    }
-  }
-  return true;
-}
-
 // Compares the regions between two operations in lockstep for equality.
 static DiagnosedSilenceableFailure
 compareOperationRegions(transform::TransformOpInterface transformOp,
diff --git a/compiler/src/iree/compiler/Utils/ShapeUtils.cpp b/compiler/src/iree/compiler/Utils/ShapeUtils.cpp
index 197a95d..e0d3c45 100644
--- a/compiler/src/iree/compiler/Utils/ShapeUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/ShapeUtils.cpp
@@ -38,4 +38,30 @@
   return numNonmatchingSSADims <= 1;
 }
 
+bool isCastableToTensorType(Type from, RankedTensorType to) {
+  auto tensorType = dyn_cast<RankedTensorType>(from);
+  if (!tensorType) {
+    return false;
+  }
+  if (tensorType.getRank() != to.getRank()) {
+    return false;
+  }
+  if (tensorType.getElementType() != to.getElementType()) {
+    return false;
+  }
+  for (auto [fromSize, toSize] :
+       llvm::zip_equal(tensorType.getShape(), to.getShape())) {
+    // If the target dimension is dynamic we can always cast to it.
+    if (ShapedType::isDynamic(toSize)) {
+      continue;
+    }
+    // Casting a dynamic dimension to a static one is never valid, and static
+    // sizes must always match.
+    if (toSize != fromSize) {
+      return false;
+    }
+  }
+  return true;
+}
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Utils/ShapeUtils.h b/compiler/src/iree/compiler/Utils/ShapeUtils.h
index 08b7e12..2fc754b 100644
--- a/compiler/src/iree/compiler/Utils/ShapeUtils.h
+++ b/compiler/src/iree/compiler/Utils/ShapeUtils.h
@@ -8,6 +8,7 @@
 #define IREE_COMPILER_UTILS_SHAPEUTILS_H_
 
 #include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ValueRange.h"
 
 namespace mlir::iree_compiler {
@@ -21,6 +22,9 @@
 bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
                         ShapedType rhsType, ValueRange rhsDynamicDims);
 
+/// Helper to check whether 'from' is castable to the target ranked tensor type.
+bool isCastableToTensorType(Type from, RankedTensorType to);
+
 } // namespace mlir::iree_compiler
 
 #endif // IREE_COMPILER_UTILS_SHAPEUTILS_H_
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index a5e1cda..464b587 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1707,6 +1707,66 @@
     "requires-gpu-cdna3"
 )
 
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_tensor_ukernel_f16f16f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=custom_mnk"
+    "--mnk=1024,1024,1024"
+    "--transpose_rhs"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+    "--iree-hip-enable-tensor-ukernels"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
+    e2e_matmul_tensor_ukernel_bf16bf16f32_large
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=custom_mnk"
+    "--mnk=1024,1024,1024"
+    "--transpose_rhs"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+    "--iree-hip-enable-tensor-ukernels"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
 endif()
 
 elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")