CDNA1/2 data tiling (#19100)
CDNA1/2 machines are going to be in use for a while, and adding data
tiling support for an architecture is just a matter of populating those
3 optional fields in `TargetWgpDetails`, and the corresponding
`gpu_materialize_encoding_gfx***.mlir` test adds some coverage around
the intrinsics that is useful beyond data tiling.
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 608ff2e..12431f0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -30,6 +30,8 @@
"gpu_infer_memory_space.mlir",
"gpu_lower_to_ukernels.mlir",
"gpu_combine_value_barriers.mlir",
+ "gpu_materialize_encoding_gfx908.mlir",
+ "gpu_materialize_encoding_gfx90a.mlir",
"gpu_materialize_encoding_gfx942.mlir",
"gpu_materialize_encoding_gfx1100.mlir",
"gpu_nested_layout_contract_amdgpu.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index ba0a751..ae65754 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -27,6 +27,8 @@
"gpu_infer_memory_space.mlir"
"gpu_lower_to_ukernels.mlir"
"gpu_materialize_encoding_gfx1100.mlir"
+ "gpu_materialize_encoding_gfx908.mlir"
+ "gpu_materialize_encoding_gfx90a.mlir"
"gpu_materialize_encoding_gfx942.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx908.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx908.mlir
new file mode 100644
index 0000000..f2ad507
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx908.mlir
@@ -0,0 +1,60 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
+// RUN: --iree-gpu-test-target=gfx908 \
+// RUN: --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#pipeline_layout_3 = #hal.pipeline.layout<constants = 3, bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @matmul_lowering_MFMA_i32_16x16x16_i8() {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
+ %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
+ %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xi8, #encoding_lhs>>{%M, %K}
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xi8, #encoding_rhs>>{%K, %N}
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xi32, #encoding_result>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xi8, #encoding_lhs>>{%M, %K}
+ -> tensor<?x?xi8, #encoding_lhs>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xi8, #encoding_rhs>>{%K, %N}
+ -> tensor<?x?xi8, #encoding_rhs>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xi32, #encoding_result>>{%M, %N}
+ -> tensor<?x?xi32, #encoding_result>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xi8, #encoding_lhs>,
+ tensor<?x?xi8, #encoding_rhs>)
+ outs(%5 : tensor<?x?xi32, #encoding_result>)
+ -> tensor<?x?xi32, #encoding_result>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xi32, #encoding_result>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xi32, #encoding_result>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: func.func @matmul_lowering_MFMA_i32_16x16x16_i8
+// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
+// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
+// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x4x4xi8>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4x4xi8>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x2x4x16x4xi32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x16_I8, unroll_m = 4, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
+// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx90a.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx90a.mlir
new file mode 100644
index 0000000..1a90859
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx90a.mlir
@@ -0,0 +1,119 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" \
+// RUN: --iree-gpu-test-target=gfx90a \
+// RUN: --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#pipeline_layout_3 = #hal.pipeline.layout<constants = 3, bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @matmul_lowering_MFMA_f32_16x16x8_bf16() {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
+ %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
+ %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #encoding_lhs>>{%M, %K}
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #encoding_rhs>>{%K, %N}
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #encoding_lhs>>{%M, %K}
+ -> tensor<?x?xbf16, #encoding_lhs>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xbf16, #encoding_rhs>>{%K, %N}
+ -> tensor<?x?xbf16, #encoding_rhs>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
+ -> tensor<?x?xf32, #encoding_result>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xbf16, #encoding_lhs>,
+ tensor<?x?xbf16, #encoding_rhs>)
+ outs(%5 : tensor<?x?xf32, #encoding_result>)
+ -> tensor<?x?xf32, #encoding_result>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xf32, #encoding_result>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: func.func @matmul_lowering_MFMA_f32_16x16x8_bf16
+// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
+// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
+// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x4x2xbf16>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4x2xbf16>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x2x4x16x4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x8_BF16, unroll_m = 4, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
+// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f64, f64, f64], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f64, f64, f64], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f64, f64, f64], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
+#pipeline_layout_3 = #hal.pipeline.layout<constants = 3, bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @matmul_lowering_MFMA_f64_16x16x4_f64() {
+ %c0 = arith.constant 0 : index
+ %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
+ %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
+ %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf64, #encoding_lhs>>{%M, %K}
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf64, #encoding_rhs>>{%K, %N}
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf64, #encoding_result>>{%M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf64, #encoding_lhs>>{%M, %K}
+ -> tensor<?x?xf64, #encoding_lhs>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?xf64, #encoding_rhs>>{%K, %N}
+ -> tensor<?x?xf64, #encoding_rhs>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?xf64, #encoding_result>>{%M, %N}
+ -> tensor<?x?xf64, #encoding_result>
+ %6 = linalg.matmul
+ ins(%3, %4 : tensor<?x?xf64, #encoding_lhs>,
+ tensor<?x?xf64, #encoding_rhs>)
+ outs(%5 : tensor<?x?xf64, #encoding_result>)
+ -> tensor<?x?xf64, #encoding_result>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+ : tensor<?x?xf64, #encoding_result>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?xf64, #encoding_result>>{%M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: func.func @matmul_lowering_MFMA_f64_16x16x4_f64
+// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
+// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
+// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x2xf64>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x2xf64>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x4x4x16xf64>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F64_16x16x4_F64, unroll_m = 4, unroll_n_to_subgroups = 4, unroll_k = 2>
+// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index e198e21..1ce6f12 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -189,7 +189,10 @@
{1024, 1024, 1024},
1024,
64 * 1024,
- {0x7fffffff, 0x7fffffff, 0x7fffffff}};
+ {0x7fffffff, 0x7fffffff, 0x7fffffff},
+ /*maxLoadInstructionBits=*/128,
+ /*simdsPerWgp=*/4,
+ /*vgprSpaceBits=*/256 * 32};
return &cdna2Wgp;
}
@@ -209,7 +212,10 @@
{1024, 1024, 1024},
1024,
64 * 1024,
- {0x7fffffff, 0x7fffffff, 0x7fffffff}};
+ {0x7fffffff, 0x7fffffff, 0x7fffffff},
+ /*maxLoadInstructionBits=*/128,
+ /*simdsPerWgp=*/4,
+ /*vgprSpaceBits=*/256 * 32};
return &cdna1Wgp;
}