blob: 2b84822084ba20964ef6bf0a8c76724247891387 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// The configuration used for executable compilation.
// This specifies the device configurations that support this custom kernel.
#spirv_target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
iree.gpu.target = #iree_gpu.target<
arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
dot = none, mma = [], subgroup_size_choices = [64, 64],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>
>
}>
module attributes {transform.with_named_sequence} {
func.func private @argmax_1d_f32_entry_point(%arg0: tensor<1x?xf32>) -> tensor<1xi64> {
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c1 : tensor<1x?xf32>
// Note: This is not safe if the dim size exceeds INT32_MAX. To pass a 64
// bit value it must be broken down into two 32-bit values for the high and
// low bits.
%dim_i32 = arith.index_cast %dim : index to i32
// Inline external dispatch that conforms to the ABI that the kernel
// requires. This is the primary reason for the surrounding function as
// details like tensor shape and push constants need to line up after
// splicing in the custom dispatch. This allows the kernel author to manage
// such details by hand without needing the rewrite patterns to worry about
// things like order of push constants.
%4 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0) : (i32, tensor<1x?xf32>{%dim}) -> tensor<1xi64>
count(%device: !hal.device, %workload: index) -> (index, index, index) {
%c1_0 = arith.constant 1 : index
hal.return %c1_0, %c1_0, %c1_0 : index, index, index
}
layout(#hal.pipeline.layout<constants = 1, bindings = [
#hal.pipeline.binding<storage_buffer, ReadOnly>,
#hal.pipeline.binding<storage_buffer>
]>)
objects({
#spirv_target ordinal(0) = [
#hal.executable.object<{
path = "samples/custom_dispatch/vulkan/shaders/one_workgroup_argmax_subgroup_f32.spv"
}>
]
})
return %4 : tensor<1xi64>
}
// Custom matcher for argmax operations equivalent to the custom kernel. This
// matcher will be run one-by-one on all operations contained within the
// target function. On success, it will return the handle to the matched
// argmax operation.
transform.named_sequence @match_argmax(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
// Fail fast on non-linalg generics.
transform.match.operation_name %generic ["linalg.generic"] : !transform.any_op
%matched = transform.match.structured failures(propagate) %generic : (!transform.any_op) -> (!transform.any_op) {
^bb1(%argmax: !transform.any_op):
// Verify that the rank (i.e. number of loops) of the linalg op is 2,
// with one parallel iterator and one reduction iterator.
// TODO: Add optionality for the parallel dimensions.
%c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
%rank = transform.match.structured.rank %argmax : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %rank, %c2 : !transform.param<i64>
transform.match.structured.dim %argmax[0] {parallel} : !transform.any_op
transform.match.structured.dim %argmax[-1] {reduction} : !transform.any_op
// Verify a single input (target vector to compute the argmax of) and two
// outputs, one for the maximum value and one for the index.
%c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
%n_inputs = transform.match.structured.num_inputs %argmax : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
%n_outputs = transform.match.structured.num_inits %argmax : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %n_outputs, %c2 : !transform.param<i64>
transform.match.structured.yield %argmax : !transform.any_op
}
// Verify the operand shapes of the linalg op. For example, in the below,
// dim 0 must be statically 1, and dim 1 must be statically divisible by 64.
%in0 = transform.get_operand %matched[0] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %in0 = tensor<1x?xf32> : !transform.any_value
transform.iree.match.dim_is_multiple_of %in0[1], 64 : !transform.any_value
%out0 = transform.get_operand %matched[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %out0 = tensor<1xf32> : !transform.any_value
%out1 = transform.get_operand %matched[2] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %out1 = tensor<1xi64> : !transform.any_value
// Verify the region of the argmax op. This does a structural comparison of
// region(s) of the payload operation against the single operation contained
// within the body of this operation. This does no verification of other
// input types/attributes. This is because typically for kernel matching,
// the most important part to get exactly right is the inner loop. Otherwise
// small variations to shape information and iterator counts and such are
// better suited for more general matchers.
transform.iree.match.regions %matched : !transform.any_op {
^bb0(%target: tensor<1x?xf32>, %empty_max: tensor<1xf32>, %empty_idx: tensor<1xi64>):
%5:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%target : tensor<1x?xf32>)
outs(%empty_max, %empty_idx : tensor<1xf32>, tensor<1xi64>) {
^bb0(%in: f32, %out: f32, %out_0: i64):
%6 = linalg.index 1 : index
%7 = arith.index_cast %6 : index to i64
%8 = arith.maximumf %in, %out : f32
%9 = arith.cmpf ogt, %in, %out : f32
%10 = arith.select %9, %7, %out_0 : i64
linalg.yield %8, %10 : f32, i64
} -> (tensor<1xf32>, tensor<1xi64>)
}
transform.yield %generic : !transform.any_op
}
// Rewrite callback for `transform.foreach_match`. The input signature for
// this sequence must match exactly with the outputs of the matcher. In this
// case we just take the argmax as an input, import the entry point for the
// custom kernel authored above, and replace the users of the argmax with a
// call to the function.
transform.named_sequence @cast_and_call_argmax(%argmax: !transform.any_op {transform.readonly}) {
%module = transform.util.get_nearest_symbol_table %argmax : (!transform.any_op) -> !transform.any_op
%func = transform.util.import_symbol @argmax_1d_f32_entry_point into %module : (!transform.any_op) -> !transform.any_op
%ins = transform.get_operand %argmax[0] : (!transform.any_op) -> !transform.any_value
%outs = transform.get_result %argmax[1] : (!transform.any_op) -> !transform.any_value
transform.func.cast_and_call %func(%ins) -> %outs before %argmax {
// This specifies how to resolve type mismatches between the arguments
// of the function and the inputs to the argmax. In this example, the
// only casts this will generate are same-rank tensor casts that drop
// static information.
transform.type_conversion.tensor.cast_shape_dynamic_dims
} : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op
transform.yield
}
// Entry point for the transform interpreter, nested on the full module. This
// is because the rewrites needed for importing the custom kernel needs to
// add a new symbol to the module's symbol table.
transform.named_sequence @__transform_main(%module: !transform.any_op) {
// Gather the set of functions within the module.
%funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
// For each function in the module, run the matcher on all contained
// operations.
transform.foreach %funcs : !transform.any_op {
^bb1(%func: !transform.any_op):
transform.foreach_match in %func
// <matcher name> -> <rewriter name>
// Multiple matcher-action pairs can be specified comma separated,
// here we are only doing a single kind of match and replace.
//
// Note that the operations within the module are walked in
// post-order, meaning actions must be very careful in their
// replacements not to modify successors of operations. Nested
// regions and DAG roots will be visited last so it is safest to
// do matching + replacement on the root of the DAG rather than
// trying to look ahead. The other option is to avoid dce/cse until
// after the walk is complete.
@match_argmax -> @cast_and_call_argmax
: (!transform.any_op) -> (!transform.any_op)
}
// Cleanup now dead instances of argmax.
transform.apply_dce to %module : !transform.any_op
transform.yield
}
}