Extend TD strategy to support batched matmul (#14292)
Extend the transform dialect strategy for matmul to also support batched
matmul. This is guarded by a different flag and is *disabled by
default*.
The extension maps the batch dimension to blocks/threads along the Z
axis. Good default values will come separately after an end-to-end
experimentation.
Co-authored-by: Nicolas Vasilache <ntv@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 692e632..e75c01c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -19,6 +19,7 @@
srcs = enforce_glob(
[
"affinemin_canonicalization.mlir",
+ "batch_matmuls.mlir",
"bubble_up_ordinal_ops.mlir",
"bufferize_copy_only_dispatches.mlir",
"canonicalize_interface_load_store.mlir",
@@ -64,6 +65,7 @@
],
include = ["*.mlir"],
exclude = [
+ "batch_matmul_match_spec.mlir",
"convolution_match_spec.mlir",
"reductions_codegen_spec.mlir",
"reductions_match_spec.mlir",
@@ -73,6 +75,7 @@
# transform dialect spec files are MLIR files that specify a transformation,
# they need to be included as data.
data = [
+ "batch_matmul_match_spec.mlir",
"convolution_match_spec.mlir",
"reductions_codegen_spec.mlir",
"reductions_match_spec.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 87e8c56..b76b7ae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"affinemin_canonicalization.mlir"
+ "batch_matmuls.mlir"
"bubble_up_ordinal_ops.mlir"
"bufferize_copy_only_dispatches.mlir"
"canonicalize_interface_load_store.mlir"
@@ -61,6 +62,7 @@
FileCheck
iree-opt
DATA
+ batch_matmul_match_spec.mlir
convolution_match_spec.mlir
reductions_codegen_spec.mlir
reductions_match_spec.mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/batch_matmul_match_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmul_match_spec.mlir
new file mode 100644
index 0000000..302aabd
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmul_match_spec.mlir
@@ -0,0 +1,9 @@
+// RUN: iree-opt %s
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ transform.iree.register_match_callbacks
+ %0:2 = transform.iree.match_callback failures(propagate) "batch_matmul"(%arg0) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.emit_remark "fill" at %0#0 : !transform.any_op
+ transform.iree.emit_remark "batch matmul" at %0#1 : !transform.any_op
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/batch_matmuls.mlir b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmuls.mlir
new file mode 100644
index 0000000..bfeb9a6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmuls.mlir
@@ -0,0 +1,71 @@
+// RUN: iree-opt %s --iree-transform-dialect-interpreter='transform-file-name=%p/batch_matmul_match_spec.mlir' --split-input-file --verify-diagnostics
+
+!lhs = tensor<128x80x32xf32>
+!rhs = tensor<128x32x320xf32>
+!res = tensor<128x80x320xf32>
+
+func.func @batch_matmul(%arg0: !lhs, %arg1: !rhs, %arg2: !res) -> !res {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : !res
+ // expected-remark @below {{fill}}
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : !res) -> !res
+ // expected-remark @below {{batch matmul}}
+ %2 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+ } ins(%arg0, %arg1 : !lhs, !rhs) outs(%1 : !res) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %3 = arith.mulf %arg3, %arg4 : f32
+ %4 = arith.addf %arg5, %3 : f32
+ linalg.yield %4 : f32
+ } -> !res
+ return %2 : !res
+}
+
+// -----
+
+!lhs = tensor<128x80x32xf32>
+!rhs = tensor<128x32x320xf32>
+!res = tensor<128x80x320xf32>
+
+func.func @batch_matmul(%arg0: !lhs, %arg1: !rhs, %arg2: !res) -> !res {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : !res
+ // expected-remark @below {{fill}}
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : !res) -> !res
+ // expected-remark @below {{batch matmul}}
+ %2 = linalg.batch_matmul ins(%arg0, %arg1 : !lhs, !rhs) outs(%1 : !res) -> !res
+ return %2 : !res
+}
+
+// -----
+
+!lhs = tensor<80x128x32xf32>
+!rhs = tensor<128x32x320xf32>
+!res = tensor<80x320x128xf32>
+
+func.func @batch_matmul(%arg0: !lhs, %arg1: !rhs, %arg2: !res) -> !res {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : !res
+ // expected-remark @below {{fill}}
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : !res) -> !res
+ // expected-remark @below {{batch matmul}}
+ %2 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+ } ins(%arg0, %arg1 : !lhs, !rhs) outs(%1 : !res) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %3 = arith.mulf %arg3, %arg4 : f32
+ %4 = arith.addf %arg5, %3 : f32
+ linalg.yield %4 : f32
+ } -> !res
+ return %2 : !res
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index ec16caa..91f70d3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -35,6 +35,7 @@
"reduction_pipeline.mlir",
"rocdl_pipeline_test.mlir",
"set_transform_strategy.mlir",
+ "set_transform_strategy_batch_matmul.mlir",
"set_transform_strategy_pad.mlir",
"illegal_configuration.mlir",
"layout_analysis_and_distribution.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 17c7ef6..731adcc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -37,6 +37,7 @@
"reduction_pipeline_transform.mlir"
"rocdl_pipeline_test.mlir"
"set_transform_strategy.mlir"
+ "set_transform_strategy_batch_matmul.mlir"
"set_transform_strategy_pad.mlir"
"tensor_pad.mlir"
"tensorcore_vectorization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir
new file mode 100644
index 0000000..38fe1ba
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir
@@ -0,0 +1,231 @@
+// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" \
+// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=1 --iree-codegen-llvmgpu-enable-transform-dialect-batch-matmul-strategy |\
+// RUN: FileCheck %s --check-prefixes=CHECK,DEFAULT
+
+// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" \
+// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=1 --iree-codegen-llvmgpu-enable-transform-dialect-batch-matmul-strategy \
+// RUN: -td-matmul-strategy-blk-sizes=128,64,32,2 \
+// RUN: -td-matmul-strategy-reduc-size=8 \
+// RUN: -td-matmul-strategy-num-threads=32,4,1 \
+// RUN: -td-matmul-strategy-num-warps=1,4,1 \
+// RUN: -td-matmul-strategy-use-async-copies=true \
+// RUN: -td-matmul-strategy-pipeline-depth=3 \
+// RUN: -td-matmul-strategy-use-mma-sync=false \
+// RUN: -td-matmul-strategy-use-fma=true \
+// RUN: | FileCheck %s --check-prefixes=CHECK,OPTIONS
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
+#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}>
+module attributes {hal.device.targets = [#device_target_cuda]} {
+ hal.executable private @batch_matmul_dispatch_0 {
+ hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.export public @batch_matmul_dispatch_0_generic_128x80x320x32_f32 ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @batch_matmul_dispatch_0_generic_128x80x320x32_f32() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x80x32xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x32x320xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<128x80x320xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [128, 80, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<128x80x32xf32>> -> tensor<128x80x32xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [128, 32, 320], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<128x32x320xf32>> -> tensor<128x32x320xf32>
+ %5 = tensor.empty() : tensor<128x80x320xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32>
+ %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<128x80x32xf32>, tensor<128x32x320xf32>) outs(%6 : tensor<128x80x320xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %8 = arith.mulf %in, %in_0 : f32
+ %9 = arith.addf %out, %8 : f32
+ linalg.yield %9 : f32
+ } -> tensor<128x80x320xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [128, 80, 320], strides = [1, 1, 1] : tensor<128x80x320xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x80x320xf32>>
+ return
+ }
+ }
+ }
+ }
+ func.func @batch_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+ %c1310720 = arith.constant 1310720 : index
+ %c5242880 = arith.constant 5242880 : index
+ %c13107200 = arith.constant 13107200 : index
+ %c0 = arith.constant 0 : index
+ %c320 = arith.constant 320 : index
+ %c553648160_i32 = arith.constant 553648160 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c128 = arith.constant 128 : index
+ %c80 = arith.constant 80 : index
+ %c32 = arith.constant 32 : index
+ hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input 0") shape([%c128, %c80, %c32]) type(%c553648160_i32) encoding(%c1_i32)
+ %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<128x80x32xf32> in !stream.resource<external>{%c1310720}
+ hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("input 1") shape([%c128, %c32, %c320]) type(%c553648160_i32) encoding(%c1_i32)
+ %1 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<128x32x320xf32> in !stream.resource<external>{%c5242880}
+ %2 = stream.resource.alloc uninitialized : !stream.resource<external>{%c13107200}
+ %3 = stream.cmd.execute with(%0 as %arg3: !stream.resource<external>{%c1310720}, %1 as %arg4: !stream.resource<external>{%c5242880}, %2 as %arg5: !stream.resource<external>{%c13107200}) {
+ stream.cmd.dispatch @batch_matmul_dispatch_0::@cuda_nvptx_fb::@batch_matmul_dispatch_0_generic_128x80x320x32_f32 {
+ ro %arg3[%c0 for %c1310720] : !stream.resource<external>{%c1310720},
+ ro %arg4[%c0 for %c5242880] : !stream.resource<external>{%c5242880},
+ wo %arg5[%c0 for %c13107200] : !stream.resource<external>{%c13107200}
+ } attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]}
+ } => !stream.timepoint
+ %4 = stream.timepoint.await %3 => %2 : !stream.resource<external>{%c13107200}
+ %5 = stream.tensor.export %4 : tensor<128x80x320xf32> in !stream.resource<external>{%c13107200} -> !hal.buffer_view
+ return %5 : !hal.buffer_view
+ }
+}
+
+
+// CHECK: transform.sequence failures(propagate) {
+// CHECK: transform.iree.register_match_callbacks
+// CHECK: %[[MATCH:.+]]:2 = transform.iree.match_callback failures(propagate) "batch_matmul"
+// CHECK: %[[FORALL:.+]], %[[TILED:.+]] = transform.structured.tile_to_forall_op %[[MATCH]]#1
+// DEFAULT: num_threads [] tile_sizes [64, 64, 1](mapping = [#gpu.block<z>, #gpu.block<y>, #gpu.block<x>])
+// OPTIONS: num_threads [] tile_sizes [128, 64, 32](mapping = [#gpu.block<z>, #gpu.block<y>, #gpu.block<x>])
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: %[[FUSED:.+]], %[[CONTAINING:.+]] = transform.structured.fuse_into_containing_op %[[MATCH]]#0 into %[[FORALL]]
+// CHECK: transform.iree.populate_workgroup_count_region_using_num_threads_slice %[[FORALL]]
+// CHECK: %[[TILED_LINALG:.+]], %[[LOOPS:.+]] = transform.structured.tile %tiled_op
+// DEFAULT: [0, 0, 0, 16]
+// OPTIONS: [0, 0, 0, 8]
+// CHECK: %[[PADDED:.+]] = transform.structured.pad %tiled_linalg_op
+// CHECK: pack_paddings = [1, 1, 1, 1], pad_to_multiple_of = [1, 1, 1, 1], padding_dimensions = [0, 1, 2, 3]
+// CHECK: padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
+// CHECK: %[[V3:.+]] = get_producer_of_operand %[[PADDED]][2]
+// CHECK: transform.structured.hoist_pad %{{.*}} by 1 loops
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: %[[FILL:.+]] = transform.structured.match ops{["linalg.fill"]}
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.structured.match ops{["tensor.parallel_insert_slice"]}
+// CHECK: transform.structured.insert_slice_to_copy
+// CHECK: %[[LHS:.+]] = get_producer_of_operand %[[PADDED]][0]
+// CHECK: %[[RHS:.+]] = get_producer_of_operand %[[PADDED]][1]
+// CHECK: %[[RHS_DPS:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RHS]]
+
+// CHECK: transform.structured.tile_to_forall_op %[[LHS]]
+// DEFAULT: num_threads [1, 32, 4] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// OPTIONS: num_threads [1, 64, 2] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.structured.match ops{["scf.if"]}
+// CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch
+
+// CHECK: transform.structured.tile_to_forall_op %[[RHS_DPS]]
+// DEFAULT: num_threads [8, 16, 1] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// OPTIONS: num_threads [2, 8, 8] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+
+// CHECK: transform.structured.tile_to_forall_op
+// DEFAULT: num_threads [2, 64, 1] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// OPTIONS: num_threads [1, 16, 8] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+
+// CHECK: transform.structured.tile_to_forall_op
+// DEFAULT: num_threads [1, 2, 64] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// OPTIONS: num_threads [1, 4, 32] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+
+// CHECK: %forall_op_8, %tiled_op_9 = transform.structured.tile_to_forall_op %[[FILL]]
+// DEFAULT: num_threads [1, 2, 64] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// OPTIONS: num_threads [1, 4, 32] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+
+// CHECK: transform.structured.masked_vectorize
+// DEFAULT: vector_sizes [64, 2, 4]
+// OPTIONS: vector_sizes [128, 1, 4]
+// CHECK: transform.structured.masked_vectorize
+// DEFAULT: vector_sizes [32, 1, 1]
+// OPTIONS: vector_sizes [128, 4, 4]
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.vector.lower_masked_transfers
+// CHECK: transform.structured.vectorize
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.canonicalization
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.iree.eliminate_empty_tensors
+
+// CHECK: transform.iree.bufferize {target_gpu}
+// CHECK: transform.iree.erase_hal_descriptor_type_from_memref
+// CHECK: transform.iree.apply_buffer_optimizations
+// CHECK: transform.iree.forall_to_workgroup
+// CHECK: transform.iree.map_nested_forall_to_gpu_threads
+// DEFAULT: workgroup_dims = [64, 2, 1] warp_dims = [2, 2, 1]
+// OPTIONS: workgroup_dims = [32, 4, 1] warp_dims = [1, 4, 1]
+// CHECK: transform.iree.eliminate_gpu_barriers
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.iree.hoist_static_alloc
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.memref.fold_memref_alias_ops
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.memref.extract_address_computations
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.linalg.tiling_canonicalization
+// CHECK: transform.apply_patterns.iree.fold_fill_into_pad
+// CHECK: transform.apply_patterns.scf.for_loop_canonicalization
+// CHECK: transform.apply_patterns.canonicalization
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.iree.synchronize_loop
+// CHECK: transform.structured.hoist_redundant_vector_transfers
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.iree.apply_buffer_optimizations
+// CHECK: %30 = transform.iree.eliminate_gpu_barriers
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.memref.fold_memref_alias_ops
+
+// CHECK: transform.memref.multibuffer
+// DEFAULT: factor = 2
+// OPTIONS: factor = 3
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.iree.create_async_groups
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
+// CHECK: transform.iree.pipeline_shared_memory_copies
+// DEFAULT: depth = 2
+// OPTIONS: depth = 3
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.vector.lower_masks
+// CHECK: apply_patterns
+// CHECK: transform.apply_patterns.vector.materialize_masks
+// CHECK: apply_patterns
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
index 364e1ec..c3a957b 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -104,7 +105,8 @@
MLIRContext *ctx, int totalNumThreads, int64_t alignment,
ArrayRef<int64_t> copySizes, bool favorPredication,
int64_t elementalBitWidth) {
- assert(copySizes.size() == 2 && "only 2-D copy supported for now");
+ assert(!copySizes.empty() && copySizes.size() <= 3 &&
+ "only 1,2,3-D copies are supported for now");
FailureOr<CopyMapping> maybeCopyMapping =
CopyMapping::numThreadsForCopy(totalNumThreads, alignment, copySizes,
favorPredication, elementalBitWidth);
@@ -117,18 +119,24 @@
elementalBitWidth);
}
assert(succeeded(maybeCopyMapping) && "failed to compute copy mapping");
- assert(maybeCopyMapping->numThreads.size() == 2 &&
- "compute copy mapping expected size-2");
- int64_t numThreadsY = maybeCopyMapping->numThreads[0];
- int64_t numThreadsX = maybeCopyMapping->numThreads[1];
- int64_t sizeY = copySizes[0];
- int64_t sizeX = copySizes[1];
- MappingInfo res{
- /*numThreads=*/{numThreadsY, numThreadsX},
- /*tilecopySizes=*/
- {mlir::ceilDiv(sizeY, numThreadsY), mlir::ceilDiv(sizeX, numThreadsX)},
- /*threadMapping=*/{linearIdY(ctx), linearIdX(ctx)},
- /*vectorSize=*/maybeCopyMapping->vectorSize};
+ assert(maybeCopyMapping->numThreads.size() == copySizes.size() &&
+ "compute copy mapping expected same number of threads and copy sizes");
+
+ SmallVector<int64_t> tileSizes = llvm::to_vector(llvm::map_range(
+ llvm::zip(copySizes, maybeCopyMapping->numThreads), [](auto &&pair) {
+ int64_t size, numThreads;
+ std::tie(size, numThreads) = pair;
+ return mlir::ceilDiv(size, numThreads);
+ }));
+ SmallVector<Attribute> allThreadMappings{linearIdZ(ctx), linearIdY(ctx),
+ linearIdX(ctx)};
+ auto threadMapping =
+ llvm::to_vector(ArrayRef(allThreadMappings).take_back(tileSizes.size()));
+
+ MappingInfo res{/*numThreads=*/maybeCopyMapping->numThreads,
+ /*tilecopySizes=*/tileSizes,
+ /*threadMapping=*/threadMapping,
+ /*vectorSize=*/maybeCopyMapping->vectorSize};
LLVM_DEBUG(res.print(DBGS()); llvm::dbgs() << "\n");
return res;
}
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
index 59956ba..f92f8f2 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -36,6 +37,7 @@
using iree_compiler::buildTileFuseDistToForallWithNumThreads;
using iree_compiler::buildTileFuseDistToForallWithTileSizes;
using iree_compiler::TileToForallAndFuseAndDistributeResult;
+using iree_compiler::gpu::BatchMatmulStrategy;
using iree_compiler::gpu::buildBufferize;
using iree_compiler::gpu::buildConvertToAsyncCopies;
using iree_compiler::gpu::buildConvertToTensorCoreOp;
@@ -97,6 +99,46 @@
return success();
}
+LogicalResult BatchMatmulStrategy::validate(const GPUModel &gpuModel) const {
+ if (failed(MatmulStrategy::validate(gpuModel))) {
+ return failure();
+ }
+
+ if (batch() < blockTileBatch()) {
+ return emitError(UnknownLoc::get(ctx))
+ << "batch( " << batch() << ") < blockTileBatch(" << blockTileBatch()
+ << ") this is at risk of not vectorizing and is NYI";
+ }
+
+ // Only single outermost batch dimension is currently supported.
+ if (captures.batches().size() != 1 || captures.batches().back() != 0) {
+ LDBG("--Couldn't find single outermost batch dimension\n");
+ return failure();
+ }
+
+ if (blockTileSizes.size() < 3) {
+ LDBG("--Not enough block tile sizes\n");
+ return failure();
+ }
+
+ if (numWarps.size() < 3) {
+ LDBG("--Not enough num warps\n");
+ return failure();
+ }
+
+ if (numThreads.size() < 3) {
+ LDBG("--Not enough num threads\n");
+ return failure();
+ }
+
+ if (!useFma) {
+ LDBG("--Only FMA is supported for batch matmul atm\n");
+ return failure();
+ }
+
+ return success();
+}
+
static std::tuple<Value, Value, Value, Value>
buildMatmulStrategyBlockDistribution(ImplicitLocOpBuilder &b, Value variantH,
const MatmulStrategy &strategy) {
@@ -130,22 +172,26 @@
tileResult.tiledOpH, Value(), tileResult.forallH);
}
-void iree_compiler::gpu::buildMatmulTensorCoreStrategy(
- ImplicitLocOpBuilder &b, Value variantH, const MatmulStrategy &strategy) {
- LLVM_DEBUG(strategy.print(DBGS()));
+/// Builds the common part of the schedule for matmuls and batched matmuls.
+static void
+buildCommonMatmulLikeThreadSchedule(ImplicitLocOpBuilder &b, Value variantH,
+ Value fillH, Value matmulH,
+ const MatmulStrategy &strategy) {
+ using mlir::iree_compiler::buildLowerVectorMasksAndCleanup;
+ using mlir::iree_compiler::buildTileFuseToScfFor;
+ using namespace mlir::iree_compiler::gpu;
- // Step 1. Apply block-level part of the strategy, keeps everything fused.
- auto [fillH, matmulH, maybeTiledTrailingHBlock, forall] =
- buildMatmulStrategyBlockDistribution(b, variantH, strategy);
- // Tile reduction loop.
- SmallVector<int64_t> tileSizes{0, 0, strategy.reductionTileSize};
+ // Tile the reduction loop (last in the list).
+ SmallVector<int64_t> tileSizes(strategy.captures.matmulOpSizes.size() - 1, 0);
+ tileSizes.push_back(strategy.reductionTileSize);
+
// Avoid canonicalizing before the pad to avoid folding away the extract_slice
// on the output needed to hoist the output pad.
auto tileReductionResult = buildTileFuseToScfFor(
b, variantH, matmulH, {}, getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)),
/*canonicalize=*/false);
- // Step 2. Pad the matmul op.
+ // Step 2. Pad the (batch) matmul op.
auto paddedMatmulOpH =
buildPad(b, tileReductionResult.tiledOpH,
strategy.getZeroPadAttrFromElementalTypes(b).getValue(),
@@ -219,3 +265,53 @@
// Step 13. Late lowerings and cleanups.
buildLowerVectorMasksAndCleanup(b, funcH);
}
+
+void iree_compiler::gpu::buildMatmulTensorCoreStrategy(
+ ImplicitLocOpBuilder &b, Value variantH, const MatmulStrategy &strategy) {
+ LLVM_DEBUG(strategy.print(DBGS()));
+
+ // Step 1. Apply block-level part of the strategy, keeps everything fused.
+ auto [fillH, matmulH, maybeTiledTrailingHBlock, forall] =
+ buildMatmulStrategyBlockDistribution(b, variantH, strategy);
+ buildCommonMatmulLikeThreadSchedule(b, variantH, fillH, matmulH, strategy);
+}
+
+/// Builds the transform dialect operations distributing batch matmul across
+/// blocks according to the given strategy.
+static std::tuple<Value, Value, Value>
+buildBatchMatmulStrategyBlockDistribution(ImplicitLocOpBuilder &b,
+ Value variantH,
+ const BatchMatmulStrategy &strategy) {
+ b.create<RegisterMatchCallbacksOp>();
+ auto [fillH, bmmH] = unpackRegisteredMatchCallback<2>(
+ b, "batch_matmul", transform::FailurePropagationMode::Propagate,
+ variantH);
+
+ MappingInfo blockMapping = strategy.getBlockMapping();
+ TileToForallAndFuseAndDistributeResult tileResult =
+ buildTileFuseDistToForallWithTileSizes(
+ /*builder=*/b,
+ /*variantH=*/variantH,
+ /*rootH=*/bmmH,
+ /*opsToFuseH=*/fillH,
+ /*tileSizes=*/
+ getAsOpFoldResult(b.getI64ArrayAttr(blockMapping.tileSizes)),
+ /*threadDimMapping=*/
+ b.getArrayAttr(blockMapping.threadMapping));
+
+ // Handle the workgroup count region.
+ b.create<IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp>(
+ tileResult.forallH);
+ return std::make_tuple(tileResult.resultingFusedOpsHandles.front(),
+ tileResult.tiledOpH, tileResult.forallH);
+}
+
+void iree_compiler::gpu::buildBatchMatmulStrategy(
+ ImplicitLocOpBuilder &b, Value variantH,
+ const BatchMatmulStrategy &strategy) {
+ LLVM_DEBUG(strategy.print(DBGS()));
+
+ auto [fillH, matmulH, forallH] =
+ buildBatchMatmulStrategyBlockDistribution(b, variantH, strategy);
+ buildCommonMatmulLikeThreadSchedule(b, variantH, fillH, matmulH, strategy);
+}
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
index 8f2bb09..2021b18 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
@@ -172,8 +172,133 @@
/*vectorSize=*/std::nullopt};
}
- void print(llvm::raw_ostream &os) const;
- LLVM_DUMP_METHOD void dump() const;
+ void print(llvm::raw_ostream &os) const override;
+ LLVM_DUMP_METHOD void dump() const override;
+};
+
+/// An extension of the matmul strategy to batched matrix multiplications.
+class BatchMatmulStrategy : public MatmulStrategy {
+public:
+ /// Construct the default strategy, pulling options from the command-line
+ /// arguments if provided and using the defaults otherwise.
+ BatchMatmulStrategy(MLIRContext *context, const GPUModel &gpuModel,
+ const transform_ext::MatchedMatmulCaptures &captures)
+ : MatmulStrategy(context, captures, gpuModel) {
+ initDefaultValues(gpuModel);
+ }
+
+ /// Initialize the default values of the strategy.
+ void initDefaultValues(const GPUModel &gpuModel) override {
+ // First, initialize as if this was a simple matmul.
+ MatmulStrategy::initDefaultValues(gpuModel);
+
+ // Make sure we pad along all dimensions.
+ paddingDimensions = {0, 1, 2, 3};
+ packingDimensions = {1, 1, 1, 1};
+ }
+
+ /// Check that the strategy is valid for the captures and the model.
+ LogicalResult validate(const GPUModel &gpuModel) const override;
+
+ /// Named accessors to shapes.
+ int64_t batch() const { return captures.matmulOpSizes[0]; }
+ int64_t m() const override { return captures.matmulOpSizes[1]; }
+ int64_t n() const override { return captures.matmulOpSizes[2]; }
+ int64_t k() const override { return captures.matmulOpSizes[3]; }
+
+ /// Named accessors to block tile sizes associated with shapes.
+ int64_t blockTileBatch() const { return blockTileSizes[0]; }
+ int64_t blockTileM() const override { return blockTileSizes[1]; }
+ int64_t blockTileN() const override { return blockTileSizes[2]; }
+
+ /// Number of threads to use.
+ int64_t numThreadsX() const { return numThreads[0]; }
+ int64_t numThreadsY() const { return numThreads[1]; }
+ int64_t numThreadsZ() const { return numThreads[2]; }
+
+ /// Number of warps to use.
+ int64_t numWarpsX() const override { return numWarps[0]; }
+ int64_t numWarpsY() const override { return numWarps[1]; }
+ int64_t numWarpsZ() const { return numWarps[2]; }
+
+ MappingInfo getBlockMapping() const override {
+ return MappingInfo{
+ /*numThreads=*/
+ {},
+ /*tileSizes=*/{blockTileBatch(), blockTileM(), blockTileN()},
+ /*threadMapping=*/{blockZ(ctx), blockY(ctx), blockX(ctx)},
+ /*vectorSize=*/std::nullopt};
+ }
+
+ // LHS copy is batch x M x K.
+ MappingInfo lhsCopyMapping() const override {
+ // TODO: generalize to transpositions, here and below.
+ return CopyMapping::getMappingInfo(
+ ctx, totalNumThreads(), k(),
+ {blockTileBatch(), blockTileM(), reductionTileSize},
+ /*favorPredication=*/false,
+ captures.lhsElementType.getIntOrFloatBitWidth());
+ }
+
+ // RHS copy is batch x K x N.
+ MappingInfo rhsCopyMapping() const override {
+ return CopyMapping::getMappingInfo(
+ ctx, totalNumThreads(), n(),
+ {blockTileBatch(), reductionTileSize, blockTileN()},
+ /*favorPredication=*/false,
+ captures.rhsElementType.getIntOrFloatBitWidth());
+ }
+
+ // RES copy is batch x M x N.
+ MappingInfo resCopyMapping() const override {
+ return CopyMapping::getMappingInfo(
+ ctx, totalNumThreads(), n(),
+ {blockTileBatch(), blockTileM(), blockTileN()},
+ /*favorPredication=*/false,
+ captures.outputElementType.getIntOrFloatBitWidth());
+ }
+
+ /// Validates the mapping for one of the lhs, rhs or res copies.
+ LogicalResult validateCopyMapping(const MappingInfo &mapping,
+ StringRef name) const {
+ int64_t threadsUsed =
+ std::accumulate(mapping.numThreads.begin(), mapping.numThreads.end(), 1,
+ std::multiplies<int64_t>());
+ if (totalNumThreads() < threadsUsed) {
+ InFlightDiagnostic diag = emitError(UnknownLoc::get(ctx))
+ << "too many threads used for transferring "
+ << name;
+
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ llvm::interleave(mapping.numThreads, os, " * ");
+ os << " >= " << totalNumThreads();
+ diag.attachNote() << os.str();
+ return diag;
+ }
+
+ return success();
+ }
+
+ /// Check that the mapping computed for a copy is valid.
+ LogicalResult validateLhsCopyMapping() const override {
+ return validateCopyMapping(lhsCopyMapping(), "lhs");
+ }
+ LogicalResult validateRhsCopyMapping() const override {
+ return validateCopyMapping(rhsCopyMapping(), "rhs");
+ }
+ LogicalResult validateResCopyMapping() const override {
+ return validateCopyMapping(resCopyMapping(), "result");
+ }
+
+ // Compute is of the size batch x M x N.
+ MappingInfo computeMapping() const override {
+ assert(useFma && "only fma is currently supported");
+ return MappingInfo{{numThreadsZ(), numThreadsY(), numThreadsX()},
+ {},
+ {threadZ(ctx), threadY(ctx), threadX(ctx)},
+ std::nullopt};
+ }
};
} // namespace gpu
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
index 4b2dfff..2b403b6 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
@@ -63,9 +63,15 @@
llvm::cl::opt<bool> clGPUEnableTransformDialectPadStrategy(
"iree-codegen-llvmgpu-enable-transform-dialect-pad-strategy",
llvm::cl::desc("activate the pad strategy"), llvm::cl::init(false));
+llvm::cl::opt<bool> clGPUEnableTransformDialectBatchMatmulStrategy(
+ "iree-codegen-llvmgpu-enable-transform-dialect-batch-matmul-strategy",
+ llvm::cl::desc("activate the batch matmul strategy, additional "
+ "configuration flags are shared with matmul"),
+ llvm::cl::init(false));
// TODO: significantly better namespacing.
using iree_compiler::gpu::AbstractGemmLikeStrategy;
+using iree_compiler::gpu::BatchMatmulStrategy;
using iree_compiler::gpu::GPUModel;
using iree_compiler::gpu::kCudaMaxVectorLoadBitWidth;
using iree_compiler::gpu::MatmulStrategy;
@@ -337,6 +343,91 @@
return strategy;
}
+/// Update the strategy to make sure it can be consumed by the codegen. In
+/// particular, make sure that tile sizes are smaller than the problem sizes to
+/// actually trigger tiling and mapping to blocks and threads.
+static void failSafeOverrides(BatchMatmulStrategy &strategy,
+ const GPUModel &gpuModel) {
+ // Configure the strategy as if for a matmul.
+ failSafeOverrides(static_cast<MatmulStrategy &>(strategy), gpuModel);
+
+ // Failsafe for blockTileBatch to avoid tiling by > size (i.e. no tiling).
+ int64_t blockTileBatch = selectLargestFailsafeValueIfNeeded(
+ /*value=*/strategy.blockTileBatch(),
+ /*limit=*/strategy.batch(),
+ /*thresholds=*/{2, 4, 8, 16, 32, 64, 128},
+ /*failSafeValues=*/{1, 2, 4, 8, 16, 32, 64});
+
+ // Override the matmul configuration to be suitable for batch matmul.
+ // Specifically, prepend the tile size for the batch dimension and force FMA.
+ strategy.blockTileSizes.insert(strategy.blockTileSizes.begin(),
+ blockTileBatch);
+
+ strategy.useMmaSync = false;
+ strategy.useWmma = false;
+ strategy.useFma = true;
+}
+
+/// Produce a strategy for the batch matmul characterized by the given capture
+/// list (shapes and types).
+static BatchMatmulStrategy getBatchMatmulConfig(MLIRContext *context,
+ MatchedMatmulCaptures &captures,
+ const GPUModel &gpuModel) {
+ // Command-line arguments trump everything.
+ BatchMatmulStrategy strategy(context, gpuModel, captures);
+ if (strategy.cliOptionsSpecified)
+ return strategy;
+
+ // TODO: fixed strategies and decision tree/heuristic.
+
+ failSafeOverrides(strategy, gpuModel);
+ return strategy;
+}
+
+/// Match the supported batch matmuls and set the transform dialect strategy for
+/// them.
+static LogicalResult matchAndSetBatchMatmulStrategy(func::FuncOp entryPoint,
+ linalg::LinalgOp op,
+ const GPUModel &gpuModel) {
+ if (!clGPUEnableTransformDialectBatchMatmulStrategy) {
+ LDBG("--Batch matmul strategy flag turned off\n");
+ return failure();
+ }
+
+ StructuredOpMatcher *fill;
+ StructuredOpMatcher *bmm;
+ transform_ext::MatchedMatmulCaptures captures;
+ transform_ext::MatcherContext matcherContext;
+ transform_ext::makeBatchMatmulMatcher(matcherContext, bmm, fill, captures,
+ /*mustMatchEntireFunc=*/true);
+ if (!matchPattern(op, *bmm)) {
+ LDBG("--Batch matmul strategy failed to match\n");
+ return failure();
+ }
+
+ if (captures.contractionDims.batch.size() != 1 ||
+ captures.contractionDims.m.size() != 1 ||
+ captures.contractionDims.n.size() != 1 ||
+ captures.contractionDims.k.size() != 1 || captures.batches()[0] != 0 ||
+ captures.m() != 1 || captures.n() != 2 || captures.k() != 3) {
+ LDBG("--Only support batch matmul with b, m, n, k iterator order atm\n");
+ return failure();
+ }
+
+ BatchMatmulStrategy strategy =
+ getBatchMatmulConfig(entryPoint->getContext(), captures, gpuModel);
+ if (failed(strategy.validate(gpuModel))) {
+ LDBG("--Batch matmul strategy failed to validate\n");
+ return failure();
+ }
+
+ iree_compiler::createTransformRegion(entryPoint, [&](ImplicitLocOpBuilder &b,
+ Value variantH) {
+ return iree_compiler::gpu::buildBatchMatmulStrategy(b, variantH, strategy);
+ });
+ return success();
+}
+
static LogicalResult matchAndSetMatmulStrategy(func::FuncOp entryPoint,
linalg::LinalgOp op,
const GPUModel &gpuModel) {
@@ -551,6 +642,11 @@
LDBG("Activate matmul\n");
return success();
}
+ if (succeeded(
+ matchAndSetBatchMatmulStrategy(entryPoint, linalgOp, gpuModel))) {
+ LDBG("Activate batch matmul\n");
+ return success();
+ }
// TODO: Add more transform dialect strategy for other kind of dispatch
// regions.
LDBG("No suitable strategy found\n");
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h
index cc75416..852d5f1 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h
@@ -17,6 +17,7 @@
namespace gpu {
/// Forward declarations of all supported strategies.
+struct BatchMatmulStrategy;
struct MatmulStrategy;
class PadStrategy;
class SmallReductionStrategy;
@@ -65,6 +66,14 @@
const MatmulStrategy &strategy);
//===--------------------------------------------------------------------===//
+// Batch matmul strategies.
+//===--------------------------------------------------------------------===//
+/// Entry point to build the transform IR corresponding to an FMA-based strategy
+/// for linalg.fill + linalg.batch_matmul.
+void buildBatchMatmulStrategy(ImplicitLocOpBuilder &b, Value variantH,
+ const BatchMatmulStrategy &strategy);
+
+//===--------------------------------------------------------------------===//
// Pad strategies.
//===--------------------------------------------------------------------===//
/// Entry point to build the transform IR corresponding to a simple pad
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 8f98e30..c5d48c3 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -72,6 +72,17 @@
using Base::Base;
};
+/// Captures the contraction dimensions of the target operation.
+struct CaptureIndexingMaps : public CaptureStaticValue<SmallVector<AffineMap>> {
+ using Base::Base;
+};
+
+/// Captures the contraction dimensions of the target operation.
+struct CaptureContractionDims
+ : public CaptureStaticValue<mlir::linalg::ContractionDimensions> {
+ using Base::Base;
+};
+
/// Captures the convolution dimensions of the target operation.
struct CaptureConvDims
: public CaptureStaticValue<mlir::linalg::detail::ConvolutionDimensions> {
@@ -627,6 +638,7 @@
/// constant value.
StructuredOpMatcher &rank(NumGreaterEqualTo minRank);
StructuredOpMatcher &rank(NumLowerEqualTo maxRank);
+ StructuredOpMatcher &rank(NumEqualsTo exactRank);
/// Adds a predicate checking that the given iteration space dimension is
/// static/dynamic. The dimension index may be negative, in which case
@@ -667,6 +679,8 @@
StructuredOpMatcher &rank(CaptureRank capture);
StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture);
StructuredOpMatcher &dim(AllDims tag, CaptureDims captures);
+ StructuredOpMatcher &indexingMaps(CaptureIndexingMaps indexingMaps);
+ StructuredOpMatcher &contractionDims(CaptureContractionDims contractionDims);
StructuredOpMatcher &convolutionDims(CaptureConvDims convDims);
//===-------------------------------------------------------------------===//
@@ -863,6 +877,16 @@
/// using block arguments in order.
StructuredOpMatcher &passThroughOp();
+ /// Check if the body of the linalg op implements a contraction of the kind
+ /// result <ReductionOpTy>= input1 <ElemOpTy> input2
+ template <typename ElemOpTy, typename ReductionOpTy>
+ StructuredOpMatcher &hasContractionBody() {
+ return hasContractionBody(
+ [](Operation *op) { return isa<ElemOpTy>(op); },
+ [](Operation *op) { return isa<ReductionOpTy>(op); },
+ ElemOpTy::getOperationName(), ReductionOpTy::getOperationName());
+ }
+
private:
/// Non-template implementations of nested predicate builders for inputs,
/// outputs and results. Should not be called directly.
@@ -881,6 +905,13 @@
// Common util for constant matcher.
StructuredOpMatcher &input(int64_t position,
std::function<bool(llvm::APFloat)> floatValueFn);
+
+ /// Non-template implementation of hasContractionBody. Takes callbacks for
+ /// checking operation kinds and names for error reporting.
+ StructuredOpMatcher &
+ hasContractionBody(function_ref<bool(Operation *)> isaElemOpTy,
+ function_ref<bool(Operation *)> isaReductionOpTy,
+ StringRef elemOpName, StringRef reductionOpName);
};
/// Creates a matcher of an arbitrary structured op.
@@ -1009,8 +1040,36 @@
};
struct MatchedMatmulCaptures {
+ linalg::ContractionDimensions contractionDims = {};
Type lhsElementType, rhsElementType, outputElementType;
SmallVector<int64_t> matmulOpSizes = {};
+ SmallVector<AffineMap> indexingMaps;
+
+ /// Helper functions.
+ int64_t rank() const { return matmulOpSizes.size(); }
+ /// Return all batches.
+ ArrayRef<unsigned> batches() const { return contractionDims.batch; }
+ /// Return the most minor candidate dimension for `m`.
+ int64_t m() const { return contractionDims.m.back(); }
+ /// Return the most minor candidate dimension for `n`.
+ int64_t n() const { return contractionDims.n.back(); }
+ /// Return the most minor candidate dimension for `k`.
+ int64_t k() const { return contractionDims.k.back(); }
+ /// AffineMap for indexing into the LHS.
+ AffineMap lhsIndexing() const {
+ assert(indexingMaps.size() == 3 && "expected 3 indexing maps");
+ return indexingMaps[0];
+ }
+ /// AffineMap for indexing into the RHS.
+ AffineMap rhsIndexing() const {
+ assert(indexingMaps.size() == 3 && "expected 3 indexing maps");
+ return indexingMaps[1];
+ }
+ /// AffineMap for indexing into the RES.
+ AffineMap resIndexing() const {
+ assert(indexingMaps.size() == 3 && "expected 3 indexing maps");
+ return indexingMaps[2];
+ }
};
/// Creates a group of matchers for:
@@ -1046,6 +1105,19 @@
MatchedMatmulCaptures &captures,
bool mustMatchEntireFunc);
+/// Create a group of matchers of batch mamtul with a fill:
+///
+/// batch_matmul(*, *, fill())
+///
+/// and capture various useful quantities. If `mustMatchEntireFunc` is set, the
+/// matcher additionally checks if all tileable operations in the functions are
+/// captured.
+void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext,
+ transform_ext::StructuredOpMatcher *&bmmCapture,
+ transform_ext::StructuredOpMatcher *&fillCapture,
+ transform_ext::MatchedMatmulCaptures &captures,
+ bool mustMatchEntireFunc);
+
/// Create a group of matchers for a different code sequence of operations
/// matching exactly a softmax operation.
///
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 7e2963b..ed3dfd6 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -856,6 +856,58 @@
return emitSilenceableFailure(loc) << "failed to match";
}
+/// Match callback for linalg.batch_matmul and its linalg.generic equivalent fed
+/// by a linalg.fill.
+///
+/// Input handles:
+///
+/// - the container op, must be associated with one operation.
+///
+/// Output handles:
+///
+/// - the fill op initializing the output;
+/// - the main compute op.
+static DiagnosedSilenceableFailure
+batchMatmulCallback(transform_ext::MatchCallbackResult &res, Location loc,
+ const mlir::transform::TransformState &state,
+ ValueRange handles) {
+ if (handles.size() != 1 ||
+ !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) {
+ return emitSilenceableFailure(loc)
+ << "expected one handle to one operation";
+ }
+
+ transform_ext::StructuredOpMatcher *pattern, *fill;
+ transform_ext::MatchedMatmulCaptures ignore;
+ transform_ext::MatcherContext matcherContext;
+ transform_ext::makeBatchMatmulMatcher(matcherContext, pattern, fill, ignore,
+ /*mustMatchEntireFunc*/ true);
+
+ // TODO: need a mechanism for this to go around the entire IR,
+ // potentially with list matches for each group.
+ Operation *root = *state.getPayloadOps(handles[0]).begin();
+
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ pattern->resetCapture();
+ if (!matchPattern(op, *pattern))
+ return WalkResult::advance();
+
+ // TODO: notify properly
+ LLVM_DEBUG({
+ DBGS() << "fill:" << fill->getCaptured() << "\n";
+ DBGS() << "pattern: " << pattern->getCaptured() << "\n";
+ });
+
+ res.addPayloadGroup({fill->getCaptured()});
+ res.addPayloadGroup({pattern->getCaptured()});
+ return WalkResult::interrupt();
+ });
+
+ if (walkResult.wasInterrupted())
+ return DiagnosedSilenceableFailure::success();
+ return emitSilenceableFailure(loc) << "failed to match batch matmul";
+}
+
/// Match callback for a tensor.pad. Matches *the first* occurrence of such pad
/// within an op associated with the given handle.
///
@@ -922,6 +974,7 @@
testShapedValueMatcherCallback);
registry.registerCallback("convolution", convolutionCallback);
registry.registerCallback("matmul", matmulCallback);
+ registry.registerCallback("batch_matmul", batchMatmulCallback);
registry.registerCallback("pad", wrapAsEntireFuncMatch(padCallback));
registry.registerCallback("reduction",
wrapAsEntireFuncMatch(reductionCallback));
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 59f085e..0d2293d 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -424,6 +424,14 @@
});
}
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ LLVM_DEBUG(DBGS() << "rank == " << exactRank.value);
+ return linalgOp.getNumLoops() == exactRank.value;
+ });
+}
+
StringRef stringifyShapeKind(transform_ext::ShapeKind kind) {
switch (kind) {
case transform_ext::ShapeKind::Static:
@@ -576,16 +584,40 @@
}
transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::indexingMaps(
+ CaptureIndexingMaps indexingMaps) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ LLVM_DEBUG(DBGS() << "capture indexing maps");
+ indexingMaps.value = linalgOp.getIndexingMapsArray();
+ return true;
+ });
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::contractionDims(
+ CaptureContractionDims contractionDims) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ LLVM_DEBUG(DBGS() << "capture contraction dimensions");
+ StringRef convMessage = linalg::detail::getMatchContractionMessage(
+ mlir::linalg::detail::isContractionInterfaceImpl(
+ linalgOp, &contractionDims.value));
+ if (convMessage.empty())
+ return true;
+ LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")");
+ return false;
+ });
+}
+
+transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) {
return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
- LLVM_DEBUG(DBGS() << "capture convolution dimensions\n");
+ LLVM_DEBUG(DBGS() << "capture convolution dimensions");
StringRef convMessage = linalg::detail::getMatchConvolutionMessage(
mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp,
&convDims.value));
if (convMessage.empty())
return true;
- LLVM_DEBUG(DBGS() << "capture convolution dimensions failed: "
- << convMessage << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")");
return false;
});
}
@@ -1149,6 +1181,79 @@
});
}
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::hasContractionBody(
+ function_ref<bool(Operation *)> isaElemOpTy,
+ function_ref<bool(Operation *)> isaReductionOpTy, StringRef elemOpName,
+ StringRef reductionOpName) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) {
+ LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/"
+ << reductionOpName << " contraction (");
+ auto scopeExitPrinter = llvm::make_scope_exit(
+ [] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); });
+
+ Block *body = linalgOp.getBlock();
+ if (!llvm::hasNItems(*body, 3)) {
+ LLVM_DEBUG(llvm::dbgs() << "three-operation body");
+ return false;
+ }
+ if (body->getNumArguments() != 3) {
+ LLVM_DEBUG(llvm::dbgs() << "three-argument block");
+ return false;
+ }
+
+ Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin());
+ Operation *reductionOp = elemOp->getNextNode();
+ Operation *yieldOp = reductionOp->getNextNode();
+ if (!isaElemOpTy(elemOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName);
+ return false;
+ }
+ if (!isaReductionOpTy(reductionOp)) {
+ LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName);
+ return false;
+ }
+ if (yieldOp->getNumOperands() != 1) {
+ LLVM_DEBUG(llvm::dbgs() << "one value yielded");
+ return false;
+ }
+ if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) {
+ LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op");
+ return false;
+ }
+ if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) {
+ LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result");
+ return false;
+ }
+ if (reductionOp->getNumOperands() != 2 ||
+ reductionOp->getNumResults() != 1) {
+ LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result");
+ return false;
+ }
+
+ SmallVector<Value, 2> expectedReductionOperands = {body->getArgument(2),
+ elemOp->getResult(0)};
+ if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) &&
+ !llvm::equal(llvm::reverse(expectedReductionOperands),
+ reductionOp->getOperands())) {
+ LLVM_DEBUG(llvm::dbgs() << "operands of the second op");
+ return false;
+ }
+
+ ValueRange expectedElemOperands = body->getArguments().take_front(2);
+ if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) &&
+ !llvm::equal(llvm::reverse(expectedElemOperands),
+ elemOp->getOperands())) {
+ LLVM_DEBUG(llvm::dbgs() << "operands of the first op");
+ return false;
+ }
+
+ scopeExitPrinter.release();
+ LLVM_DEBUG(llvm::dbgs() << "success)");
+ return true;
+ });
+}
+
void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor(
StringRef name) {
LLVM_DEBUG(DBGS() << "op is a " << name << "'");
@@ -1393,6 +1498,34 @@
trailingCapture = &trailing;
}
+void transform_ext::makeBatchMatmulMatcher(
+ transform_ext::MatcherContext &matcherContext,
+ transform_ext::StructuredOpMatcher *&bmmCapture,
+ transform_ext::StructuredOpMatcher *&fillCapture,
+ transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) {
+ auto &bmm =
+ transform_ext::m_StructuredOp<linalg::BatchMatmulOp, linalg::GenericOp>(
+ matcherContext)
+ .hasContractionBody<arith::MulFOp, arith::AddFOp>()
+ .rank(NumEqualsTo(4))
+ .dim(AllDims(), CaptureDims(captures.matmulOpSizes))
+ .dim(AllDimsExcept({-1}), utils::IteratorType::parallel)
+ .dim(-1, utils::IteratorType::reduction)
+ .contractionDims(CaptureContractionDims(captures.contractionDims))
+ .input(NumEqualsTo(2))
+ .input(0, CaptureElementType(captures.lhsElementType))
+ .input(1, CaptureElementType(captures.rhsElementType))
+ .output(0, CaptureElementType(captures.outputElementType));
+ bmmCapture = &bmm;
+
+ auto &fill = transform_ext::m_StructuredOp<linalg::FillOp>(matcherContext);
+ bmm = bmm.output(0, fill);
+ fillCapture = &fill;
+
+ if (mustMatchEntireFunc)
+ bmm = bmm.allTilableOpsCaptured<func::FuncOp>();
+}
+
/// Match sum(%src, broadcast(%reduction))
static void
matchSubBroadcast(transform_ext::MatcherContext &matcherContext,