Lun Dong | 2258a82 | 2023-01-26 14:28:02 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright 2023 Google LLC |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #include "iree/builtins/ukernel/common.h" |
| 18 | |
| 19 | #include <math.h> |
| 20 | #include <riscv_vector.h> |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | // Helpers for defining generic implementations of elementwise functions. |
| 24 | // Since it affords the best code size tradeoff options, the entrypoint |
| 25 | // is dispatched based on an opcode. |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | |
| 28 | // Opcodes for generic functions operating on 32-bit operands and result. |
| 29 | // Since the outer dispatcher only differentiates based on width, all other |
| 30 | // type specificity is carried by the opcode. |
| 31 | // Binary opcodes are named "X32B" and unary opcodes "X32U". |
| 32 | // The initial list was sorted, and it is encouraged to sort extensions, but |
| 33 | // each opcode must be numerically stable, so the list is not expected to |
| 34 | // be sorted over time. |
| 35 | typedef enum { |
| 36 | IREE_UK_X32B_ADDF = 0, |
| 37 | IREE_UK_X32B_ADDI = 1, |
| 38 | IREE_UK_X32B_ANDI = 2, |
| 39 | IREE_UK_X32B_DIVF = 3, |
| 40 | IREE_UK_X32B_DIVSI = 4, |
| 41 | IREE_UK_X32B_DIVUI = 5, |
| 42 | IREE_UK_X32B_MULF = 6, |
| 43 | IREE_UK_X32B_MULI = 7, |
| 44 | IREE_UK_X32B_ORI = 8, |
| 45 | IREE_UK_X32B_SHLI = 9, |
| 46 | IREE_UK_X32B_SHRSI = 10, |
| 47 | IREE_UK_X32B_SHRUI = 11, |
| 48 | IREE_UK_X32B_SUBF = 12, |
| 49 | IREE_UK_X32B_SUBI = 13, |
| 50 | IREE_UKENREL_X32B_XORI = 14, |
| 51 | } iree_uk_x32b_opcode_t; |
| 52 | |
| 53 | typedef enum { |
| 54 | IREE_UK_X32B_UI = 0, // unsigned integer |
| 55 | IREE_UK_X32B_SI = 1, // signed integer |
| 56 | IREE_UK_X32B_NA = 2, // not available in RVV |
| 57 | } iree_uk_x32b_opcode_type_t; |
| 58 | |
| 59 | typedef enum { |
| 60 | IREE_UK_X32U_ABSF, |
| 61 | IREE_UK_X32U_CEILF, |
| 62 | IREE_UK_X32U_CTLZ, |
| 63 | IREE_UK_X32U_EXPF, |
| 64 | IREE_UK_X32U_FLOORF, |
| 65 | IREE_UK_X32U_LOGF, |
| 66 | IREE_UK_X32U_NEGF, |
| 67 | IREE_UK_X32U_RSQRTF, |
| 68 | } iree_uk_x32u_opcode_t; |
| 69 | |
| 70 | // Macros to access various typed, dereferenced pointers. |
| 71 | #define ASF32(ptr) *((float*)ptr) |
| 72 | #define ASUI32(ptr) *((iree_uk_uint32_t*)ptr) |
| 73 | #define ASSI32(ptr) *((iree_uk_int32_t*)ptr) |
| 74 | |
| 75 | //===----------------------------------------------------------------------===// |
Lun Dong | 2258a82 | 2023-01-26 14:28:02 -0800 | [diff] [blame] | 76 | // Implementation macros. |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | |
| 79 | // Defines a generic "dispatched" implementation via opcode_t by invoking |
| 80 | // the function iree_uk_generic_{category}_2d. |
| 81 | // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D. |
| 82 | #define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \ |
| 83 | IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \ |
| 84 | const dtype* lhs, iree_uk_ssize_t lhs_offset, \ |
| 85 | iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \ |
| 86 | const dtype* rhs, iree_uk_ssize_t rhs_offset, \ |
| 87 | iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \ |
| 88 | dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, \ |
| 89 | iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \ |
| 90 | iree_uk_ssize_t size0, iree_uk_ssize_t size1) { \ |
| 91 | return iree_uk_##category##_2d(opcode_t, lhs, lhs_offset, lhs_stride0, \ |
| 92 | lhs_stride1, rhs, rhs_offset, rhs_stride0, \ |
| 93 | rhs_stride1, out, out_offset, out_stride0, \ |
| 94 | out_stride1, size0, size1); \ |
| 95 | } |
| 96 | |
| 97 | // Defines a generic "dispatched" implementation via opcode_t by invoking |
| 98 | // the function iree_uk_generic_{category}_2d. |
| 99 | // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D. |
| 100 | #define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \ |
| 101 | IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \ |
| 102 | const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \ |
| 103 | iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out, \ |
| 104 | iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0, \ |
| 105 | iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0, \ |
| 106 | iree_uk_ssize_t size1) { \ |
| 107 | return iree_uk_generic_##category##_2d( \ |
| 108 | opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \ |
| 109 | out_stride0, out_stride1, size0, size1); \ |
| 110 | } |
| 111 | |
| 112 | //===----------------------------------------------------------------------===// |
| 113 | // Internal helpers. |
| 114 | //===----------------------------------------------------------------------===// |
| 115 | |
| 116 | static iree_uk_x32b_opcode_type_t get_iree_uk_x32b_op_type( |
| 117 | iree_uk_x32b_opcode_t opcode) { |
| 118 | switch (opcode) { |
| 119 | case IREE_UK_X32B_ADDI: |
| 120 | case IREE_UK_X32B_ANDI: |
| 121 | case IREE_UK_X32B_DIVUI: |
| 122 | case IREE_UK_X32B_MULI: |
| 123 | case IREE_UK_X32B_ORI: |
| 124 | case IREE_UK_X32B_SHLI: |
| 125 | case IREE_UK_X32B_SHRUI: |
| 126 | case IREE_UKENREL_X32B_XORI: |
| 127 | case IREE_UK_X32B_SUBI: |
| 128 | return IREE_UK_X32B_UI; |
| 129 | case IREE_UK_X32B_DIVSI: |
| 130 | return IREE_UK_X32B_SI; |
| 131 | default: |
| 132 | return IREE_UK_X32B_NA; |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | // Computes a single element of an x32b opcode usinbg RVV. |
| 137 | static void iree_uk_rvv_x32b_op(iree_uk_x32b_opcode_t opcode, int* result_code, |
| 138 | const iree_uk_uint32_t* lhs, |
| 139 | iree_uk_ssize_t lhs_stride, |
| 140 | const iree_uk_uint32_t* rhs, |
| 141 | iree_uk_ssize_t rhs_stride, |
| 142 | iree_uk_uint32_t* out, |
| 143 | iree_uk_ssize_t out_stride, size_t vl) { |
| 144 | iree_uk_x32b_opcode_type_t op_type = get_iree_uk_x32b_op_type(opcode); |
| 145 | if (op_type == IREE_UK_X32B_UI) { |
| 146 | vuint32m8_t vx = vlse32_v_u32m8(lhs, lhs_stride, vl); // load |
| 147 | vuint32m8_t vy = vlse32_v_u32m8(rhs, rhs_stride, vl); // load |
| 148 | switch (opcode) { |
| 149 | case IREE_UK_X32B_ADDI: |
| 150 | vx = vadd(vx, vy, vl); |
| 151 | break; |
| 152 | case IREE_UK_X32B_ANDI: |
| 153 | vx = vand(vx, vy, vl); |
| 154 | break; |
| 155 | case IREE_UK_X32B_DIVUI: |
| 156 | vx = vdivu(vx, vy, vl); |
| 157 | break; |
| 158 | case IREE_UK_X32B_MULI: |
| 159 | vx = vmul(vx, vy, vl); |
| 160 | break; |
| 161 | case IREE_UK_X32B_ORI: |
| 162 | vx = vor(vx, vy, vl); |
| 163 | break; |
| 164 | case IREE_UK_X32B_SHLI: |
| 165 | vx = vsll(vx, vy, vl); |
| 166 | break; |
| 167 | case IREE_UK_X32B_SHRUI: |
| 168 | vx = vsrl(vx, vy, vl); |
| 169 | break; |
| 170 | case IREE_UKENREL_X32B_XORI: |
| 171 | vx = vor(vx, vy, vl); |
| 172 | break; |
| 173 | case IREE_UK_X32B_SUBI: |
| 174 | vx = vsub(vx, vy, vl); |
| 175 | break; |
| 176 | default: |
| 177 | *result_code = 1; |
| 178 | } |
| 179 | vsse32(out, out_stride, vx, vl); // save |
| 180 | } else if (op_type == IREE_UK_X32B_SI) { |
| 181 | vint32m8_t vx = |
| 182 | vlse32_v_i32m8((iree_uk_int32_t*)lhs, lhs_stride, vl); // load |
| 183 | vint32m8_t vy = |
| 184 | vlse32_v_i32m8((iree_uk_int32_t*)rhs, rhs_stride, vl); // load |
| 185 | switch (opcode) { |
| 186 | case IREE_UK_X32B_DIVSI: |
| 187 | vx = vdiv(vx, vy, vl); |
| 188 | break; |
| 189 | default: |
| 190 | *result_code = 1; |
| 191 | } |
| 192 | vsse32((iree_uk_int32_t*)out, out_stride, vx, vl); // save |
| 193 | } else { |
| 194 | *result_code = 1; |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | // Computes a single element of an x32b opcode. On error, should set |
| 199 | // |*result_code| to a non-zero value (but should not touch it otherwise). |
| 200 | static void iree_uk_generic_x32b_op(iree_uk_x32b_opcode_t opcode, |
| 201 | int* result_code, |
| 202 | const iree_uk_uint32_t* lhs, |
| 203 | const iree_uk_uint32_t* rhs, |
| 204 | iree_uk_uint32_t* out) { |
| 205 | switch (opcode) { |
| 206 | case IREE_UK_X32B_ADDF: |
| 207 | ASF32(out) = ASF32(lhs) + ASF32(rhs); |
| 208 | return; |
| 209 | case IREE_UK_X32B_ADDI: |
| 210 | ASUI32(out) = ASUI32(lhs) + ASUI32(rhs); |
| 211 | return; |
| 212 | case IREE_UK_X32B_ANDI: |
| 213 | ASUI32(out) = ASUI32(lhs) & ASUI32(rhs); |
| 214 | return; |
| 215 | case IREE_UK_X32B_DIVF: |
| 216 | ASF32(out) = ASF32(lhs) / ASF32(rhs); |
| 217 | return; |
| 218 | case IREE_UK_X32B_DIVSI: |
| 219 | ASSI32(out) = ASSI32(lhs) / ASSI32(rhs); |
| 220 | return; |
| 221 | case IREE_UK_X32B_DIVUI: |
| 222 | ASUI32(out) = ASUI32(lhs) / ASUI32(rhs); |
| 223 | return; |
| 224 | case IREE_UK_X32B_MULF: |
| 225 | ASF32(out) = ASF32(lhs) * ASF32(rhs); |
| 226 | return; |
| 227 | case IREE_UK_X32B_MULI: |
| 228 | ASUI32(out) = ASUI32(lhs) * ASUI32(rhs); |
| 229 | return; |
| 230 | case IREE_UK_X32B_ORI: |
| 231 | ASUI32(out) = ASUI32(lhs) | ASUI32(rhs); |
| 232 | return; |
| 233 | case IREE_UK_X32B_SHLI: |
| 234 | ASUI32(out) = ASUI32(lhs) << ASUI32(rhs); |
| 235 | return; |
| 236 | case IREE_UK_X32B_SHRSI: |
| 237 | ASSI32(out) = ASSI32(lhs) >> ASSI32(rhs); |
| 238 | return; |
| 239 | case IREE_UK_X32B_SHRUI: |
| 240 | ASUI32(out) = ASUI32(lhs) >> ASUI32(rhs); |
| 241 | return; |
| 242 | case IREE_UKENREL_X32B_XORI: |
| 243 | ASUI32(out) = ASUI32(lhs) ^ ASUI32(rhs); |
| 244 | return; |
| 245 | case IREE_UK_X32B_SUBF: |
| 246 | ASF32(out) = ASF32(lhs) - ASF32(rhs); |
| 247 | return; |
| 248 | case IREE_UK_X32B_SUBI: |
| 249 | ASSI32(out) = ASUI32(lhs) - ASUI32(rhs); |
| 250 | return; |
| 251 | default: |
| 252 | *result_code = 1; |
| 253 | } |
| 254 | } |
| 255 | |
| 256 | // Computes a single element of an x32u opcode. Most are float ops. On error, |
| 257 | // should set |*result_code| to a non-zero value (but should not touch it |
| 258 | // otherwise). |
| 259 | static void iree_uk_generic_x32u_op(iree_uk_x32u_opcode_t opcode, |
| 260 | int* result_code, |
| 261 | const iree_uk_uint32_t* in, |
| 262 | iree_uk_uint32_t* out) { |
| 263 | switch (opcode) { |
| 264 | case IREE_UK_X32U_ABSF: |
| 265 | ASF32(out) = fabsf(ASF32(in)); |
| 266 | return; |
| 267 | case IREE_UK_X32U_CEILF: |
| 268 | ASF32(out) = ceilf(ASF32(in)); |
| 269 | return; |
| 270 | case IREE_UK_X32U_CTLZ: |
| 271 | ASUI32(out) = iree_uk_count_leading_zeros_u32(ASUI32(in)); |
| 272 | return; |
| 273 | case IREE_UK_X32U_EXPF: |
| 274 | ASF32(out) = expf(ASF32(in)); |
| 275 | return; |
| 276 | case IREE_UK_X32U_FLOORF: |
| 277 | ASF32(out) = floorf(ASF32(in)); |
| 278 | return; |
| 279 | case IREE_UK_X32U_LOGF: |
| 280 | ASF32(out) = logf(ASF32(in)); |
| 281 | return; |
| 282 | case IREE_UK_X32U_NEGF: |
| 283 | ASF32(out) = -ASF32(in); |
| 284 | return; |
| 285 | case IREE_UK_X32U_RSQRTF: |
| 286 | ASF32(out) = 1.0f / sqrtf(ASF32(in)); |
| 287 | return; |
| 288 | default: |
| 289 | *result_code = 1; |
| 290 | } |
| 291 | } |
| 292 | |
| 293 | //===----------------------------------------------------------------------===// |
| 294 | // Opcode dispatch entry points. |
| 295 | //===----------------------------------------------------------------------===// |
| 296 | |
| 297 | // 32bit binary kernels. |
| 298 | IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_x32b_2d( |
| 299 | iree_uk_x32b_opcode_t opcode, |
| 300 | // LHS. |
| 301 | const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset, |
| 302 | iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, |
| 303 | // RHS |
| 304 | const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset, |
| 305 | iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, |
| 306 | // OUT. |
| 307 | iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, |
| 308 | iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, |
| 309 | // Sizes. |
| 310 | iree_uk_ssize_t size0, iree_uk_ssize_t size1) { |
| 311 | int result_code = 0; |
| 312 | |
| 313 | if (get_iree_uk_x32b_op_type(opcode) != IREE_UK_X32B_NA) { |
| 314 | size_t vl; |
| 315 | // make most use of vectorization by swiching dimension |
| 316 | if (size0 < size1) { |
| 317 | for (iree_uk_ssize_t i = 0; i < size0; ++i) { |
| 318 | for (iree_uk_ssize_t j = 0; j < size1; j += vl) { |
| 319 | vl = vsetvl_e32m8(size1 - j); |
| 320 | iree_uk_rvv_x32b_op(opcode, &result_code, |
| 321 | &lhs[i * lhs_stride0 + j * lhs_stride1], |
| 322 | lhs_stride1 * sizeof(uint32_t), |
| 323 | &rhs[i * rhs_stride0 + j * rhs_stride1], |
| 324 | rhs_stride1 * sizeof(uint32_t), |
| 325 | &out[i * out_stride0 + j * out_stride1], |
| 326 | out_stride1 * sizeof(uint32_t), vl); |
| 327 | } |
| 328 | } |
| 329 | } else { |
| 330 | for (iree_uk_ssize_t j = 0; j < size1; ++j) { |
| 331 | for (iree_uk_ssize_t i = 0; i < size0; i += vl) { |
| 332 | vl = vsetvl_e32m8(size0 - i); |
| 333 | iree_uk_rvv_x32b_op(opcode, &result_code, |
| 334 | &lhs[i * lhs_stride0 + j * lhs_stride1], |
| 335 | lhs_stride0 * sizeof(uint32_t), |
| 336 | &rhs[i * rhs_stride0 + j * rhs_stride1], |
| 337 | rhs_stride0 * sizeof(uint32_t), |
| 338 | &out[i * out_stride0 + j * out_stride1], |
| 339 | out_stride0 * sizeof(uint32_t), vl); |
| 340 | } |
| 341 | } |
| 342 | } |
| 343 | } else { |
| 344 | for (iree_uk_ssize_t i = 0; i < size0; ++i) { |
| 345 | for (iree_uk_ssize_t j = 0; j < size1; ++j) { |
| 346 | iree_uk_generic_x32b_op(opcode, &result_code, |
| 347 | &lhs[i * lhs_stride0 + j * lhs_stride1], |
| 348 | &rhs[i * rhs_stride0 + j * rhs_stride1], |
| 349 | &out[i * out_stride0 + j * out_stride1]); |
| 350 | } |
| 351 | } |
| 352 | } |
| 353 | return result_code; |
| 354 | } |
| 355 | |
| 356 | // Generic 32bit unary kernels. |
| 357 | IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d( |
| 358 | iree_uk_x32u_opcode_t opcode, |
| 359 | // IN. |
| 360 | const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset, |
| 361 | iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1, |
| 362 | // OUT. |
| 363 | iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, |
| 364 | iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, |
| 365 | // Sizes. |
| 366 | iree_uk_ssize_t size0, iree_uk_ssize_t size1) { |
| 367 | int result_code = 0; |
| 368 | // TODO: Manually unroll to x4 to trigger vectorization. |
| 369 | for (iree_uk_ssize_t i = 0; i < size0; ++i) { |
| 370 | for (iree_uk_ssize_t j = 0; j < size1; ++j) { |
| 371 | iree_uk_generic_x32u_op(opcode, &result_code, |
| 372 | &in[i * in_stride0 + j * in_stride1], |
| 373 | &out[i * out_stride0 + j * out_stride1]); |
| 374 | } |
| 375 | } |
| 376 | return result_code; |
| 377 | } |