[ROCM] Port mlir ukernels to ukernel descriptor lowering flow (#21683)
This copies all ukernels from the tuning spec to the ukernel descriptor
based lowering and PDL patterns. This doesn't remove the ukernels in the
spec yet as that requires the usage of
`--iree-codegen-enable-default-tuning-specs=true` to be updated to
`--iree-hip-enable-tensor-ukernels` everywhere, which imo is better done in a
separate PR.
The ukernels and matching patterns being copied in this PR:
- pingpong_large_f8_expanded
- pingpong_large_f16
- pingpong_medium_f16_expanded
- pingpong_large_f16_expanded
- pingpong_large_bf16
- pingpong_medium_bf16_expanded
- pingpong_large_bf16_expanded
Note that the mmt_2048x1280x5120_f16_f16_f32 matching and annotation is
not ported as I think this is not reachable due to pingpong_large_f16
matching the same and taking precedence.
---------
Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
Co-authored-by: Quinn Dawkins <quinn.dawkins@gmail.com>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
index 8a54cf6..f562c11 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Utils/ShapeUtils.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -61,7 +62,8 @@
return rewriter.notifyMatchFailure(rootOp,
"expected StringAttr for attr name.");
}
- rootOp->setAttr(strName.strref(), annotation);
+ rewriter.modifyOpInPlace(
+ rootOp, [&]() { rootOp->setAttr(strName.strref(), annotation); });
return success();
}
@@ -75,7 +77,6 @@
return rewriter.notifyMatchFailure(rootOp,
"not a contraction like linalg op");
}
-
if (linalgOp.getIndexingMaps() != indexingMaps) {
return rewriter.notifyMatchFailure(rootOp, "indexing maps mismatch");
}
@@ -102,6 +103,9 @@
if (!dim) {
return failure();
}
+ if (dimInt.getInt() >= shapedType.getRank()) {
+ return failure();
+ }
auto divisorInt = dyn_cast<IntegerAttr>(divisor);
if (!divisor) {
return failure();
@@ -140,6 +144,9 @@
if (!dimInt) {
return failure();
}
+ if (dimInt.getInt() >= shapedType.getRank()) {
+ return failure();
+ }
if (auto lowerBoundInt = dyn_cast<IntegerAttr>(lowerBound)) {
FailureOr<int64_t> constantLb =
ValueBoundsConstraintSet::computeConstantBound(
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
index 6959765..18c5013 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir
@@ -133,3 +133,77 @@
// CHECK-LABEL: @negative_matmul_f8_dynamic_lower_bound
// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_f8_expanded", tensor>
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_f16(%arg0: tensor<256x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<256x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<256x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<256x1024xf32>) {
+ ^bb0(%in: f16, %in_4: f16, %out: f32):
+ %12 = arith.extf %in : f16 to f32
+ %13 = arith.extf %in_4 : f16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<256x1024xf32>
+ return %2 : tensor<256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_f16
+// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_bf16(%arg0: tensor<256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<256x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<256x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<256x1024xf32>) {
+ ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+ %12 = arith.extf %in : bf16 to f32
+ %13 = arith.extf %in_4 : bf16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<256x1024xf32>
+ return %2 : tensor<256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16
+// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
+
+// -----
+
+// The dynamic dimension is a multiple of 256, but doesn't have a lower bound of 256.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_matmul_bf16_dynamic_lower_bound(%arg0: index) -> tensor<1x256x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = util.assume.int %arg0<umin = 128, udiv = 256> : index
+ %1 = tensor.empty(%0) : tensor<1x256x?xbf16>
+ %2 = tensor.empty(%0) : tensor<1024x?xbf16>
+ %3 = tensor.empty() : tensor<1x256x1024xf32>
+ %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+ %5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %2 : tensor<1x256x?xbf16>, tensor<1024x?xbf16>) outs(%4 : tensor<1x256x1024xf32>) {
+ ^bb0(%in: bf16, %in_0: bf16, %out: f32):
+ %6 = arith.extf %in : bf16 to f32
+ %7 = arith.extf %in_0 : bf16 to f32
+ %8 = arith.mulf %6, %7 : f32
+ %9 = arith.addf %out, %8 : f32
+ linalg.yield %9 : f32
+ } -> tensor<1x256x1024xf32>
+ return %5 : tensor<1x256x1024xf32>
+}
+// CHECK-LABEL: @negative_matmul_bf16_dynamic_lower_bound
+// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
+// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
index 7ca6748..e56b7ec 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
@@ -17,7 +17,7 @@
module attributes {
hal.executable.target = #executable_target_rocm_hsaco_fb
} {
- func.func @matmul_f8(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
+ func.func @matmul_f8_medium_expanded(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x128x1024xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
@@ -34,3 +34,297 @@
}
// CHECK-LABEL: util.func private @pingpong_medium_f8_expanded
// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_f8_large_expanded(%arg0: tensor<1x256x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x256x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x256x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x256x1024xf32>) {
+ ^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
+ %12 = arith.extf %in : f8E4M3FNUZ to f32
+ %13 = arith.extf %in_4 : f8E4M3FNUZ to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x256x1024xf32>
+ return %2 : tensor<1x256x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_f8_large_expanded
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_f8_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_f8_expanded
+// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_f16_large(%arg0: tensor<1024x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1024x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1024x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1024x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1024x1024xf32>) {
+ ^bb0(%in: f16, %in_4: f16, %out: f32):
+ %12 = arith.extf %in : f16 to f32
+ %13 = arith.extf %in_4 : f16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1024x1024xf32>
+ return %2 : tensor<1024x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_f16_large
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_f16", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_f16
+// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_f16_medium_expanded(%arg0: tensor<1x128x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1x128x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x128x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1x128x1024xf32>) {
+ ^bb0(%in: f16, %in_4: f16, %out: f32):
+ %12 = arith.extf %in : f16 to f32
+ %13 = arith.extf %in_4 : f16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x128x1024xf32>
+ return %2 : tensor<1x128x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_f16_medium_expanded
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_f16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_medium_f16_expanded
+// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_f16_large_expanded(%arg0: tensor<1x256x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<1x256x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x256x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<1x256x1024xf32>) {
+ ^bb0(%in: f16, %in_4: f16, %out: f32):
+ %12 = arith.extf %in : f16 to f32
+ %13 = arith.extf %in_4 : f16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x256x1024xf32>
+ return %2 : tensor<1x256x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_f16_large_expanded
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_f16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_f16_expanded
+// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_bf16_large(%arg0: tensor<1024x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1024x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1024x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1024x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1024x1024xf32>) {
+ ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+ %12 = arith.extf %in : bf16 to f32
+ %13 = arith.extf %in_4 : bf16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1024x1024xf32>
+ return %2 : tensor<1024x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_bf16_large
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_bf16
+// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_bf16_expanded_large(%arg0: tensor<1x256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1x256x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x256x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x256x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1x256x1024xf32>) {
+ ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+ %12 = arith.extf %in : bf16 to f32
+ %13 = arith.extf %in_4 : bf16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x256x1024xf32>
+ return %2 : tensor<1x256x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_bf16_expanded_large
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_large_bf16_expanded
+// CHECK: iree_codegen.inner_tiled
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none,
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+module attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb
+} {
+ func.func @matmul_bf16_expanded_medium(%arg0: tensor<1x128x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<1x128x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x128x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<1x128x1024xf32>) {
+ ^bb0(%in: bf16, %in_4: bf16, %out: f32):
+ %12 = arith.extf %in : bf16 to f32
+ %13 = arith.extf %in_4 : bf16 to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x128x1024xf32>
+ return %2 : tensor<1x128x1024xf32>
+ }
+}
+// CHECK-LABEL: @matmul_bf16_expanded_medium
+// CHECK: linalg.generic
+// CHECK-SAME: compilation_info = #iree_codegen.compilation_info
+// CHECK-SAME: lowering_config =
+// CHECK-SAME: translation_info =
+// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_bf16_expanded", tensor>
+// CHECK-LABEL: util.func private @pingpong_medium_bf16_expanded
+// CHECK: iree_codegen.inner_tiled
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel
index c9749cd..a5c4e18 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel
@@ -55,6 +55,8 @@
iree_c_embed_data(
name = "iree_mlir_ukernels_amdgpu",
srcs = [
+ "iree_uk_amdgpu_matmul_bf16.mlir",
+ "iree_uk_amdgpu_matmul_f16.mlir",
"iree_uk_amdgpu_matmul_f8.mlir",
],
c_file_output = "iree_mlir_ukernels_amdgpu.c",
@@ -67,6 +69,8 @@
iree_lit_test_suite(
name = "verify_mlir_ukernels_amdgpu",
srcs = [
+ "iree_uk_amdgpu_matmul_bf16.mlir",
+ "iree_uk_amdgpu_matmul_f16.mlir",
"iree_uk_amdgpu_matmul_f8.mlir",
],
cfg = "//compiler:lit.cfg.py",
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt
index 0df959c..73815c9 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt
@@ -40,6 +40,8 @@
NAME
iree_mlir_ukernels_amdgpu
SRCS
+ "iree_uk_amdgpu_matmul_bf16.mlir"
+ "iree_uk_amdgpu_matmul_f16.mlir"
"iree_uk_amdgpu_matmul_f8.mlir"
C_FILE_OUTPUT
"iree_mlir_ukernels_amdgpu.c"
@@ -53,6 +55,8 @@
NAME
verify_mlir_ukernels_amdgpu
SRCS
+ "iree_uk_amdgpu_matmul_bf16.mlir"
+ "iree_uk_amdgpu_matmul_f16.mlir"
"iree_uk_amdgpu_matmul_f8.mlir"
TOOLS
iree-opt
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir
new file mode 100644
index 0000000..9a138e2
--- /dev/null
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_bf16.mlir
@@ -0,0 +1,659 @@
+// RUN: iree-opt %s
+
+!bf16_in_ty = tensor<256x?xbf16>
+!bf16_exp_in_ty = tensor<1x256x?xbf16>
+!bf16_block_in = tensor<256x64xbf16>
+!bf16_exp_block_in = tensor<1x256x64xbf16>
+!bf16_flat_shared = memref<16384xbf16, #gpu.address_space<workgroup>>
+!bf16_shared = memref<256x64xbf16, #gpu.address_space<workgroup>>
+!bf16_shared_exp = memref<16x16x4x16xbf16, #gpu.address_space<workgroup>>
+
+!in_ty_bf16 = tensor<256x?xbf16>
+!exp_in_ty_bf16 = tensor<1x256x?xbf16>
+!block_in_bf16 = tensor<256x64xbf16>
+!exp_block_in_bf16 = tensor<1x256x64xbf16>
+!flat_shared_bf16 = memref<16384xbf16, #gpu.address_space<workgroup>>
+!shared_bf16 = memref<256x64xbf16, #gpu.address_space<workgroup>>
+!shared_exp_bf16 = memref<16x16x4x16xbf16, #gpu.address_space<workgroup>>
+
+!mexp_in_ty_bf16 = tensor<1x128x?xbf16>
+!mexp_block_in_bf16 = tensor<1x128x64xbf16>
+!mflat_shared_bf16 = memref<8192xbf16, #gpu.address_space<workgroup>>
+!mshared_bf16 = memref<128x64xbf16, #gpu.address_space<workgroup>>
+!mshared_exp_bf16 = memref<8x16x4x16xbf16, #gpu.address_space<workgroup>>
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+
+util.func private @pingpong_large_bf16(%lhs_base: !bf16_in_ty, %rhs_base: !bf16_in_ty, %unused_acc: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : bf16
+ %lhs_shared_base = memref.alloc() : !bf16_flat_shared
+ %rhs_shared_base = memref.alloc() : !bf16_flat_shared
+
+ %dim = tensor.dim %lhs_base, %c1 : !bf16_in_ty
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [%delin#0, %vec] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+
+ %0 = tensor.empty() : tensor<16x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<16x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw>: index
+ %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw>: index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw>: index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 8 elements.
+ %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw>: index
+ // Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw>: index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw>: index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw>: index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw>: index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw>: index
+
+ %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, %i] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [%glb0, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [%glb1, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_2 = tensor.extract_slice %lhs_block [%glb2, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_3 = tensor.extract_slice %lhs_block [%glb3, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot3 : vector<8x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+ %empty = tensor.empty() : tensor<8x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<8x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[%m_outer_id, %ids#3, %n_outer_id, %inner_id] [8, 1, 4, 4] [1, 1, 1, 1] : tensor<8x1x4x4xf32> into tensor<16x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<16x16x16x16xf32> into tensor<256x256xf32>
+ util.return %collapse : tensor<256x256xf32>
+}
+
+// Expanded variants of BF16
+
+util.func private @pingpong_medium_bf16_expanded(%lhs_base: !mexp_in_ty_bf16, %rhs_base: !in_ty_bf16, %unused_acc: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : bf16
+ %lhs_shared_base = memref.alloc() : !mflat_shared_bf16
+ %rhs_shared_base = memref.alloc() : !flat_shared_bf16
+
+ %dim = tensor.dim %rhs_base, %c1 : !in_ty_bf16
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty_bf16
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_bf16
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !mflat_shared_bf16
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared_bf16
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [128, 64] : !mflat_shared_bf16 into !mshared_bf16
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared_bf16 into !shared_bf16
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 128, 64] [1, 1, 1] : !mexp_in_ty_bf16 to !mexp_block_in_bf16
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty_bf16 to !block_in_bf16
+
+ scf.forall (%id) in (1024) {
+ %delin:2 = affine.delinearize_index %id into (128, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !mexp_block_in_bf16 to tensor<1x1x8xbf16>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !mshared_bf16
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw>: index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [8, 16, 4, 16] : !mshared_bf16 into !mshared_exp_bf16
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared_bf16 into !shared_exp_bf16
+
+ %0 = tensor.empty() : tensor<1x8x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw>: index
+ %m_outer_id = arith.muli %ids#0, %c4 overflow<nsw, nuw>: index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw>: index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 8 elements.
+ %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw>: index
+ // RHS indexing. Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw>: index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw>: index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw>: index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw>: index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw>: index
+ // LHS indexing.
+ %bpo_lhs = arith.muli %wt#0, %c16 overflow<nsw, nuw>: index
+ %glb0_lhs = arith.addi %wt#1, %bpo_lhs overflow<nsw, nuw>: index
+ %glb1_lhs = arith.addi %glb0_lhs, %c8 overflow<nsw, nuw>: index
+
+ %2 = arith.constant dense<0.0> : vector<4x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<4x4x1x4xf32> {
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+ %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+ %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty_bf16 to !block_in_bf16
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in_bf16 to tensor<1x8xbf16>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+ %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+ %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+ rocdl.sched.barrier 0
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 128, 64] [1, 1, 1] : !mexp_in_ty_bf16 to !mexp_block_in_bf16
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in_bf16 to tensor<1x1x8xbf16>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in_bf16 to tensor<1x1x8xbf16>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !shared_bf16
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !mshared_bf16
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !mshared_bf16
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot2 : vector<4x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+ %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+ %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp_bf16, vector<4x1x2x4xbf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_bf16, vector<4x1x2x4xbf16>
+ %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+ %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xbf16> to vector<4x2x1x4xbf16>
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<4x2x1x4xbf16>, vector<4x2x1x4xbf16> into vector<4x4x1x4xf32>
+
+ %tp = vector.transpose %dot2, [0, 2, 1, 3] : vector<4x4x1x4xf32> to vector<4x1x4x4xf32>
+ %empty = tensor.empty() : tensor<1x4x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<4x1x4x4xf32>, tensor<1x4x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 4, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x4x1x4x4xf32> into tensor<1x8x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x8x16x16x16xf32> into tensor<1x128x256xf32>
+ util.return %collapse : tensor<1x128x256xf32>
+}
+
+util.func private @pingpong_large_bf16_expanded(%lhs_base: !bf16_exp_in_ty, %rhs_base: !bf16_in_ty, %unused_acc: tensor<1x256x256xf32>) -> tensor<1x256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : bf16
+ %lhs_shared_base = memref.alloc() : !bf16_flat_shared
+ %rhs_shared_base = memref.alloc() : !bf16_flat_shared
+
+ %dim = tensor.dim %rhs_base, %c1 : !bf16_in_ty
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !bf16_exp_in_ty
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !bf16_in_ty
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !bf16_flat_shared
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !bf16_flat_shared into !bf16_shared
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 256, 64] [1, 1, 1] : !bf16_exp_in_ty to !bf16_exp_block_in
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !bf16_shared into !bf16_shared_exp
+
+ %0 = tensor.empty() : tensor<1x16x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+ %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 8 elements.
+ %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+ // Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw>: index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw>: index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw>: index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw>: index
+
+ %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 256, 64] [1, 1, 1] : !bf16_exp_in_ty to !bf16_exp_block_in
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_2 = tensor.extract_slice %lhs_block [0, %glb2, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+ %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+ %lhs_thread_3 = tensor.extract_slice %lhs_block [0, %glb3, %gko] [1, 1, 8] [1, 1, 1] : !bf16_exp_block_in to tensor<1x1x8xbf16>
+ %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xbf16>, vector<1x8xbf16>
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !bf16_in_ty to !bf16_block_in
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !bf16_block_in to tensor<1x8xbf16>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xbf16>, vector<1x8xbf16>
+
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xbf16>, !bf16_shared
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot3 : vector<8x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<8x1x1x4xbf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !bf16_shared_exp, vector<4x1x1x4xbf16>
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16, col_major = true>
+ } : vector<8x1x1x4xbf16>, vector<4x1x1x4xbf16> into vector<8x4x1x4xf32>
+
+ %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+ %empty = tensor.empty() : tensor<1x8x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<1x8x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 8, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x8x1x4x4xf32> into tensor<1x16x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x16x16x16x16xf32> into tensor<1x256x256xf32>
+ util.return %collapse : tensor<1x256x256xf32>
+}
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir
new file mode 100644
index 0000000..480b849
--- /dev/null
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f16.mlir
@@ -0,0 +1,649 @@
+// RUN: iree-opt %s
+
+!in_ty = tensor<256x?xf16>
+!exp_in_ty = tensor<1x256x?xf16>
+!block_in = tensor<256x64xf16>
+!exp_block_in = tensor<1x256x64xf16>
+!flat_shared = memref<16384xf16, #gpu.address_space<workgroup>>
+!shared = memref<256x64xf16, #gpu.address_space<workgroup>>
+!shared_exp = memref<16x16x4x16xf16, #gpu.address_space<workgroup>>
+
+!mexp_in_ty = tensor<1x128x?xf16>
+!mexp_block_in = tensor<1x128x64xf16>
+!mflat_shared = memref<8192xf16, #gpu.address_space<workgroup>>
+!mshared = memref<128x64xf16, #gpu.address_space<workgroup>>
+!mshared_exp = memref<8x16x4x16xf16, #gpu.address_space<workgroup>>
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+
+util.func private @pingpong_large_f16(%lhs_base: !in_ty, %rhs_base: !in_ty, %unused_acc: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : f16
+ %lhs_shared_base = memref.alloc() : !flat_shared
+ %rhs_shared_base = memref.alloc() : !flat_shared
+
+ %dim = tensor.dim %lhs_base, %c1 : !in_ty
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !in_ty
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+
+ %0 = tensor.empty() : tensor<16x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<16x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+ %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 8 elements.
+ %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+ // Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+
+ %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %lhs_thread_2 = tensor.extract_slice %lhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %lhs_thread_3 = tensor.extract_slice %lhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot3 : vector<8x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+ %empty = tensor.empty() : tensor<8x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<8x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[%m_outer_id, %ids#3, %n_outer_id, %inner_id] [8, 1, 4, 4] [1, 1, 1, 1] : tensor<8x1x4x4xf32> into tensor<16x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0, 1], [2, 3]] : tensor<16x16x16x16xf32> into tensor<256x256xf32>
+ util.return %collapse : tensor<256x256xf32>
+}
+
+util.func private @pingpong_medium_f16_expanded(%lhs_base: !mexp_in_ty, %rhs_base: !in_ty, %unused_acc: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : f16
+ %lhs_shared_base = memref.alloc() : !mflat_shared
+ %rhs_shared_base = memref.alloc() : !flat_shared
+
+ %dim = tensor.dim %rhs_base, %c1 : !in_ty
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !mexp_in_ty
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !mflat_shared
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [128, 64] : !mflat_shared into !mshared
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 128, 64] [1, 1, 1] : !mexp_in_ty to !mexp_block_in
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+
+ scf.forall (%id) in (1024) {
+ %delin:2 = affine.delinearize_index %id into (128, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [8, 16, 4, 16] : !mshared into !mshared_exp
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+
+ %0 = tensor.empty() : tensor<1x8x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+ %m_outer_id = arith.muli %ids#0, %c4 overflow<nsw, nuw> : index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 8 elements.
+ %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+ // RHS indexing. Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+ // LHS indexing.
+ %bpo_lhs = arith.muli %wt#0, %c16 overflow<nsw, nuw> : index
+ %glb0_lhs = arith.addi %wt#1, %bpo_lhs overflow<nsw, nuw> : index
+ %glb1_lhs = arith.addi %glb0_lhs, %c8 overflow<nsw, nuw> : index
+
+ %2 = arith.constant dense<0.0> : vector<4x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<4x4x1x4xf32> {
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+ %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+ %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+ %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+ %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+ rocdl.sched.barrier 0
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 128, 64] [1, 1, 1] : !mexp_in_ty to !mexp_block_in
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1_lhs, %gko] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1_lhs, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot2 : vector<4x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+ %lhs_vec_0_t = vector.transpose %lhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+ %rhs_vec_0_t = vector.transpose %rhs_vec_0, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0_t, %rhs_vec_0_t) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !mshared_exp, vector<4x1x2x4xf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x2x4xf16>
+ %lhs_vec_2_t = vector.transpose %lhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+ %rhs_vec_2_t = vector.transpose %rhs_vec_2, [0, 2, 1, 3] : vector<4x1x2x4xf16> to vector<4x2x1x4xf16>
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2_t, %rhs_vec_2_t) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<4x2x1x4xf16>, vector<4x2x1x4xf16> into vector<4x4x1x4xf32>
+
+ %tp = vector.transpose %dot2, [0, 2, 1, 3] : vector<4x4x1x4xf32> to vector<4x1x4x4xf32>
+ %empty = tensor.empty() : tensor<1x4x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<4x1x4x4xf32>, tensor<1x4x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 4, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x4x1x4x4xf32> into tensor<1x8x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x8x16x16x16xf32> into tensor<1x128x256xf32>
+ util.return %collapse : tensor<1x128x256xf32>
+}
+
+util.func private @pingpong_large_f16_expanded(%lhs_base: !exp_in_ty, %rhs_base: !in_ty, %unused_acc: tensor<1x256x256xf32>) -> tensor<1x256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : f16
+ %lhs_shared_base = memref.alloc() : !flat_shared
+ %rhs_shared_base = memref.alloc() : !flat_shared
+
+ %dim = tensor.dim %rhs_base, %c1 : !in_ty
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !exp_in_ty
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<64, 4>] : !flat_shared
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 64] : !flat_shared into !shared
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 256, 64] [1, 1, 1] : !exp_in_ty to !exp_block_in
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 64] [1, 1] : !in_ty to !block_in
+
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 16] : !shared into !shared_exp
+
+ %0 = tensor.empty() : tensor<1x16x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+ %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 8 elements.
+ %gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
+ // Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+
+ %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c64 to %dim step %c64 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 256, 64] [1, 1, 1] : !exp_in_ty to !exp_block_in
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+ %lhs_thread_2 = tensor.extract_slice %lhs_block [0, %glb2, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+ %lhs_thread_3 = tensor.extract_slice %lhs_block [0, %glb3, %gko] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
+ %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 64] [1, 1] : !in_ty to !block_in
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
+
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot3 : vector<8x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<8x1x1x4xf16>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp, vector<4x1x1x4xf16>
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
+ } : vector<8x1x1x4xf16>, vector<4x1x1x4xf16> into vector<8x4x1x4xf32>
+
+ %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+ %empty = tensor.empty() : tensor<1x8x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<1x8x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id] [1, 8, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x8x1x4x4xf32> into tensor<1x16x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x16x16x16x16xf32> into tensor<1x256x256xf32>
+ util.return %collapse : tensor<1x256x256xf32>
+}
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir
index b5dadcd..d808b97 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_matmul_f8.mlir
@@ -209,3 +209,226 @@
%collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x8x16x16x16xf32> into tensor<1x128x256xf32>
util.return %collapse : tensor<1x128x256xf32>
}
+
+util.func private @pingpong_large_f8_expanded(%lhs_base: !exp_in_ty_f8, %rhs_base: !in_ty_f8, %unused_acc: tensor<1x256x256xf32>) -> tensor<1x256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.0 : f8E4M3FNUZ
+ %lhs_shared_base = memref.alloc() : !flat_shared_f8
+ %rhs_shared_base = memref.alloc() : !flat_shared_f8
+
+ %dim = tensor.dim %rhs_base, %c1 : !in_ty_f8
+ %lhs = iree_gpu.buffer_resource_cast %lhs_base cacheSwizzleStride(%dim) : !exp_in_ty_f8
+ %rhs = iree_gpu.buffer_resource_cast %rhs_base cacheSwizzleStride(%dim) : !in_ty_f8
+
+ %lhs_shared_swizzle = iree_codegen.swizzle_hint %lhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8
+ %rhs_shared_swizzle = iree_codegen.swizzle_hint %rhs_shared_base[#iree_codegen.rotate_rows<128, 8>] : !flat_shared_f8
+
+ %lhs_shared = memref.expand_shape %lhs_shared_swizzle [[0, 1]] output_shape [256, 128] : !flat_shared_f8 into !shared_f8
+ %rhs_shared = memref.expand_shape %rhs_shared_swizzle [[0, 1]] output_shape [256, 128] : !flat_shared_f8 into !shared_f8
+
+ %lhs_init = tensor.extract_slice %lhs [0, 0, 0] [1, 256, 128] [1, 1, 1] : !exp_in_ty_f8 to !exp_block_in_f8
+ %rhs_init = tensor.extract_slice %rhs [0, 0] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8
+
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
+ %lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+ %lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ scf.forall (%id) in (2048) {
+ %delin:2 = affine.delinearize_index %id into (256, 8) : index, index
+ %vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
+ %rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+ %rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+
+ %lhs_shared_expand = memref.expand_shape %lhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 32] : !shared_f8 into !shared_exp_f8
+ %rhs_shared_expand = memref.expand_shape %rhs_shared [[0, 1], [2, 3]] output_shape [16, 16, 4, 32] : !shared_f8 into !shared_exp_f8
+
+ %0 = tensor.empty() : tensor<1x16x16x16x16xf32>
+ %1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
+ %ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
+ %inner_id = arith.muli %ids#2, %c8 overflow<nsw, nuw> : index
+ %inner_id_acc = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
+ %m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
+ %n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
+ %delin:2 = affine.delinearize_index %id into (64, 8) : index, index
+ %wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
+
+ // Inner 64 loads 8 threads x 16 elements.
+ %gko = arith.muli %wt#2, %c16 overflow<nsw, nuw> : index
+ // Each subgroup loads 32 contiguous rows out of 256.
+ %bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
+ // Base index is remaining outer 8 lanes + subgroup base.
+ %glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
+ %glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
+ %glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
+ %glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
+
+ %2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
+
+ %cmp0 = arith.cmpi slt, %id, %c256 : index
+ %cmp1 = arith.cmpi sge, %id, %c256 : index
+ scf.if %cmp0 {
+ rocdl.s.barrier
+ }
+ %3 = scf.for %i = %c128 to %dim step %c128 iter_args(%iter = %2) -> vector<8x4x1x4xf32> {
+
+ // Global loads of lhs.
+ %lhs_block = tensor.extract_slice %lhs [0, 0, %i] [1, 256, 128] [1, 1, 1] : !exp_in_ty_f8 to !exp_block_in_f8
+ %lhs_thread_0 = tensor.extract_slice %lhs_block [0, %glb0, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+ %lhs_vec_local_0 = vector.transfer_read %lhs_thread_0 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ %lhs_thread_1 = tensor.extract_slice %lhs_block [0, %glb1, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+ %lhs_vec_local_1 = vector.transfer_read %lhs_thread_1 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ %lhs_thread_2 = tensor.extract_slice %lhs_block [0, %glb2, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+ %lhs_vec_local_2 = vector.transfer_read %lhs_thread_2 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ %lhs_thread_3 = tensor.extract_slice %lhs_block [0, %glb3, %gko] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
+ %lhs_vec_local_3 = vector.transfer_read %lhs_thread_3 [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%iter) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ // Global loads of rhs.
+ %rhs_block = tensor.extract_slice %rhs [0, %i] [256, 128] [1, 1] : !in_ty_f8 to !block_in_f8
+ %rhs_thread_0 = tensor.extract_slice %rhs_block [%glb0, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+ %rhs_vec_local_0 = vector.transfer_read %rhs_thread_0 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ %rhs_thread_1 = tensor.extract_slice %rhs_block [%glb1, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+ %rhs_vec_local_1 = vector.transfer_read %rhs_thread_1 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ %rhs_thread_2 = tensor.extract_slice %rhs_block [%glb2, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+ %rhs_vec_local_2 = vector.transfer_read %rhs_thread_2 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+ %rhs_thread_3 = tensor.extract_slice %rhs_block [%glb3, %gko] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
+ %rhs_vec_local_3 = vector.transfer_read %rhs_thread_3 [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
+
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ vector.transfer_write %lhs_vec_local_0, %lhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %lhs_vec_local_1, %lhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %lhs_vec_local_2, %lhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %lhs_vec_local_3, %lhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+
+ vector.transfer_write %rhs_vec_local_0, %rhs_shared [%glb0, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %rhs_vec_local_1, %rhs_shared [%glb1, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %rhs_vec_local_2, %rhs_shared [%glb2, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+ vector.transfer_write %rhs_vec_local_3, %rhs_shared [%glb3, %gko] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
+
+ gpu.barrier
+ rocdl.sched.barrier 0
+ rocdl.s.setprio 1 { iree_gpu.swap_mfma = 1 }
+
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+ rocdl.s.setprio 0
+ gpu.barrier
+ rocdl.sched.barrier 0
+
+ scf.yield %dot3 : vector<8x4x1x4xf32>
+ }
+ scf.if %cmp1 {
+ rocdl.s.barrier
+ }
+
+ // Epilogue
+ %lhs_vec_0 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_0 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c0, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+ %dot0 = iree_codegen.inner_tiled ins(%lhs_vec_0, %rhs_vec_0) outs(%3) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+ %lhs_vec_1 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_1 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c1, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+ %dot1 = iree_codegen.inner_tiled ins(%lhs_vec_1, %rhs_vec_1) outs(%dot0) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+ %lhs_vec_2 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_2 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c2, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+ %dot2 = iree_codegen.inner_tiled ins(%lhs_vec_2, %rhs_vec_2) outs(%dot1) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+ %lhs_vec_3 = vector.transfer_read %lhs_shared_expand[%m_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<8x1x1x8xf8E4M3FNUZ>
+ %rhs_vec_3 = vector.transfer_read %rhs_shared_expand[%n_outer_id, %ids#3, %c3, %inner_id], %cst {in_bounds = [true, true, true, true]} : !shared_exp_f8, vector<4x1x1x8xf8E4M3FNUZ>
+ %dot3 = iree_codegen.inner_tiled ins(%lhs_vec_3, %rhs_vec_3) outs(%dot2) {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ, col_major = true>
+ } : vector<8x1x1x8xf8E4M3FNUZ>, vector<4x1x1x8xf8E4M3FNUZ> into vector<8x4x1x4xf32>
+
+ %tp = vector.transpose %dot3, [0, 2, 1, 3] : vector<8x4x1x4xf32> to vector<8x1x4x4xf32>
+ %empty = tensor.empty() : tensor<1x8x1x4x4xf32>
+ %4 = vector.transfer_write %tp, %empty[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x1x4x4xf32>, tensor<1x8x1x4x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %out[0, %m_outer_id, %ids#3, %n_outer_id, %inner_id_acc] [1, 8, 1, 4, 4] [1, 1, 1, 1, 1] : tensor<1x8x1x4x4xf32> into tensor<1x16x16x16x16xf32>
+ }
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ %collapse = tensor.collapse_shape %1 [[0], [1, 2], [3, 4]] : tensor<1x16x16x16x16xf32> into tensor<1x256x256xf32>
+ util.return %collapse : tensor<1x256x256xf32>
+}
diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
index 64545dc..c6c83b4 100644
--- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir
@@ -1,8 +1,10 @@
// RUN: iree-opt -allow-unregistered-dialect %s
-// This pattern matches an expanded matmul-like operation and annotates it
-// with ukernel descriptor and configuration attributes.
-pdl.pattern @annotate_expanded_matmul_like : benefit(1) {
+// F8 Patterns
+
+// This pattern matches a medium-sized expanded matmul-like operation and
+// annotates it with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_f8_medium_expanded : benefit(1) {
%elemtypes = pdl.attribute = [f8E4M3FNUZ, f8E4M3FNUZ, f32]
%imaps = pdl.attribute = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
@@ -19,7 +21,7 @@
%out_init = pdl.operand : %out_type
// Match the a matmul-like generic with above indexin maps.
- %generic_op = pdl.operation "linalg.generic" (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
pdl.apply_native_constraint "matchContraction"(
%generic_op, %elemtypes, %imaps
: !pdl.operation, !pdl.attribute, !pdl.attribute)
@@ -77,3 +79,568 @@
pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
}
}
+
+// This pattern matches a large expanded f8 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes. This is preferred over the
+// medium-sized ukernel.
+pdl.pattern @annotate_matmul_like_f8_large_expanded : benefit(2) {
+ %elemtypes = pdl.attribute = [f8E4M3FNUZ, f8E4M3FNUZ, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 256 == 0, K % 128 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c128 = pdl.attribute = 128
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // N >= 1024, K >= 512
+ %c512 = pdl.attribute = 512
+ %c1024 = pdl.attribute = 1024
+
+ // TODO: Kernel specialization is needed to apply this strategy selectively at
+ // runtime. Additionally model exports don't specify lower bounds so it is
+ // impossible to use this strategy with this check.
+ // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c512, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_f8_expanded", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{
+ workgroup = [1, 256, 256, 0]
+ }>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f8.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
+
+// F16 Patterns
+
+// This pattern matches a large f16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_f16_large : benefit(1) {
+ %elemtypes = pdl.attribute = [f16, f16, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 256 == 0, K % 64 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c64 = pdl.attribute = 64
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // M, N >= 1024, K >= 256
+ %c1024 = pdl.attribute = 1024
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c1, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_f16", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [256, 256, 0]}>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f16.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
+
+// This pattern matches a medium-sized f16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_f16_medium_expanded : benefit(1) {
+ %elemtypes = pdl.attribute = [f16, f16, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 128 == 0, K % 64 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c64 = pdl.attribute = 64
+ %c128 = pdl.attribute = 128
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // M, N >= 1024, K >= 256
+ %c1024 = pdl.attribute = 1024
+
+ // TODO: Kernel specialization is needed to apply this strategy selectively at
+ // runtime. Additionally model exports don't specify lower bounds so it is
+ // impossible to use this strategy with this check.
+ // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_medium_f16_expanded", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f16.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
+
+// This pattern matches a medium-sized f16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes. This is preferred over the
+// medium-sized ukernel.
+pdl.pattern @annotate_matmul_like_f16_large_expanded : benefit(2) {
+ %elemtypes = pdl.attribute = [f16, f16, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 256 == 0, K % 64 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c64 = pdl.attribute = 64
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // M, N >= 1024, K >= 256
+ %c1024 = pdl.attribute = 1024
+
+ // TODO: Kernel specialization is needed to apply this strategy selectively at
+ // runtime. Additionally model exports don't specify lower bounds so it is
+ // impossible to use this strategy with this check.
+ // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_f16_expanded", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_f16.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
+
+// BF16 Patterns
+
+// This pattern matches a bf16 matmul-like operation and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_bf16_large : benefit(1) {
+ %elemtypes = pdl.attribute = [bf16, bf16, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 256 == 0, K % 64 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c64 = pdl.attribute = 64
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // M, N >= 1024, K >= 256
+ %c1024 = pdl.attribute = 1024
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c1, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [256, 256, 0]}>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_bf16.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
+
+// This pattern matches an expanded bf16 matmul-like operation of medium size and annotates it
+// with ukernel descriptor and configuration attributes.
+pdl.pattern @annotate_matmul_like_bf16_medium_expanded : benefit(1) {
+ %elemtypes = pdl.attribute = [bf16, bf16, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 128 == 0, K % 64 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c64 = pdl.attribute = 64
+ %c128 = pdl.attribute = 128
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c128 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // M, N >= 1024, K >= 256
+ %c4 = pdl.attribute = 4
+ %c512 = pdl.attribute = 512
+ %c1024 = pdl.attribute = 1024
+
+ // TODO: Kernel specialization is needed to apply this strategy selectively at
+ // runtime. Additionally model exports don't specify lower bounds so it is
+ // impossible to use this strategy with this check.
+ // pdl.apply_native_constraint "dimIsBound"(%lhs, %c0, %c4, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_medium_bf16_expanded", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 128, 256, 0]}>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_bf16.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
+
+// This pattern matches an expanded bf16 matmul-like operation of large size and annotates it
+// with ukernel descriptor and configuration attributes. This is preferred over the medium
+// strategy.
+pdl.pattern @annotate_matmul_like_bf16_large_expanded : benefit(2) {
+ %elemtypes = pdl.attribute = [bf16, bf16, f32]
+ %imaps = pdl.attribute = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+
+ %lhs_type = pdl.type
+ %rhs_type = pdl.type
+ %out_type = pdl.type
+
+ %lhs = pdl.operand : %lhs_type
+ %rhs = pdl.operand : %rhs_type
+ %out_init = pdl.operand : %out_type
+
+ // Match the a matmul-like generic with above indexing maps.
+ %generic_op = pdl.operation (%lhs, %rhs, %out_init : !pdl.value, !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ pdl.apply_native_constraint "matchContraction"(
+ %generic_op, %elemtypes, %imaps
+ : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %attr_name = pdl.attribute = "iree_codegen.ukernel"
+ pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
+
+ // M % 256 == 0, K % 64 == 0, N % 256 == 0
+ %empty = pdl.attribute = {}
+ %c0 = pdl.attribute = 0
+ %c1 = pdl.attribute = 1
+ %c2 = pdl.attribute = 2
+ %c64 = pdl.attribute = 64
+ %c256 = pdl.attribute = 256
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c1, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%lhs, %c2, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c0, %c256 : !pdl.value, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsMultipleOf"(%rhs, %c1, %c64 : !pdl.value, !pdl.attribute, !pdl.attribute)
+
+ // M, N >= 1024, K >= 256
+ %c1024 = pdl.attribute = 1024
+ pdl.apply_native_constraint "dimIsBound"(%lhs, %c2, %c256, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+ pdl.apply_native_constraint "dimIsBound"(%rhs, %c0, %c1024, %empty : !pdl.value, !pdl.attribute, !pdl.attribute, !pdl.attribute)
+
+ pdl.rewrite {
+ // Call the C++ "annotateOperation" utility to add the attributes to the matched linalg.generic op.
+ // This modifies the operation in-place.
+
+ %annotation = pdl.attribute = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %attr_name, %annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %config_name = pdl.attribute = "compilation_info"
+ %config = pdl.attribute = #iree_codegen.compilation_info<
+ lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 256, 256, 0]}>,
+ translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
+ workgroup_size = [512, 1, 1] subgroup_size = 64,
+ // This strategy uses the maximum amount of possible shared memory on
+ // all gfx942 architectures so shared memory padding to reduce bank
+ // conflicts must be disabled. Also prefetching is done manually in the
+ // above and is disabled here as well.
+ {gpu_pipeline_options =
+ #iree_gpu.pipeline_options<
+ prefetch_shared_memory = false,
+ no_reduce_shared_memory_bank_conflicts = true>,
+ // This strategy requires 2 waves per SIMD.
+ llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}>
+ >
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %config_name, %config : !pdl.operation, !pdl.attribute, !pdl.attribute)
+
+ %builtin_attr = pdl.attribute = "rocm.builtin_name"
+ %builtin_annotation = pdl.attribute = "iree_uk_amdgpu_matmul_bf16.mlir"
+ pdl.apply_native_rewrite "annotateOperation"(%generic_op, %builtin_attr, %builtin_annotation : !pdl.operation, !pdl.attribute, !pdl.attribute)
+ }
+}
diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
index 4438517..8ac410f 100644
--- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h"
#include "iree/compiler/Utils/EquivalenceUtils.h"
+#include "iree/compiler/Utils/ShapeUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -33,32 +34,6 @@
// MatchCastCompatibleDagFromRootOp
//===----------------------------------------------------------------------===//
-static bool isCastableToTensorType(Type from, RankedTensorType to) {
- auto tensorType = dyn_cast<RankedTensorType>(from);
- if (!tensorType) {
- return false;
- }
- if (tensorType.getRank() != to.getRank()) {
- return false;
- }
- if (tensorType.getElementType() != to.getElementType()) {
- return false;
- }
- for (auto [fromSize, toSize] :
- llvm::zip_equal(tensorType.getShape(), to.getShape())) {
- // If the target dimension is dynamic we can always cast to it.
- if (ShapedType::isDynamic(toSize)) {
- continue;
- }
- // Casting a dynamic dimension to a static one is never valid, and static
- // sizes must always match.
- if (toSize != fromSize) {
- return false;
- }
- }
- return true;
-}
-
// Compares the regions between two operations in lockstep for equality.
static DiagnosedSilenceableFailure
compareOperationRegions(transform::TransformOpInterface transformOp,
diff --git a/compiler/src/iree/compiler/Utils/ShapeUtils.cpp b/compiler/src/iree/compiler/Utils/ShapeUtils.cpp
index 197a95d..e0d3c45 100644
--- a/compiler/src/iree/compiler/Utils/ShapeUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/ShapeUtils.cpp
@@ -38,4 +38,30 @@
return numNonmatchingSSADims <= 1;
}
+bool isCastableToTensorType(Type from, RankedTensorType to) {
+ auto tensorType = dyn_cast<RankedTensorType>(from);
+ if (!tensorType) {
+ return false;
+ }
+ if (tensorType.getRank() != to.getRank()) {
+ return false;
+ }
+ if (tensorType.getElementType() != to.getElementType()) {
+ return false;
+ }
+ for (auto [fromSize, toSize] :
+ llvm::zip_equal(tensorType.getShape(), to.getShape())) {
+ // If the target dimension is dynamic we can always cast to it.
+ if (ShapedType::isDynamic(toSize)) {
+ continue;
+ }
+ // Casting a dynamic dimension to a static one is never valid, and static
+ // sizes must always match.
+ if (toSize != fromSize) {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Utils/ShapeUtils.h b/compiler/src/iree/compiler/Utils/ShapeUtils.h
index 08b7e12..2fc754b 100644
--- a/compiler/src/iree/compiler/Utils/ShapeUtils.h
+++ b/compiler/src/iree/compiler/Utils/ShapeUtils.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_UTILS_SHAPEUTILS_H_
#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
namespace mlir::iree_compiler {
@@ -21,6 +22,9 @@
bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
ShapedType rhsType, ValueRange rhsDynamicDims);
+/// Helper to check whether 'from' is castable to the target ranked tensor type.
+bool isCastableToTensorType(Type from, RankedTensorType to);
+
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_UTILS_SHAPEUTILS_H_
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index a5e1cda..464b587 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1707,6 +1707,66 @@
"requires-gpu-cdna3"
)
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_tensor_ukernel_f16f16f32_large
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f32"
+ "--shapes=custom_mnk"
+ "--mnk=1024,1024,1024"
+ "--transpose_rhs"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ "--iree-hip-enable-tensor-ukernels"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_tensor_ukernel_bf16bf16f32_large
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=f32"
+ "--shapes=custom_mnk"
+ "--mnk=1024,1024,1024"
+ "--transpose_rhs"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ "--iree-hip-enable-tensor-ukernels"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
endif()
elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11")