blob: 6276fe73380b5d5a76bc37d664774a63fd42e06c [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mli_api.h" // NOLINT
namespace tflite {
// Convolution specialized function.
typedef mli_status (*conv_func_ptr)(const mli_tensor* /*in*/,
const mli_tensor* /*weights*/,
const mli_tensor* /*bias*/,
const mli_conv2d_cfg* /*cfg*/,
mli_tensor* /*out*/);
#ifdef MLI_2_0
conv_func_ptr __attribute__((weak))
mli_krn_conv2d_hwcn(const mli_tensor* weights) {
int filter_w = weights->shape[KRNL_W_DIM_HWCN];
int filter_h = weights->shape[KRNL_H_DIM_HWCN];
if (filter_w == 1 && filter_h == 1) {
return mli_krn_conv2d_hwcn_sa8_sa8_sa32_k1x1;
} else if (filter_w == 3 && filter_h == 3) {
return mli_krn_conv2d_hwcn_sa8_sa8_sa32_k3x3;
} else if (filter_w == 5 && filter_h == 5) {
return mli_krn_conv2d_hwcn_sa8_sa8_sa32_k5x5;
} else {
return mli_krn_conv2d_hwcn_sa8_sa8_sa32;
}
}
#else
conv_func_ptr __attribute__((weak))
mli_krn_conv2d_hwcn(const mli_tensor* weights, const mli_conv2d_cfg* cfg) {
return mli_krn_conv2d_nhwc_sa8_sa8_sa32;
}
#endif
// Depthwise convolution specialized function.
typedef mli_status (*depthwise_func_ptr)(const mli_tensor* /*in*/,
const mli_tensor* /*weights*/,
const mli_tensor* /*bias*/,
const mli_conv2d_cfg* /*cfg*/,
mli_tensor* /*out*/);
#ifdef MLI_2_0
depthwise_func_ptr __attribute__((weak))
mli_krn_depthwise_conv2d(const mli_tensor* weights) {
int filter_w = weights->shape[KRNL_DW_W_DIM_HW1N];
int filter_h = weights->shape[KRNL_DW_H_DIM_HW1N];
if (filter_w == 3 && filter_h == 3) {
return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32_k3x3;
} else if (filter_w == 5 && filter_h == 5) {
return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32_k5x5;
} else {
return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32;
}
}
#else
depthwise_func_ptr __attribute__((weak))
mli_krn_depthwise_conv2d(const mli_tensor* weights, const mli_conv2d_cfg* cfg) {
return mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32;
}
#endif
#ifdef MLI_2_0
depthwise_func_ptr __attribute__((weak))
mli_krn_group_conv2d(const mli_tensor* weights) {
int filter_w = weights->shape[KRNL_DW_W_DIM_HW1N];
int filter_h = weights->shape[KRNL_DW_H_DIM_HW1N];
if (filter_w == 3 && filter_h == 3) {
return mli_krn_group_conv2d_hwcn_sa8_sa8_sa32_k3x3;
} else if (filter_w == 5 && filter_h == 5) {
return mli_krn_group_conv2d_hwcn_sa8_sa8_sa32_k5x5;
} else {
return mli_krn_group_conv2d_hwcn_sa8_sa8_sa32;
}
}
#endif
// Pooling specialized functions.
typedef mli_status (*pooling_func_ptr)(const mli_tensor* /*in*/,
const mli_pool_cfg* /*cfg*/,
mli_tensor* /*out*/);
#ifdef MLI_2_0
pooling_func_ptr __attribute__((weak))
mli_krn_avepool(const mli_pool_cfg* cfg) {
int filter_w = cfg->kernel_width;
int filter_h = cfg->kernel_height;
if (filter_w == 2 && filter_h == 2) {
return mli_krn_avepool_hwc_sa8_k2x2;
} else if (filter_w == 3 && filter_h == 3) {
return mli_krn_avepool_hwc_sa8_k3x3;
} else {
return mli_krn_avepool_hwc_sa8;
}
}
#else
pooling_func_ptr __attribute__((weak))
mli_krn_avepool(const mli_pool_cfg* cfg) {
return mli_krn_avepool_hwc_sa8;
}
#endif
#ifdef MLI_2_0
pooling_func_ptr __attribute__((weak))
mli_krn_maxpool(const mli_pool_cfg* cfg) {
int filter_w = cfg->kernel_width;
int filter_h = cfg->kernel_height;
if (filter_w == 2 && filter_h == 2) {
return mli_krn_maxpool_hwc_sa8_k2x2;
} else if (filter_w == 3 && filter_h == 3) {
return mli_krn_maxpool_hwc_sa8_k3x3;
} else {
return mli_krn_maxpool_hwc_sa8;
}
}
#else
pooling_func_ptr __attribute__((weak))
mli_krn_maxpool(const mli_pool_cfg* cfg) {
return mli_krn_maxpool_hwc_sa8;
}
#endif
} // namespace tflite