| /* |
| * Copyright 2023 Google LLC |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "iree/builtins/ukernel/query_tile_sizes_internal.h" |
| |
| static bool iree_uk_query_tile_sizes_operation_is_matmul( |
| iree_uk_uint32_t flags) { |
| iree_uk_uint32_t op = iree_uk_query_tile_sizes_operation(flags); |
| return op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32 || |
| op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32; |
| } |
| |
| static void iree_uk_query_tile_sizes_2d_validate( |
| const iree_uk_query_tile_sizes_2d_params_t* params) { |
| #ifdef IREE_UK_ENABLE_ASSERTS |
| IREE_UK_ASSERT(iree_uk_query_tile_sizes_operation_is_matmul(params->flags)); |
| iree_uk_uint32_t role = iree_uk_query_tile_sizes_operand_role(params->flags); |
| IREE_UK_ASSERT(role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS || |
| role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS || |
| role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT); |
| const iree_uk_int64_t kDynamic = IREE_UK_INT64_MIN; |
| IREE_UK_ASSERT((params->size0 >= 0 || params->size0 == kDynamic) || |
| (params->size1 >= 0 || params->size1 == kDynamic)); |
| #endif // IREE_UK_ENABLE_ASSERTS |
| } |
| |
| static iree_uk_matmul_tile_sizes_t iree_uk_query_matmul_tile_sizes_generic( |
| const iree_uk_query_tile_sizes_2d_params_t* params) { |
| // Dummy values, originally taken from what was used on ARM_64 +dotprod for |
| // i8i8i32. Not particularly meaningful outside of that case, just is what |
| // some tests have been written against. |
| (void)params; |
| return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 4, .N = 8}; |
| } |
| |
| // Experimental tile sizes for RVV |
| static iree_uk_matmul_tile_sizes_t iree_uk_query_matmul_tile_sizes_rvv( |
| const iree_uk_query_tile_sizes_2d_params_t* params) { |
| (void)params; |
| return (iree_uk_matmul_tile_sizes_t){.M = 16, .K = 16, .N = 16}; |
| } |
| |
| static void iree_uk_query_tile_sizes_2d_matmul( |
| const iree_uk_query_tile_sizes_2d_params_t* params, |
| iree_uk_query_tile_sizes_2d_out_params_t* out_params) { |
| iree_uk_matmul_tile_sizes_t matmul_tile_sizes; |
| if (iree_uk_query_tile_sizes_operation(params->flags) == |
| IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32) { |
| matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_rvv(params); |
| } else { |
| matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_generic(params); |
| } |
| iree_uk_uint32_t role = iree_uk_query_tile_sizes_operand_role(params->flags); |
| if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS) { |
| out_params->tile_size0 = matmul_tile_sizes.M; |
| out_params->tile_size1 = matmul_tile_sizes.K; |
| } else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS) { |
| out_params->tile_size0 = matmul_tile_sizes.N; |
| out_params->tile_size1 = matmul_tile_sizes.K; |
| } else if (role == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT) { |
| out_params->tile_size0 = matmul_tile_sizes.M; |
| out_params->tile_size1 = matmul_tile_sizes.N; |
| } else { |
| // Can't happen, validated earlier. |
| IREE_UK_ASSUME_UNREACHABLE; |
| } |
| } |
| |
| IREE_UK_EXPORT int iree_uk_query_tile_sizes_2d( |
| const iree_uk_query_tile_sizes_2d_params_t* params, |
| iree_uk_query_tile_sizes_2d_out_params_t* out_params) { |
| iree_uk_query_tile_sizes_2d_validate(params); |
| |
| if (iree_uk_query_tile_sizes_operation_is_matmul(params->flags)) { |
| iree_uk_query_tile_sizes_2d_matmul(params, out_params); |
| } else { |
| // Can't happen, validated earlier. |
| IREE_UK_ASSUME_UNREACHABLE; |
| } |
| return 0; |
| } |