| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| ==============================================================================*/ | |
| #include "mli_api.h" // NOLINT | |
| namespace tflite { | |
| // Convolution specialized function. | |
| typedef mli_status (*conv_func_ptr)(const mli_tensor* /*in*/, | |
| const mli_tensor* /*weights*/, | |
| const mli_tensor* /*bias*/, | |
| const mli_conv2d_cfg* /*cfg*/, | |
| mli_tensor* /*out*/); | |
| #ifdef MLI_2_0 | |
| conv_func_ptr __attribute__((weak)) | |
| mli_krn_conv2d_hwcn(const mli_tensor* weights) { | |
| int filter_w = weights->shape[KRNL_W_DIM_HWCN]; | |
| int filter_h = weights->shape[KRNL_H_DIM_HWCN]; | |
| if (filter_w == 1 && filter_h == 1) { | |
| return mli_krn_conv2d_hwcn_sa8_sa8_sa32_k1x1; | |
| } else if (filter_w == 3 && filter_h == 3) { | |
| return mli_krn_conv2d_hwcn_sa8_sa8_sa32_k3x3; | |
| } else if (filter_w == 5 && filter_h == 5) { | |
| return mli_krn_conv2d_hwcn_sa8_sa8_sa32_k5x5; | |
| } else { | |
| return mli_krn_conv2d_hwcn_sa8_sa8_sa32; | |
| } | |
| } | |
| #else | |
| conv_func_ptr __attribute__((weak)) | |
| mli_krn_conv2d_hwcn(const mli_tensor* weights, const mli_conv2d_cfg* cfg) { | |
| return mli_krn_conv2d_nhwc_sa8_sa8_sa32; | |
| } | |
| #endif | |
| // Depthwise convolution specialized function. | |
| typedef mli_status (*depthwise_func_ptr)(const mli_tensor* /*in*/, | |
| const mli_tensor* /*weights*/, | |
| const mli_tensor* /*bias*/, | |
| const mli_conv2d_cfg* /*cfg*/, | |
| mli_tensor* /*out*/); | |
| #ifdef MLI_2_0 | |
| depthwise_func_ptr __attribute__((weak)) | |
| mli_krn_depthwise_conv2d(const mli_tensor* weights) { | |
| int filter_w = weights->shape[KRNL_DW_W_DIM_HW1N]; | |
| int filter_h = weights->shape[KRNL_DW_H_DIM_HW1N]; | |
| if (filter_w == 3 && filter_h == 3) { | |
| return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32_k3x3; | |
| } else if (filter_w == 5 && filter_h == 5) { | |
| return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32_k5x5; | |
| } else { | |
| return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32; | |
| } | |
| } | |
| #else | |
| depthwise_func_ptr __attribute__((weak)) | |
| mli_krn_depthwise_conv2d(const mli_tensor* weights, const mli_conv2d_cfg* cfg) { | |
| return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32; | |
| } | |
| #endif | |
| #ifdef MLI_2_0 | |
| depthwise_func_ptr __attribute__((weak)) | |
| mli_krn_group_conv2d(const mli_tensor* weights) { | |
| int filter_w = weights->shape[KRNL_DW_W_DIM_HW1N]; | |
| int filter_h = weights->shape[KRNL_DW_H_DIM_HW1N]; | |
| if (filter_w == 3 && filter_h == 3) { | |
| return mli_krn_group_conv2d_hwcn_sa8_sa8_sa32_k3x3; | |
| } else if (filter_w == 5 && filter_h == 5) { | |
| return mli_krn_group_conv2d_hwcn_sa8_sa8_sa32_k5x5; | |
| } else { | |
| return mli_krn_group_conv2d_hwcn_sa8_sa8_sa32; | |
| } | |
| } | |
| #endif | |
| // Pooling specialized functions. | |
| typedef mli_status (*pooling_func_ptr)(const mli_tensor* /*in*/, | |
| const mli_pool_cfg* /*cfg*/, | |
| mli_tensor* /*out*/); | |
| #ifdef MLI_2_0 | |
| pooling_func_ptr __attribute__((weak)) | |
| mli_krn_avepool(const mli_pool_cfg* cfg) { | |
| int filter_w = cfg->kernel_width; | |
| int filter_h = cfg->kernel_height; | |
| if (filter_w == 2 && filter_h == 2) { | |
| return mli_krn_avepool_hwc_sa8_k2x2; | |
| } else if (filter_w == 3 && filter_h == 3) { | |
| return mli_krn_avepool_hwc_sa8_k3x3; | |
| } else { | |
| return mli_krn_avepool_hwc_sa8; | |
| } | |
| } | |
| #else | |
| pooling_func_ptr __attribute__((weak)) | |
| mli_krn_avepool(const mli_pool_cfg* cfg) { | |
| return mli_krn_avepool_hwc_sa8; | |
| } | |
| #endif | |
| #ifdef MLI_2_0 | |
| pooling_func_ptr __attribute__((weak)) | |
| mli_krn_maxpool(const mli_pool_cfg* cfg) { | |
| int filter_w = cfg->kernel_width; | |
| int filter_h = cfg->kernel_height; | |
| if (filter_w == 2 && filter_h == 2) { | |
| return mli_krn_maxpool_hwc_sa8_k2x2; | |
| } else if (filter_w == 3 && filter_h == 3) { | |
| return mli_krn_maxpool_hwc_sa8_k3x3; | |
| } else { | |
| return mli_krn_maxpool_hwc_sa8; | |
| } | |
| } | |
| #else | |
| pooling_func_ptr __attribute__((weak)) | |
| mli_krn_maxpool(const mli_pool_cfg* cfg) { | |
| return mli_krn_maxpool_hwc_sa8; | |
| } | |
| #endif | |
| } // namespace tflite |