Fix upstream iree breakage
(1) renaming iree_uk_ssize_t to iree_uk_index_t due to
https://github.com/openxla/iree/pull/13878.
(2) renaming mhlo to stablehlo due to
https://github.com/openxla/iree/pull/13870.
Change-Id: I99c1a249d2a83cf25ec91305639c191fd838fdba
diff --git a/samples/simple_vec_mul/CMakeLists.txt b/samples/simple_vec_mul/CMakeLists.txt
index 071461f..74e9dc8 100644
--- a/samples/simple_vec_mul/CMakeLists.txt
+++ b/samples/simple_vec_mul/CMakeLists.txt
@@ -12,7 +12,7 @@
C_IDENTIFIER
"samples_simple_vec_mul_simple_float_mul"
FLAGS
- "-iree-input-type=mhlo"
+ "-iree-input-type=stablehlo"
"-riscv-v-fixed-length-vector-lmul-max=8"
VMVX
INLINE_HAL
@@ -26,7 +26,7 @@
C_IDENTIFIER
"samples_simple_vec_mul_simple_int_mul"
FLAGS
- "-iree-input-type=mhlo"
+ "-iree-input-type=stablehlo"
"-riscv-v-fixed-length-vector-lmul-max=8"
VMVX
INLINE_HAL
diff --git a/samples/simple_vec_mul/simple_float_mul.mlir b/samples/simple_vec_mul/simple_float_mul.mlir
index 0b2f7df..39da051 100644
--- a/samples/simple_vec_mul/simple_float_mul.mlir
+++ b/samples/simple_vec_mul/simple_float_mul.mlir
@@ -1,5 +1,5 @@
func.func @simple_mul(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32>
{
- %0 = "mhlo.multiply"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
+ %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
return %0 : tensor<1024xf32>
}
diff --git a/samples/simple_vec_mul/simple_int_mul.mlir b/samples/simple_vec_mul/simple_int_mul.mlir
index ee2dac0..af6e845 100644
--- a/samples/simple_vec_mul/simple_int_mul.mlir
+++ b/samples/simple_vec_mul/simple_int_mul.mlir
@@ -1,5 +1,5 @@
func.func @simple_mul(%arg0: tensor<1024xi32>, %arg1: tensor<1024xi32>) -> tensor<1024xi32>
{
- %0 = "mhlo.multiply"(%arg0, %arg1) : (tensor<1024xi32>, tensor<1024xi32>) -> tensor<1024xi32>
+ %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<1024xi32>, tensor<1024xi32>) -> tensor<1024xi32>
return %0 : tensor<1024xi32>
}
diff --git a/vmvx_ukernel/elementwise.c b/vmvx_ukernel/elementwise.c
index 5a06913..0e71d22 100644
--- a/vmvx_ukernel/elementwise.c
+++ b/vmvx_ukernel/elementwise.c
@@ -81,13 +81,13 @@
// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
#define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \
IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
- const dtype* lhs, iree_uk_ssize_t lhs_offset, \
- iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \
- const dtype* rhs, iree_uk_ssize_t rhs_offset, \
- iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \
- dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, \
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \
- iree_uk_ssize_t size0, iree_uk_ssize_t size1) { \
+ const dtype* lhs, iree_uk_index_t lhs_offset, \
+ iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1, \
+ const dtype* rhs, iree_uk_index_t rhs_offset, \
+ iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1, \
+ dtype* IREE_UK_RESTRICT out, iree_uk_index_t out_offset, \
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, \
+ iree_uk_index_t size0, iree_uk_index_t size1) { \
return iree_uk_##category##_2d(opcode_t, lhs, lhs_offset, lhs_stride0, \
lhs_stride1, rhs, rhs_offset, rhs_stride0, \
rhs_stride1, out, out_offset, out_stride0, \
@@ -99,11 +99,11 @@
// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
#define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \
IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
- const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
- iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out, \
- iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0, \
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0, \
- iree_uk_ssize_t size1) { \
+ const dtype* in, iree_uk_index_t in_offset, iree_uk_index_t in_stride0, \
+ iree_uk_index_t in_stride1, dtype* IREE_UK_RESTRICT out, \
+ iree_uk_index_t out_offset, iree_uk_index_t out_stride0, \
+ iree_uk_index_t out_stride1, iree_uk_index_t size0, \
+ iree_uk_index_t size1) { \
return iree_uk_generic_##category##_2d( \
opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \
out_stride0, out_stride1, size0, size1); \
@@ -136,11 +136,11 @@
// Computes a single element of an x32b opcode usinbg RVV.
static void iree_uk_rvv_x32b_op(iree_uk_x32b_opcode_t opcode, int* result_code,
const iree_uk_uint32_t* lhs,
- iree_uk_ssize_t lhs_stride,
+ iree_uk_index_t lhs_stride,
const iree_uk_uint32_t* rhs,
- iree_uk_ssize_t rhs_stride,
+ iree_uk_index_t rhs_stride,
iree_uk_uint32_t* out,
- iree_uk_ssize_t out_stride, size_t vl) {
+ iree_uk_index_t out_stride, size_t vl) {
iree_uk_x32b_opcode_type_t op_type = get_iree_uk_x32b_op_type(opcode);
if (op_type == IREE_UK_X32B_UI) {
vuint32m8_t vx = vlse32_v_u32m8(lhs, lhs_stride, vl); // load
@@ -298,24 +298,24 @@
IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_x32b_2d(
iree_uk_x32b_opcode_t opcode,
// LHS.
- const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
- iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
+ const iree_uk_uint32_t* lhs, iree_uk_index_t lhs_offset,
+ iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1,
// RHS
- const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
- iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
+ const iree_uk_uint32_t* rhs, iree_uk_index_t rhs_offset,
+ iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1,
// OUT.
- iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+ iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
// Sizes.
- iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+ iree_uk_index_t size0, iree_uk_index_t size1) {
int result_code = 0;
if (get_iree_uk_x32b_op_type(opcode) != IREE_UK_X32B_NA) {
size_t vl;
// make most use of vectorization by swiching dimension
if (size0 < size1) {
- for (iree_uk_ssize_t i = 0; i < size0; ++i) {
- for (iree_uk_ssize_t j = 0; j < size1; j += vl) {
+ for (iree_uk_index_t i = 0; i < size0; ++i) {
+ for (iree_uk_index_t j = 0; j < size1; j += vl) {
vl = vsetvl_e32m8(size1 - j);
iree_uk_rvv_x32b_op(opcode, &result_code,
&lhs[i * lhs_stride0 + j * lhs_stride1],
@@ -327,8 +327,8 @@
}
}
} else {
- for (iree_uk_ssize_t j = 0; j < size1; ++j) {
- for (iree_uk_ssize_t i = 0; i < size0; i += vl) {
+ for (iree_uk_index_t j = 0; j < size1; ++j) {
+ for (iree_uk_index_t i = 0; i < size0; i += vl) {
vl = vsetvl_e32m8(size0 - i);
iree_uk_rvv_x32b_op(opcode, &result_code,
&lhs[i * lhs_stride0 + j * lhs_stride1],
@@ -341,8 +341,8 @@
}
}
} else {
- for (iree_uk_ssize_t i = 0; i < size0; ++i) {
- for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ for (iree_uk_index_t i = 0; i < size0; ++i) {
+ for (iree_uk_index_t j = 0; j < size1; ++j) {
iree_uk_generic_x32b_op(opcode, &result_code,
&lhs[i * lhs_stride0 + j * lhs_stride1],
&rhs[i * rhs_stride0 + j * rhs_stride1],
@@ -357,17 +357,17 @@
IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d(
iree_uk_x32u_opcode_t opcode,
// IN.
- const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
- iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
+ const iree_uk_uint32_t* in, iree_uk_index_t in_offset,
+ iree_uk_index_t in_stride0, iree_uk_index_t in_stride1,
// OUT.
- iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+ iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
// Sizes.
- iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+ iree_uk_index_t size0, iree_uk_index_t size1) {
int result_code = 0;
// TODO: Manually unroll to x4 to trigger vectorization.
- for (iree_uk_ssize_t i = 0; i < size0; ++i) {
- for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ for (iree_uk_index_t i = 0; i < size0; ++i) {
+ for (iree_uk_index_t j = 0; j < size1; ++j) {
iree_uk_generic_x32u_op(opcode, &result_code,
&in[i * in_stride0 + j * in_stride1],
&out[i * out_stride0 + j * out_stride1]);
diff --git a/vmvx_ukernel/mmt4d_tile.c b/vmvx_ukernel/mmt4d_tile.c
index 18ad137..10105b6 100644
--- a/vmvx_ukernel/mmt4d_tile.c
+++ b/vmvx_ukernel/mmt4d_tile.c
@@ -56,9 +56,9 @@
memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t));
}
// Accumulation loop.
- for (iree_uk_ssize_t k = 0; k < K; ++k) {
- for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
- for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_uk_index_t k = 0; k < K; ++k) {
+ for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
out_tile[i0 * N0 + j0] +=
dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0);
}
@@ -87,10 +87,10 @@
for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
}
// Accumulation loop.
- for (iree_uk_ssize_t k = 0; k < K; ++k) {
- for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
- for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
- for (iree_uk_ssize_t k0 = 0; k0 < K0; ++k0) {
+ for (iree_uk_index_t k = 0; k < K; ++k) {
+ for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
float lhs_val = lhs_panel[i0 * K0 + k0];
float rhs_val = rhs_panel[j0 * K0 + k0];
acc[i0 * N0 + j0] += lhs_val * rhs_val;