Use RVV for ukernel
Use RVV intrinsics for int ops in ukernel. This is an intermediate step
for implementing Kelvin in ukernel.
The main work is pretty much done. All available vmvx targets passed.
Tests cases for verifying the correctness will be added next.
Change-Id: Iccb6cce9464233be838bc24a42dd697327fb6b81
diff --git a/vmvx_ukernel/CMakeLists.txt b/vmvx_ukernel/CMakeLists.txt
index eff2b5b..3b9c0e8 100644
--- a/vmvx_ukernel/CMakeLists.txt
+++ b/vmvx_ukernel/CMakeLists.txt
@@ -24,8 +24,8 @@
"${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d_generic.h"
"${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack_generic.h"
SRCS
- "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/elementwise_generic.c"
- "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/elementwise_impl.c.inc"
+ "elementwise.c"
+ "elementwise_impl.c.inc"
"${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d.c"
"${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d_generic.c"
"${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack.c"
diff --git a/vmvx_ukernel/elementwise.c b/vmvx_ukernel/elementwise.c
new file mode 100644
index 0000000..d0e4683
--- /dev/null
+++ b/vmvx_ukernel/elementwise.c
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "iree/builtins/ukernel/elementwise.h"
+
+// Include the implementation helpers.
+#include "vmvx_ukernel/elementwise_impl.c.inc"
+
+DISPATCH_UKERNEL_BINARY_2D(addf, IREE_UK_X32B_ADDF, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(addi, IREE_UK_X32B_ADDI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(andi, IREE_UK_X32B_ANDI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(divf, IREE_UK_X32B_DIVF, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(divsi, IREE_UK_X32B_DIVSI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(divui, IREE_UK_X32B_DIVUI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(mulf, IREE_UK_X32B_MULF, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(muli, IREE_UK_X32B_MULI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(ori, IREE_UK_X32B_ORI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(shli, IREE_UK_X32B_SHLI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(shrsi, IREE_UK_X32B_SHRSI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(shrui, IREE_UK_X32B_SHRUI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(subf, IREE_UK_X32B_SUBF, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(subi, IREE_UK_X32B_SUBI, iree_uk_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(xori, IREE_UKENREL_X32B_XORI, iree_uk_uint32_t,
+ x32b);
+
+DISPATCH_UKERNEL_UNARY_2D(absf, IREE_UK_X32U_ABSF, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(ceilf, IREE_UK_X32U_CEILF, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(ctlz, IREE_UK_X32U_CTLZ, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(expf, IREE_UK_X32U_EXPF, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(floorf, IREE_UK_X32U_FLOORF, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(logf, IREE_UK_X32U_LOGF, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(negf, IREE_UK_X32U_NEGF, iree_uk_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(rsqrtf, IREE_UK_X32U_RSQRTF, iree_uk_uint32_t, x32u);
diff --git a/vmvx_ukernel/elementwise_impl.c.inc b/vmvx_ukernel/elementwise_impl.c.inc
new file mode 100644
index 0000000..baf1efb
--- /dev/null
+++ b/vmvx_ukernel/elementwise_impl.c.inc
@@ -0,0 +1,395 @@
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "iree/builtins/ukernel/common.h"
+
+#include <math.h>
+#include <riscv_vector.h>
+
+//===----------------------------------------------------------------------===//
+// Helpers for defining generic implementations of elementwise functions.
+// Since it affords the best code size tradeoff options, the entrypoint
+// is dispatched based on an opcode.
+//===----------------------------------------------------------------------===//
+
+// Opcodes for generic functions operating on 32-bit operands and result.
+// Since the outer dispatcher only differentiates based on width, all other
+// type specificity is carried by the opcode.
+// Binary opcodes are named "X32B" and unary opcodes "X32U".
+// The initial list was sorted, and it is encouraged to sort extensions, but
+// each opcode must be numerically stable, so the list is not expected to
+// be sorted over time.
+typedef enum {
+ IREE_UK_X32B_ADDF = 0,
+ IREE_UK_X32B_ADDI = 1,
+ IREE_UK_X32B_ANDI = 2,
+ IREE_UK_X32B_DIVF = 3,
+ IREE_UK_X32B_DIVSI = 4,
+ IREE_UK_X32B_DIVUI = 5,
+ IREE_UK_X32B_MULF = 6,
+ IREE_UK_X32B_MULI = 7,
+ IREE_UK_X32B_ORI = 8,
+ IREE_UK_X32B_SHLI = 9,
+ IREE_UK_X32B_SHRSI = 10,
+ IREE_UK_X32B_SHRUI = 11,
+ IREE_UK_X32B_SUBF = 12,
+ IREE_UK_X32B_SUBI = 13,
+ IREE_UKENREL_X32B_XORI = 14,
+} iree_uk_x32b_opcode_t;
+
+typedef enum {
+ IREE_UK_X32B_UI = 0, // unsigned integer
+ IREE_UK_X32B_SI = 1, // signed integer
+ IREE_UK_X32B_NA = 2, // not available in RVV
+} iree_uk_x32b_opcode_type_t;
+
+typedef enum {
+ IREE_UK_X32U_ABSF,
+ IREE_UK_X32U_CEILF,
+ IREE_UK_X32U_CTLZ,
+ IREE_UK_X32U_EXPF,
+ IREE_UK_X32U_FLOORF,
+ IREE_UK_X32U_LOGF,
+ IREE_UK_X32U_NEGF,
+ IREE_UK_X32U_RSQRTF,
+} iree_uk_x32u_opcode_t;
+
+// Macros to access various typed, dereferenced pointers.
+#define ASF32(ptr) *((float*)ptr)
+#define ASUI32(ptr) *((iree_uk_uint32_t*)ptr)
+#define ASSI32(ptr) *((iree_uk_int32_t*)ptr)
+
+//===----------------------------------------------------------------------===//
+// Math helper functions (extracted from base/internal/math.h and adapted
+// to be able to be used standalone).
+//===----------------------------------------------------------------------===//
+static inline int iree_uk_count_leading_zeros_u32(const iree_uk_uint32_t n) {
+#if defined(__GNUC__) || defined(__clang__)
+ // Handle 0 as a special case because __builtin_clz(0) is undefined.
+ if (n == 0) return 32;
+ // Use __builtin_clz, which uses the following instructions:
+ // x86: bsr
+ // ARM64: clz
+ // PPC: cntlzd
+ return (int)__builtin_clz(n);
+#else
+#error No clz for this arch.
+#endif // GCC / CLANG
+}
+
+//===----------------------------------------------------------------------===//
+// Implementation macros.
+//===----------------------------------------------------------------------===//
+
+// Defines a generic "dispatched" implementation via opcode_t by invoking
+// the function iree_uk_generic_{category}_2d.
+// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
+#define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \
+ IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
+ const dtype* lhs, iree_uk_ssize_t lhs_offset, \
+ iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \
+ const dtype* rhs, iree_uk_ssize_t rhs_offset, \
+ iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \
+ dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, \
+ iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \
+ iree_uk_ssize_t size0, iree_uk_ssize_t size1) { \
+ return iree_uk_##category##_2d(opcode_t, lhs, lhs_offset, lhs_stride0, \
+ lhs_stride1, rhs, rhs_offset, rhs_stride0, \
+ rhs_stride1, out, out_offset, out_stride0, \
+ out_stride1, size0, size1); \
+ }
+
+// Defines a generic "dispatched" implementation via opcode_t by invoking
+// the function iree_uk_generic_{category}_2d.
+// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
+#define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \
+ IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
+ const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
+ iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out, \
+ iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0, \
+ iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0, \
+ iree_uk_ssize_t size1) { \
+ return iree_uk_generic_##category##_2d( \
+ opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \
+ out_stride0, out_stride1, size0, size1); \
+ }
+
+//===----------------------------------------------------------------------===//
+// Internal helpers.
+//===----------------------------------------------------------------------===//
+
+static iree_uk_x32b_opcode_type_t get_iree_uk_x32b_op_type(
+ iree_uk_x32b_opcode_t opcode) {
+ switch (opcode) {
+ case IREE_UK_X32B_ADDI:
+ case IREE_UK_X32B_ANDI:
+ case IREE_UK_X32B_DIVUI:
+ case IREE_UK_X32B_MULI:
+ case IREE_UK_X32B_ORI:
+ case IREE_UK_X32B_SHLI:
+ case IREE_UK_X32B_SHRUI:
+ case IREE_UKENREL_X32B_XORI:
+ case IREE_UK_X32B_SUBI:
+ return IREE_UK_X32B_UI;
+ case IREE_UK_X32B_DIVSI:
+ return IREE_UK_X32B_SI;
+ default:
+ return IREE_UK_X32B_NA;
+ }
+}
+
+// Computes a single element of an x32b opcode usinbg RVV.
+static void iree_uk_rvv_x32b_op(iree_uk_x32b_opcode_t opcode, int* result_code,
+ const iree_uk_uint32_t* lhs,
+ iree_uk_ssize_t lhs_stride,
+ const iree_uk_uint32_t* rhs,
+ iree_uk_ssize_t rhs_stride,
+ iree_uk_uint32_t* out,
+ iree_uk_ssize_t out_stride, size_t vl) {
+ iree_uk_x32b_opcode_type_t op_type = get_iree_uk_x32b_op_type(opcode);
+ if (op_type == IREE_UK_X32B_UI) {
+ vuint32m8_t vx = vlse32_v_u32m8(lhs, lhs_stride, vl); // load
+ vuint32m8_t vy = vlse32_v_u32m8(rhs, rhs_stride, vl); // load
+ switch (opcode) {
+ case IREE_UK_X32B_ADDI:
+ vx = vadd(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_ANDI:
+ vx = vand(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_DIVUI:
+ vx = vdivu(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_MULI:
+ vx = vmul(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_ORI:
+ vx = vor(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_SHLI:
+ vx = vsll(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_SHRUI:
+ vx = vsrl(vx, vy, vl);
+ break;
+ case IREE_UKENREL_X32B_XORI:
+ vx = vor(vx, vy, vl);
+ break;
+ case IREE_UK_X32B_SUBI:
+ vx = vsub(vx, vy, vl);
+ break;
+ default:
+ *result_code = 1;
+ }
+ vsse32(out, out_stride, vx, vl); // save
+ } else if (op_type == IREE_UK_X32B_SI) {
+ vint32m8_t vx =
+ vlse32_v_i32m8((iree_uk_int32_t*)lhs, lhs_stride, vl); // load
+ vint32m8_t vy =
+ vlse32_v_i32m8((iree_uk_int32_t*)rhs, rhs_stride, vl); // load
+ switch (opcode) {
+ case IREE_UK_X32B_DIVSI:
+ vx = vdiv(vx, vy, vl);
+ break;
+ default:
+ *result_code = 1;
+ }
+ vsse32((iree_uk_int32_t*)out, out_stride, vx, vl); // save
+ } else {
+ *result_code = 1;
+ }
+}
+
+// Computes a single element of an x32b opcode. On error, should set
+// |*result_code| to a non-zero value (but should not touch it otherwise).
+static void iree_uk_generic_x32b_op(iree_uk_x32b_opcode_t opcode,
+ int* result_code,
+ const iree_uk_uint32_t* lhs,
+ const iree_uk_uint32_t* rhs,
+ iree_uk_uint32_t* out) {
+ switch (opcode) {
+ case IREE_UK_X32B_ADDF:
+ ASF32(out) = ASF32(lhs) + ASF32(rhs);
+ return;
+ case IREE_UK_X32B_ADDI:
+ ASUI32(out) = ASUI32(lhs) + ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_ANDI:
+ ASUI32(out) = ASUI32(lhs) & ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_DIVF:
+ ASF32(out) = ASF32(lhs) / ASF32(rhs);
+ return;
+ case IREE_UK_X32B_DIVSI:
+ ASSI32(out) = ASSI32(lhs) / ASSI32(rhs);
+ return;
+ case IREE_UK_X32B_DIVUI:
+ ASUI32(out) = ASUI32(lhs) / ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_MULF:
+ ASF32(out) = ASF32(lhs) * ASF32(rhs);
+ return;
+ case IREE_UK_X32B_MULI:
+ ASUI32(out) = ASUI32(lhs) * ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_ORI:
+ ASUI32(out) = ASUI32(lhs) | ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_SHLI:
+ ASUI32(out) = ASUI32(lhs) << ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_SHRSI:
+ ASSI32(out) = ASSI32(lhs) >> ASSI32(rhs);
+ return;
+ case IREE_UK_X32B_SHRUI:
+ ASUI32(out) = ASUI32(lhs) >> ASUI32(rhs);
+ return;
+ case IREE_UKENREL_X32B_XORI:
+ ASUI32(out) = ASUI32(lhs) ^ ASUI32(rhs);
+ return;
+ case IREE_UK_X32B_SUBF:
+ ASF32(out) = ASF32(lhs) - ASF32(rhs);
+ return;
+ case IREE_UK_X32B_SUBI:
+ ASSI32(out) = ASUI32(lhs) - ASUI32(rhs);
+ return;
+ default:
+ *result_code = 1;
+ }
+}
+
+// Computes a single element of an x32u opcode. Most are float ops. On error,
+// should set |*result_code| to a non-zero value (but should not touch it
+// otherwise).
+static void iree_uk_generic_x32u_op(iree_uk_x32u_opcode_t opcode,
+ int* result_code,
+ const iree_uk_uint32_t* in,
+ iree_uk_uint32_t* out) {
+ switch (opcode) {
+ case IREE_UK_X32U_ABSF:
+ ASF32(out) = fabsf(ASF32(in));
+ return;
+ case IREE_UK_X32U_CEILF:
+ ASF32(out) = ceilf(ASF32(in));
+ return;
+ case IREE_UK_X32U_CTLZ:
+ ASUI32(out) = iree_uk_count_leading_zeros_u32(ASUI32(in));
+ return;
+ case IREE_UK_X32U_EXPF:
+ ASF32(out) = expf(ASF32(in));
+ return;
+ case IREE_UK_X32U_FLOORF:
+ ASF32(out) = floorf(ASF32(in));
+ return;
+ case IREE_UK_X32U_LOGF:
+ ASF32(out) = logf(ASF32(in));
+ return;
+ case IREE_UK_X32U_NEGF:
+ ASF32(out) = -ASF32(in);
+ return;
+ case IREE_UK_X32U_RSQRTF:
+ ASF32(out) = 1.0f / sqrtf(ASF32(in));
+ return;
+ default:
+ *result_code = 1;
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Opcode dispatch entry points.
+//===----------------------------------------------------------------------===//
+
+// 32bit binary kernels.
+IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_x32b_2d(
+ iree_uk_x32b_opcode_t opcode,
+ // LHS.
+ const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
+ iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
+ // RHS
+ const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
+ iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
+ // OUT.
+ iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
+ iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+ // Sizes.
+ iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+ int result_code = 0;
+
+ if (get_iree_uk_x32b_op_type(opcode) != IREE_UK_X32B_NA) {
+ size_t vl;
+ // make most use of vectorization by swiching dimension
+ if (size0 < size1) {
+ for (iree_uk_ssize_t i = 0; i < size0; ++i) {
+ for (iree_uk_ssize_t j = 0; j < size1; j += vl) {
+ vl = vsetvl_e32m8(size1 - j);
+ iree_uk_rvv_x32b_op(opcode, &result_code,
+ &lhs[i * lhs_stride0 + j * lhs_stride1],
+ lhs_stride1 * sizeof(uint32_t),
+ &rhs[i * rhs_stride0 + j * rhs_stride1],
+ rhs_stride1 * sizeof(uint32_t),
+ &out[i * out_stride0 + j * out_stride1],
+ out_stride1 * sizeof(uint32_t), vl);
+ }
+ }
+ } else {
+ for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ for (iree_uk_ssize_t i = 0; i < size0; i += vl) {
+ vl = vsetvl_e32m8(size0 - i);
+ iree_uk_rvv_x32b_op(opcode, &result_code,
+ &lhs[i * lhs_stride0 + j * lhs_stride1],
+ lhs_stride0 * sizeof(uint32_t),
+ &rhs[i * rhs_stride0 + j * rhs_stride1],
+ rhs_stride0 * sizeof(uint32_t),
+ &out[i * out_stride0 + j * out_stride1],
+ out_stride0 * sizeof(uint32_t), vl);
+ }
+ }
+ }
+ } else {
+ for (iree_uk_ssize_t i = 0; i < size0; ++i) {
+ for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ iree_uk_generic_x32b_op(opcode, &result_code,
+ &lhs[i * lhs_stride0 + j * lhs_stride1],
+ &rhs[i * rhs_stride0 + j * rhs_stride1],
+ &out[i * out_stride0 + j * out_stride1]);
+ }
+ }
+ }
+ return result_code;
+}
+
+// Generic 32bit unary kernels.
+IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d(
+ iree_uk_x32u_opcode_t opcode,
+ // IN.
+ const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
+ iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
+ // OUT.
+ iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
+ iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+ // Sizes.
+ iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+ int result_code = 0;
+ // TODO: Manually unroll to x4 to trigger vectorization.
+ for (iree_uk_ssize_t i = 0; i < size0; ++i) {
+ for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ iree_uk_generic_x32u_op(opcode, &result_code,
+ &in[i * in_stride0 + j * in_stride1],
+ &out[i * out_stride0 + j * out_stride1]);
+ }
+ }
+ return result_code;
+}
diff --git a/vmvx_ukernel/mmt4d_arch.c b/vmvx_ukernel/mmt4d_arch.c
index 9c8aecf..92e19ce 100644
--- a/vmvx_ukernel/mmt4d_arch.c
+++ b/vmvx_ukernel/mmt4d_arch.c
@@ -16,8 +16,63 @@
#include "iree/builtins/ukernel/arch/mmt4d_arch.h"
+#include <riscv_vector.h>
+#include <string.h>
+
+// Calculate the dot product of two int8 vectors using RVV
+static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u,
+ const iree_uk_int8_t* w, int n) {
+ size_t vl;
+ // auxiliary variables
+ vint8m4_t vu, vw;
+ vint16m8_t vx;
+ vint32m1_t v_sum;
+ iree_uk_int32_t sum = 0;
+ for (size_t i = 0; i < n; i += vl) {
+ vl = vsetvl_e8m4(n - i);
+ vu = vle8_v_i8m4(u + i, vl); // load
+ vw = vle8_v_i8m4(w + i, vl); // load
+ vx = vwmul(vu, vw, vl); // multiply
+ v_sum = vmv_s(v_sum, 0, vl); // init
+ v_sum = vwredsum(v_sum, vx, v_sum, vl); // sum
+ sum += vmv_x(v_sum);
+ }
+ return sum;
+}
+
+// RVV implementation of matmul tile, i8*i8->i32 case.
+static void iree_uk_mmt4d_tile_i8i8i32_rvv(
+ void* out_tile_untyped, const void* lhs_panel_untyped,
+ const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags,
+ const iree_uk_mmt4d_params_t* params) {
+ iree_uk_int32_t* out_tile = out_tile_untyped;
+ const iree_uk_int8_t* lhs_panel = lhs_panel_untyped;
+ const iree_uk_int8_t* rhs_panel = rhs_panel_untyped;
+ iree_uk_int16_t M0 = params->M0;
+ iree_uk_int16_t N0 = params->N0;
+ iree_uk_int16_t K0 = params->K0;
+ // Initialize the accumulator tile.
+ if (!(flags & IREE_UK_FLAG_ACCUMULATE)) {
+ memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t));
+ }
+ // Accumulation loop.
+ for (iree_uk_ssize_t k = 0; k < K; ++k) {
+ for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
+ out_tile[i0 * N0 + j0] +=
+ dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0);
+ }
+ }
+ lhs_panel += M0 * K0;
+ rhs_panel += N0 * K0;
+ }
+}
+
iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
const iree_uk_mmt4d_params_t* params) {
// TODO(lundong): to be replaced with Kelvin
+ if (params->type == iree_uk_mmt4d_type_i8i8i32) {
+ return iree_uk_mmt4d_tile_i8i8i32_rvv;
+ }
return 0;
}
diff --git a/vmvx_ukernel/query_tile_sizes_arch.c b/vmvx_ukernel/query_tile_sizes_arch.c
index 90b3502..a8d2db7 100644
--- a/vmvx_ukernel/query_tile_sizes_arch.c
+++ b/vmvx_ukernel/query_tile_sizes_arch.c
@@ -1,9 +1,18 @@
-// Copyright 2023 Google LLC.
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+/*
+ * Copyright 2023 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
#include "iree/builtins/ukernel/arch/query_tile_sizes_arch.h"