blob: 169cf9ca8df2374c5bf13d178e34922751fa3d30 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "iree/builtins/ukernel/query_tile_sizes_internal.h"
static bool iree_uk_query_tile_sizes_operation_is_matmul(
iree_uk_uint32_t flags) {
iree_uk_uint32_t op = iree_uk_query_tile_sizes_operation(flags);
return op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32 ||
op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
}
static void iree_uk_query_tile_sizes_2d_validate(
const iree_uk_query_tile_sizes_2d_params_t* params) {
#ifdef IREE_UK_ENABLE_ASSERTS
IREE_UK_ASSERT(iree_uk_query_tile_sizes_operation_is_matmul(params->flags));
iree_uk_uint32_t role = iree_uk_query_tile_sizes_operand_role(params->flags);
IREE_UK_ASSERT(role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS ||
role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS ||
role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT);
const iree_uk_int64_t kDynamic = IREE_UK_INT64_MIN;
IREE_UK_ASSERT((params->size0 >= 0 || params->size0 == kDynamic) ||
(params->size1 >= 0 || params->size1 == kDynamic));
#endif // IREE_UK_ENABLE_ASSERTS
}
static iree_uk_matmul_tile_sizes_t iree_uk_query_matmul_tile_sizes_generic(
const iree_uk_query_tile_sizes_2d_params_t* params) {
// Dummy values, originally taken from what was used on ARM_64 +dotprod for
// i8i8i32. Not particularly meaningful outside of that case, just is what
// some tests have been written against.
(void)params;
return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 4, .N = 8};
}
// Experimental tile sizes for RVV
static iree_uk_matmul_tile_sizes_t iree_uk_query_matmul_tile_sizes_rvv(
const iree_uk_query_tile_sizes_2d_params_t* params) {
(void)params;
return (iree_uk_matmul_tile_sizes_t){.M = 16, .K = 16, .N = 16};
}
static void iree_uk_query_tile_sizes_2d_matmul(
const iree_uk_query_tile_sizes_2d_params_t* params,
iree_uk_query_tile_sizes_2d_out_params_t* out_params) {
iree_uk_matmul_tile_sizes_t matmul_tile_sizes;
if (iree_uk_query_tile_sizes_operation(params->flags) ==
IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32) {
matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_rvv(params);
} else {
matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_generic(params);
}
iree_uk_uint32_t role = iree_uk_query_tile_sizes_operand_role(params->flags);
if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS) {
out_params->tile_size0 = matmul_tile_sizes.M;
out_params->tile_size1 = matmul_tile_sizes.K;
} else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS) {
out_params->tile_size0 = matmul_tile_sizes.N;
out_params->tile_size1 = matmul_tile_sizes.K;
} else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT) {
out_params->tile_size0 = matmul_tile_sizes.M;
out_params->tile_size1 = matmul_tile_sizes.N;
} else {
// Can't happen, validated earlier.
IREE_UK_ASSUME_UNREACHABLE;
}
}
IREE_UK_EXPORT int iree_uk_query_tile_sizes_2d(
const iree_uk_query_tile_sizes_2d_params_t* params,
iree_uk_query_tile_sizes_2d_out_params_t* out_params) {
iree_uk_query_tile_sizes_2d_validate(params);
if (iree_uk_query_tile_sizes_operation_is_matmul(params->flags)) {
iree_uk_query_tile_sizes_2d_matmul(params, out_params);
} else {
// Can't happen, validated earlier.
IREE_UK_ASSUME_UNREACHABLE;
}
return 0;
}