blob: 4de26b0ca75638db6b68b21973bb3bd2026b2746 [file] [log] [blame]
Lun Dong2258a822023-01-26 14:28:02 -08001/*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "iree/builtins/ukernel/common.h"
18
19#include <math.h>
20#include <riscv_vector.h>
21
22//===----------------------------------------------------------------------===//
23// Helpers for defining generic implementations of elementwise functions.
24// Since it affords the best code size tradeoff options, the entrypoint
25// is dispatched based on an opcode.
26//===----------------------------------------------------------------------===//
27
28// Opcodes for generic functions operating on 32-bit operands and result.
29// Since the outer dispatcher only differentiates based on width, all other
30// type specificity is carried by the opcode.
31// Binary opcodes are named "X32B" and unary opcodes "X32U".
32// The initial list was sorted, and it is encouraged to sort extensions, but
33// each opcode must be numerically stable, so the list is not expected to
34// be sorted over time.
35typedef enum {
36 IREE_UK_X32B_ADDF = 0,
37 IREE_UK_X32B_ADDI = 1,
38 IREE_UK_X32B_ANDI = 2,
39 IREE_UK_X32B_DIVF = 3,
40 IREE_UK_X32B_DIVSI = 4,
41 IREE_UK_X32B_DIVUI = 5,
42 IREE_UK_X32B_MULF = 6,
43 IREE_UK_X32B_MULI = 7,
44 IREE_UK_X32B_ORI = 8,
45 IREE_UK_X32B_SHLI = 9,
46 IREE_UK_X32B_SHRSI = 10,
47 IREE_UK_X32B_SHRUI = 11,
48 IREE_UK_X32B_SUBF = 12,
49 IREE_UK_X32B_SUBI = 13,
50 IREE_UKENREL_X32B_XORI = 14,
51} iree_uk_x32b_opcode_t;
52
53typedef enum {
54 IREE_UK_X32B_UI = 0, // unsigned integer
55 IREE_UK_X32B_SI = 1, // signed integer
56 IREE_UK_X32B_NA = 2, // not available in RVV
57} iree_uk_x32b_opcode_type_t;
58
59typedef enum {
60 IREE_UK_X32U_ABSF,
61 IREE_UK_X32U_CEILF,
62 IREE_UK_X32U_CTLZ,
63 IREE_UK_X32U_EXPF,
64 IREE_UK_X32U_FLOORF,
65 IREE_UK_X32U_LOGF,
66 IREE_UK_X32U_NEGF,
67 IREE_UK_X32U_RSQRTF,
68} iree_uk_x32u_opcode_t;
69
70// Macros to access various typed, dereferenced pointers.
71#define ASF32(ptr) *((float*)ptr)
72#define ASUI32(ptr) *((iree_uk_uint32_t*)ptr)
73#define ASSI32(ptr) *((iree_uk_int32_t*)ptr)
74
75//===----------------------------------------------------------------------===//
Lun Dong2258a822023-01-26 14:28:02 -080076// Implementation macros.
77//===----------------------------------------------------------------------===//
78
79// Defines a generic "dispatched" implementation via opcode_t by invoking
80// the function iree_uk_generic_{category}_2d.
81// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
82#define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \
83 IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
84 const dtype* lhs, iree_uk_ssize_t lhs_offset, \
85 iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \
86 const dtype* rhs, iree_uk_ssize_t rhs_offset, \
87 iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \
88 dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, \
89 iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \
90 iree_uk_ssize_t size0, iree_uk_ssize_t size1) { \
91 return iree_uk_##category##_2d(opcode_t, lhs, lhs_offset, lhs_stride0, \
92 lhs_stride1, rhs, rhs_offset, rhs_stride0, \
93 rhs_stride1, out, out_offset, out_stride0, \
94 out_stride1, size0, size1); \
95 }
96
97// Defines a generic "dispatched" implementation via opcode_t by invoking
98// the function iree_uk_generic_{category}_2d.
99// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
100#define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \
101 IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
102 const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
103 iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out, \
104 iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0, \
105 iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0, \
106 iree_uk_ssize_t size1) { \
107 return iree_uk_generic_##category##_2d( \
108 opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \
109 out_stride0, out_stride1, size0, size1); \
110 }
111
112//===----------------------------------------------------------------------===//
113// Internal helpers.
114//===----------------------------------------------------------------------===//
115
116static iree_uk_x32b_opcode_type_t get_iree_uk_x32b_op_type(
117 iree_uk_x32b_opcode_t opcode) {
118 switch (opcode) {
119 case IREE_UK_X32B_ADDI:
120 case IREE_UK_X32B_ANDI:
121 case IREE_UK_X32B_DIVUI:
122 case IREE_UK_X32B_MULI:
123 case IREE_UK_X32B_ORI:
124 case IREE_UK_X32B_SHLI:
125 case IREE_UK_X32B_SHRUI:
126 case IREE_UKENREL_X32B_XORI:
127 case IREE_UK_X32B_SUBI:
128 return IREE_UK_X32B_UI;
129 case IREE_UK_X32B_DIVSI:
130 return IREE_UK_X32B_SI;
131 default:
132 return IREE_UK_X32B_NA;
133 }
134}
135
136// Computes a single element of an x32b opcode usinbg RVV.
137static void iree_uk_rvv_x32b_op(iree_uk_x32b_opcode_t opcode, int* result_code,
138 const iree_uk_uint32_t* lhs,
139 iree_uk_ssize_t lhs_stride,
140 const iree_uk_uint32_t* rhs,
141 iree_uk_ssize_t rhs_stride,
142 iree_uk_uint32_t* out,
143 iree_uk_ssize_t out_stride, size_t vl) {
144 iree_uk_x32b_opcode_type_t op_type = get_iree_uk_x32b_op_type(opcode);
145 if (op_type == IREE_UK_X32B_UI) {
146 vuint32m8_t vx = vlse32_v_u32m8(lhs, lhs_stride, vl); // load
147 vuint32m8_t vy = vlse32_v_u32m8(rhs, rhs_stride, vl); // load
148 switch (opcode) {
149 case IREE_UK_X32B_ADDI:
150 vx = vadd(vx, vy, vl);
151 break;
152 case IREE_UK_X32B_ANDI:
153 vx = vand(vx, vy, vl);
154 break;
155 case IREE_UK_X32B_DIVUI:
156 vx = vdivu(vx, vy, vl);
157 break;
158 case IREE_UK_X32B_MULI:
159 vx = vmul(vx, vy, vl);
160 break;
161 case IREE_UK_X32B_ORI:
162 vx = vor(vx, vy, vl);
163 break;
164 case IREE_UK_X32B_SHLI:
165 vx = vsll(vx, vy, vl);
166 break;
167 case IREE_UK_X32B_SHRUI:
168 vx = vsrl(vx, vy, vl);
169 break;
170 case IREE_UKENREL_X32B_XORI:
171 vx = vor(vx, vy, vl);
172 break;
173 case IREE_UK_X32B_SUBI:
174 vx = vsub(vx, vy, vl);
175 break;
176 default:
177 *result_code = 1;
178 }
179 vsse32(out, out_stride, vx, vl); // save
180 } else if (op_type == IREE_UK_X32B_SI) {
181 vint32m8_t vx =
182 vlse32_v_i32m8((iree_uk_int32_t*)lhs, lhs_stride, vl); // load
183 vint32m8_t vy =
184 vlse32_v_i32m8((iree_uk_int32_t*)rhs, rhs_stride, vl); // load
185 switch (opcode) {
186 case IREE_UK_X32B_DIVSI:
187 vx = vdiv(vx, vy, vl);
188 break;
189 default:
190 *result_code = 1;
191 }
192 vsse32((iree_uk_int32_t*)out, out_stride, vx, vl); // save
193 } else {
194 *result_code = 1;
195 }
196}
197
198// Computes a single element of an x32b opcode. On error, should set
199// |*result_code| to a non-zero value (but should not touch it otherwise).
200static void iree_uk_generic_x32b_op(iree_uk_x32b_opcode_t opcode,
201 int* result_code,
202 const iree_uk_uint32_t* lhs,
203 const iree_uk_uint32_t* rhs,
204 iree_uk_uint32_t* out) {
205 switch (opcode) {
206 case IREE_UK_X32B_ADDF:
207 ASF32(out) = ASF32(lhs) + ASF32(rhs);
208 return;
209 case IREE_UK_X32B_ADDI:
210 ASUI32(out) = ASUI32(lhs) + ASUI32(rhs);
211 return;
212 case IREE_UK_X32B_ANDI:
213 ASUI32(out) = ASUI32(lhs) & ASUI32(rhs);
214 return;
215 case IREE_UK_X32B_DIVF:
216 ASF32(out) = ASF32(lhs) / ASF32(rhs);
217 return;
218 case IREE_UK_X32B_DIVSI:
219 ASSI32(out) = ASSI32(lhs) / ASSI32(rhs);
220 return;
221 case IREE_UK_X32B_DIVUI:
222 ASUI32(out) = ASUI32(lhs) / ASUI32(rhs);
223 return;
224 case IREE_UK_X32B_MULF:
225 ASF32(out) = ASF32(lhs) * ASF32(rhs);
226 return;
227 case IREE_UK_X32B_MULI:
228 ASUI32(out) = ASUI32(lhs) * ASUI32(rhs);
229 return;
230 case IREE_UK_X32B_ORI:
231 ASUI32(out) = ASUI32(lhs) | ASUI32(rhs);
232 return;
233 case IREE_UK_X32B_SHLI:
234 ASUI32(out) = ASUI32(lhs) << ASUI32(rhs);
235 return;
236 case IREE_UK_X32B_SHRSI:
237 ASSI32(out) = ASSI32(lhs) >> ASSI32(rhs);
238 return;
239 case IREE_UK_X32B_SHRUI:
240 ASUI32(out) = ASUI32(lhs) >> ASUI32(rhs);
241 return;
242 case IREE_UKENREL_X32B_XORI:
243 ASUI32(out) = ASUI32(lhs) ^ ASUI32(rhs);
244 return;
245 case IREE_UK_X32B_SUBF:
246 ASF32(out) = ASF32(lhs) - ASF32(rhs);
247 return;
248 case IREE_UK_X32B_SUBI:
249 ASSI32(out) = ASUI32(lhs) - ASUI32(rhs);
250 return;
251 default:
252 *result_code = 1;
253 }
254}
255
256// Computes a single element of an x32u opcode. Most are float ops. On error,
257// should set |*result_code| to a non-zero value (but should not touch it
258// otherwise).
259static void iree_uk_generic_x32u_op(iree_uk_x32u_opcode_t opcode,
260 int* result_code,
261 const iree_uk_uint32_t* in,
262 iree_uk_uint32_t* out) {
263 switch (opcode) {
264 case IREE_UK_X32U_ABSF:
265 ASF32(out) = fabsf(ASF32(in));
266 return;
267 case IREE_UK_X32U_CEILF:
268 ASF32(out) = ceilf(ASF32(in));
269 return;
270 case IREE_UK_X32U_CTLZ:
271 ASUI32(out) = iree_uk_count_leading_zeros_u32(ASUI32(in));
272 return;
273 case IREE_UK_X32U_EXPF:
274 ASF32(out) = expf(ASF32(in));
275 return;
276 case IREE_UK_X32U_FLOORF:
277 ASF32(out) = floorf(ASF32(in));
278 return;
279 case IREE_UK_X32U_LOGF:
280 ASF32(out) = logf(ASF32(in));
281 return;
282 case IREE_UK_X32U_NEGF:
283 ASF32(out) = -ASF32(in);
284 return;
285 case IREE_UK_X32U_RSQRTF:
286 ASF32(out) = 1.0f / sqrtf(ASF32(in));
287 return;
288 default:
289 *result_code = 1;
290 }
291}
292
293//===----------------------------------------------------------------------===//
294// Opcode dispatch entry points.
295//===----------------------------------------------------------------------===//
296
297// 32bit binary kernels.
298IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_x32b_2d(
299 iree_uk_x32b_opcode_t opcode,
300 // LHS.
301 const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
302 iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
303 // RHS
304 const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
305 iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
306 // OUT.
307 iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
308 iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
309 // Sizes.
310 iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
311 int result_code = 0;
312
313 if (get_iree_uk_x32b_op_type(opcode) != IREE_UK_X32B_NA) {
314 size_t vl;
315 // make most use of vectorization by swiching dimension
316 if (size0 < size1) {
317 for (iree_uk_ssize_t i = 0; i < size0; ++i) {
318 for (iree_uk_ssize_t j = 0; j < size1; j += vl) {
319 vl = vsetvl_e32m8(size1 - j);
320 iree_uk_rvv_x32b_op(opcode, &result_code,
321 &lhs[i * lhs_stride0 + j * lhs_stride1],
322 lhs_stride1 * sizeof(uint32_t),
323 &rhs[i * rhs_stride0 + j * rhs_stride1],
324 rhs_stride1 * sizeof(uint32_t),
325 &out[i * out_stride0 + j * out_stride1],
326 out_stride1 * sizeof(uint32_t), vl);
327 }
328 }
329 } else {
330 for (iree_uk_ssize_t j = 0; j < size1; ++j) {
331 for (iree_uk_ssize_t i = 0; i < size0; i += vl) {
332 vl = vsetvl_e32m8(size0 - i);
333 iree_uk_rvv_x32b_op(opcode, &result_code,
334 &lhs[i * lhs_stride0 + j * lhs_stride1],
335 lhs_stride0 * sizeof(uint32_t),
336 &rhs[i * rhs_stride0 + j * rhs_stride1],
337 rhs_stride0 * sizeof(uint32_t),
338 &out[i * out_stride0 + j * out_stride1],
339 out_stride0 * sizeof(uint32_t), vl);
340 }
341 }
342 }
343 } else {
344 for (iree_uk_ssize_t i = 0; i < size0; ++i) {
345 for (iree_uk_ssize_t j = 0; j < size1; ++j) {
346 iree_uk_generic_x32b_op(opcode, &result_code,
347 &lhs[i * lhs_stride0 + j * lhs_stride1],
348 &rhs[i * rhs_stride0 + j * rhs_stride1],
349 &out[i * out_stride0 + j * out_stride1]);
350 }
351 }
352 }
353 return result_code;
354}
355
356// Generic 32bit unary kernels.
357IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d(
358 iree_uk_x32u_opcode_t opcode,
359 // IN.
360 const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
361 iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
362 // OUT.
363 iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
364 iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
365 // Sizes.
366 iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
367 int result_code = 0;
368 // TODO: Manually unroll to x4 to trigger vectorization.
369 for (iree_uk_ssize_t i = 0; i < size0; ++i) {
370 for (iree_uk_ssize_t j = 0; j < size1; ++j) {
371 iree_uk_generic_x32u_op(opcode, &result_code,
372 &in[i * in_stride0 + j * in_stride1],
373 &out[i * out_stride0 + j * out_stride1]);
374 }
375 }
376 return result_code;
377}