Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1 | /* |
| 2 | * Copyright 2024 Google LLC |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #ifndef TFLM_OPT_CONV_UTIL_H_ |
| 18 | #define TFLM_OPT_CONV_UTIL_H_ |
| 19 | |
| 20 | #include <cassert> |
| 21 | #include <memory> |
| 22 | |
| 23 | #include "crt/kelvin.h" |
| 24 | #include "tensorflow/lite/kernels/internal/common.h" |
| 25 | #include "tensorflow/lite/kernels/internal/runtime_shape.h" |
| 26 | #include "tensorflow/lite/kernels/internal/types.h" |
| 27 | #include "tflm/opt/util.h" |
| 28 | |
| 29 | namespace kelvin::opt { |
| 30 | /* clang-format off */ |
| 31 | constexpr const int swizzle[16] = { |
| 32 | 0, 4, 8, 12, |
| 33 | 2, 6, 10, 14, |
| 34 | 1, 5, 9, 13, |
| 35 | 3, 7, 11, 15, |
| 36 | }; |
| 37 | /* clang-format on */ |
| 38 | |
| 39 | constexpr int kFilterHeightIndex = 1; |
| 40 | constexpr int kFilterWidthIndex = 2; |
| 41 | constexpr int kFilterInputChannelIndex = 3; |
| 42 | constexpr int kInputChannelIndex = 3; |
| 43 | constexpr int kOutputChannelIndex = 3; |
| 44 | |
| 45 | #define INA0 v0 |
| 46 | #define FLTA0 v8 |
| 47 | #define FLTA1 v9 |
| 48 | #define FLTA2 v10 |
| 49 | #define FLTA3 v11 |
| 50 | #define FLTA4 v12 |
| 51 | #define FLTA5 v13 |
| 52 | #define FLTA6 v14 |
| 53 | #define FLTA7 v15 |
| 54 | #define ACC v48 |
| 55 | #define ACC0 v48 |
| 56 | #define OUT0 v56 |
| 57 | |
| 58 | // H,W ( height and width of filter) N -number of inputs, M -number of outputs |
| 59 | template <int N> |
| 60 | inline void Filter_N_H_W_M(const int8_t* input, int8_t* output, int H, int W, |
| 61 | int M) { |
| 62 | // Convert: input [zo][ky][kx][zi] (N,3,1,M) |
| 63 | // output [zo.hi=N/8][ky][kx][zi_hi=M/4][zo.lo=8][zi_lo=4] |
| 64 | const int8_t(&in)[N][H][W][M] = *(int8_t(*)[N][H][W][M])input; |
| 65 | int8_t(&out)[N / 8][H][W][M / 4][8][4] = |
| 66 | *(int8_t(*)[N / 8][H][W][M / 4][8][4]) output; |
| 67 | assert(N >= 4 && M >= 4); |
| 68 | for (int zo = 0; zo < N; ++zo) { |
| 69 | for (int ky = 0; ky < H; ++ky) { |
| 70 | for (int kx = 0; kx < W; ++kx) { |
| 71 | for (int zi = 0; zi < M; ++zi) { |
| 72 | const int zo_hi = zo >> 3; // div8 |
| 73 | const int zo_lo = zo & 7; // rem8 |
| 74 | const int zi_hi = zi >> 2; // div4 |
| 75 | const int zi_lo = zi & 3; // rem4 |
| 76 | out[zo_hi][ky][kx][zi_hi][zo_lo][zi_lo] = in[zo][ky][kx][zi]; |
| 77 | } |
| 78 | } |
| 79 | } |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | // Swizzle values, and duplicate 4 times for stripmining. |
| 84 | inline void Swizzle(const int32_t* input, int32_t* output, int N, |
| 85 | bool negate = false) { |
| 86 | const int32_t(&in)[N] = *(int32_t(*)[N])input; |
| 87 | int32_t(&out)[N * 4] = *(int32_t(*)[N * 4]) output; |
| 88 | // Convert to accumulator swizzle pattern. |
| 89 | for (int i = 0; i < N / 8; ++i) { |
| 90 | int32_t* out0 = out + i * 32 + 0; |
| 91 | int32_t* out1 = out + i * 32 + 16; |
| 92 | int32_t* out2 = out + i * 32 + 8; |
| 93 | int32_t* out3 = out + i * 32 + 24; |
| 94 | for (int j = 0; j < 4; ++j) { |
| 95 | const int32_t* p_in = in + i * 8; |
| 96 | for (int k = 0; k < 2; ++k) { |
| 97 | *out0++ = *p_in++; |
| 98 | *out1++ = *p_in++; |
| 99 | *out2++ = *p_in++; |
| 100 | *out3++ = *p_in++; |
| 101 | } |
| 102 | } |
| 103 | } |
| 104 | if (negate) { |
| 105 | for (int i = 0; i < N * 4; ++i) { |
| 106 | out[i] = -out[i]; |
| 107 | } |
| 108 | } |
| 109 | } |
| 110 | |
Derek Chow | 2aeaaa8 | 2024-05-14 16:32:25 -0700 | [diff] [blame] | 111 | // Runs strip-mined output pipeline (without bias addition) in place on |
| 112 | // registers. |
| 113 | #define INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(result, mult, shft, output_min, \ |
| 114 | output_max, output_offset) \ |
| 115 | { \ |
| 116 | vdmulh_w_rn_vv_m(result, result, mult); \ |
| 117 | vsha_w_r_vv_m(result, result, shft); \ |
| 118 | vadd_w_vx_m(result, result, output_offset); \ |
| 119 | vmax_w_vx_m(result, result, output_activation_min); \ |
| 120 | vmin_w_vx_m(result, result, output_activation_max); \ |
| 121 | } |
| 122 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 123 | // Run output pipeline on int32 accumulators in [v48-v55] and store results |
| 124 | // in v48 and v52. Clobbers [v48-v55]. |
| 125 | #define INT32_TO_INT8_OUTPUT_PIPELINE(bias, mult, shft, output_min, \ |
| 126 | output_max, output_offset, bias_reg, \ |
| 127 | mult_reg, shift_reg) \ |
| 128 | { \ |
| 129 | vcget(v48); \ |
| 130 | vld_w_x_m(bias_reg, bias); \ |
| 131 | vld_w_x_m(mult_reg, mult); \ |
| 132 | vld_w_x_m(shift_reg, shft); \ |
| 133 | vadd_w_vv_m(v48, v48, bias_reg); \ |
| 134 | vadd_w_vv_m(v52, v52, bias_reg); \ |
| 135 | vmin_w_vx_m(v48, v48, output_max); \ |
| 136 | vmax_w_vx_m(v52, v52, output_min); \ |
| 137 | vdmulh_w_r_vv_m(v48, v48, mult_reg); \ |
| 138 | vdmulh_w_r_vv_m(v52, v52, mult_reg); \ |
| 139 | vsha_w_r_vv_m(v48, v48, shift_reg); \ |
| 140 | vsha_w_r_vv_m(v52, v52, shift_reg); \ |
| 141 | vadd_w_vx_m(v48, v48, output_offset); \ |
| 142 | vadd_w_vx_m(v52, v52, output_offset); \ |
| 143 | vsraqs_b_vx(v48, v48, 0); \ |
| 144 | vsraqs_b_vx(v52, v52, 0); \ |
| 145 | } |
| 146 | } // namespace kelvin::opt |
| 147 | |
| 148 | #endif // TFLM_OPT_CONV_UTIL_H_ |