Add support for custom tile size in ukernel mmt4d Also set experimental tile sizes for RVV. Change-Id: I2242d22c037a74dc84a093dc700f93996ea8de4a
diff --git a/vmvx_ukernel/CMakeLists.txt b/vmvx_ukernel/CMakeLists.txt index 2dd48ec..89f13a8 100644 --- a/vmvx_ukernel/CMakeLists.txt +++ b/vmvx_ukernel/CMakeLists.txt
@@ -12,10 +12,10 @@ "elementwise.c" "elementwise_impl.c.inc" "mmt4d_tile.c" + "query_tile_sizes.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack_tile.c" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/query_tile_sizes.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/unpack.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/unpack_tile.c" DEPS
diff --git a/vmvx_ukernel/query_tile_sizes.c b/vmvx_ukernel/query_tile_sizes.c new file mode 100644 index 0000000..4f22c54 --- /dev/null +++ b/vmvx_ukernel/query_tile_sizes.c
@@ -0,0 +1,98 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "iree/builtins/ukernel/query_tile_sizes.h" + +static bool iree_uk_query_tile_sizes_operation_is_matmul( + iree_uk_uint32_t flags) { + iree_uk_uint32_t op = iree_uk_query_tile_sizes_operation(flags); + return op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32 || + op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32; +} + +static void iree_uk_query_tile_sizes_2d_validate( + const iree_uk_query_tile_sizes_2d_params_t* params) { +#ifdef IREE_UK_ENABLE_ASSERTS + IREE_UK_ASSERT(iree_uk_query_tile_sizes_operation_is_matmul(params->flags)); + iree_uk_uint32_t role = iree_uk_query_tile_sizes_operand_role(params->flags); + IREE_UK_ASSERT(role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS || + role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS || + role == + IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS_TRANSPOSE || + role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT); + const iree_uk_int64_t kDynamic = IREE_UK_INT64_MIN; + IREE_UK_ASSERT((params->size0 >= 0 || params->size0 == kDynamic) || + (params->size1 >= 0 || params->size1 == kDynamic)); +#endif // IREE_UK_ENABLE_ASSERTS +} + +static iree_uk_matmul_tile_sizes_t iree_uk_query_matmul_tile_sizes_generic( + const iree_uk_query_tile_sizes_2d_params_t* params) { + // Dummy values, originally taken from what was used on ARM_64 +dotprod for + // i8i8i32. Not particularly meaningful outside of that case, just is what + // some tests have been written against. + (void)params; + return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 4, .N = 8}; +} + +// Experimental tile sizes for RVV +static iree_uk_matmul_tile_sizes_t iree_uk_query_matmul_tile_sizes_rvv( + const iree_uk_query_tile_sizes_2d_params_t* params) { + (void)params; + return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 16, .N = 8}; +} + +static void iree_uk_query_tile_sizes_2d_matmul( + const iree_uk_query_tile_sizes_2d_params_t* params, + iree_uk_query_tile_sizes_2d_out_params_t* out_params) { + iree_uk_matmul_tile_sizes_t matmul_tile_sizes; + if (iree_uk_query_tile_sizes_operation(params->flags) == + IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32) { + matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_rvv(params); + } else { + matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_generic(params); + } + iree_uk_uint32_t role = iree_uk_query_tile_sizes_operand_role(params->flags); + if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS) { + out_params->tile_size0 = matmul_tile_sizes.M; + out_params->tile_size1 = matmul_tile_sizes.K; + } else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS) { + out_params->tile_size0 = matmul_tile_sizes.K; + out_params->tile_size1 = matmul_tile_sizes.N; + } else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS_TRANSPOSE) { + out_params->tile_size0 = matmul_tile_sizes.N; + out_params->tile_size1 = matmul_tile_sizes.K; + } else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT) { + out_params->tile_size0 = matmul_tile_sizes.M; + out_params->tile_size1 = matmul_tile_sizes.N; + } else { + // Can't happen, validated earlier. + IREE_UK_ASSUME_UNREACHABLE; + } +} + +IREE_UK_EXPORT void iree_uk_query_tile_sizes_2d( + const iree_uk_query_tile_sizes_2d_params_t* params, + iree_uk_query_tile_sizes_2d_out_params_t* out_params) { + iree_uk_query_tile_sizes_2d_validate(params); + + if (iree_uk_query_tile_sizes_operation_is_matmul(params->flags)) { + iree_uk_query_tile_sizes_2d_matmul(params, out_params); + } else { + // Can't happen, validated earlier. + IREE_UK_ASSUME_UNREACHABLE; + } +}