Adapt IREE to the ruy API changes.
PiperOrigin-RevId: 306944042
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index a9f14f1..1d66705 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -7,7 +7,7 @@
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
129cf84e69537ae5c184550f94be18da738d9261 third_party/llvm-project
80d452484c5409444b0ec19383faa84bb7a4d351 third_party/pybind11
-2b11bd49a84f8c3655b4ba14b420f5cc17782db4 third_party/ruy
+9f53ba413e6fc879236dcaa3e008915973d67a4f third_party/ruy
b73f111094da3e380a1774b56b15f16c90ae8e23 third_party/sdl2
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
1346dd5de119d603686e260daf08f36958909a23 third_party/spirv_tools
diff --git a/build_tools/third_party/ruy/CMakeLists.txt b/build_tools/third_party/ruy/CMakeLists.txt
index dd6136e..f898da6 100644
--- a/build_tools/third_party/ruy/CMakeLists.txt
+++ b/build_tools/third_party/ruy/CMakeLists.txt
@@ -382,11 +382,11 @@
PACKAGE
ruy
NAME
- spec
+ mul_params
ROOT
${RUY_SRC_ROOT}
HDRS
- "spec.h"
+ "mul_params.h"
COPTS
${RUY_COPTS_BASE}
DEPS
@@ -457,7 +457,7 @@
ruy::platform
ruy::side_pair
ruy::size_util
- ruy::spec
+ ruy::mul_params
ruy::tune
)
@@ -800,7 +800,7 @@
ruy::platform
ruy::side_pair
ruy::size_util
- ruy::spec
+ ruy::mul_params
ruy::tune
)
@@ -858,8 +858,6 @@
context
ROOT
${RUY_SRC_ROOT}
- SRCS
- "context.cc"
HDRS
"context.h"
COPTS
@@ -883,6 +881,34 @@
PACKAGE
ruy
NAME
+ context_internal
+ ROOT
+ ${RUY_SRC_ROOT}
+ SRCS
+ "context_internal.cc"
+ HDRS
+ "context_internal.h"
+ COPTS
+ ${RUY_COPTS_BASE}
+ DEPS
+ ruy::allocator
+ ruy::check_macros
+ ruy::detect_arm
+ ruy::detect_x86
+ ruy::have_built_path_for
+ ruy::path
+ ruy::platform
+ ruy::prepacked_cache
+ ruy::thread_pool
+ ruy::trace
+ ruy::tune
+ PUBLIC
+)
+
+external_cc_library(
+ PACKAGE
+ ruy
+ NAME
trmul_params
ROOT
${RUY_SRC_ROOT}
@@ -916,12 +942,13 @@
ruy::check_macros
ruy::common
ruy::context
+ ruy::context_internal
ruy::internal_matrix
ruy::matrix
ruy::opt_set
ruy::side_pair
ruy::size_util
- ruy::spec
+ ruy::mul_params
ruy::thread_pool
ruy::trace
ruy::trmul_params
@@ -948,6 +975,7 @@
ruy::check_macros
ruy::common
ruy::context
+ ruy::context_internal
ruy::internal_matrix
ruy::kernel
ruy::matrix
@@ -957,7 +985,7 @@
ruy::prepacked_cache
ruy::side_pair
ruy::size_util
- ruy::spec
+ ruy::mul_params
ruy::trmul
ruy::trmul_params
ruy::tune
diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/hal/vmla/op_kernels_ruy.h
index 379d17e..981ee72 100644
--- a/iree/hal/vmla/op_kernels_ruy.h
+++ b/iree/hal/vmla/op_kernels_ruy.h
@@ -41,34 +41,34 @@
Status MatMul::Execute(RuntimeState* runtime_state,
const Buffers<T, ACC>& buffers) {
ruy::Matrix<T> lhs;
- lhs.data.set(buffers.lhs_buffer.data());
+ lhs.set_data(buffers.lhs_buffer.data());
ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1],
- ruy::Order::kRowMajor, &lhs.layout);
+ ruy::Order::kRowMajor, lhs.mutable_layout());
ruy::Matrix<T> rhs;
- rhs.data.set(buffers.rhs_buffer.data());
+ rhs.set_data(buffers.rhs_buffer.data());
ruy::MakeSimpleLayout(buffers.rhs_shape[1], buffers.rhs_shape[0],
- ruy::Order::kColMajor, &rhs.layout);
+ ruy::Order::kColMajor, rhs.mutable_layout());
ruy::Matrix<T> dst;
- dst.data.set(buffers.dst_buffer.data());
+ dst.set_data(buffers.dst_buffer.data());
ruy::MakeSimpleLayout(buffers.dst_shape[1], buffers.dst_shape[0],
- ruy::Order::kColMajor, &dst.layout);
+ ruy::Order::kColMajor, dst.mutable_layout());
- ruy::BasicSpec<ACC, T> spec;
- spec.bias = buffers.bias_buffer.data();
+ ruy::MulParams<ACC, T> mul_params;
+ mul_params.set_bias(buffers.bias_buffer.data());
if (buffers.multiplier_mantissa_buffer.size() == 1) {
- spec.multiplier_fixedpoint = buffers.multiplier_mantissa_buffer[0];
- spec.multiplier_exponent = buffers.multiplier_exponent_buffer[0];
+ mul_params.set_multiplier_fixedpoint(buffers.multiplier_mantissa_buffer[0]);
+ mul_params.set_multiplier_exponent(buffers.multiplier_exponent_buffer[0]);
} else {
- spec.multiplier_fixedpoint_perchannel =
- buffers.multiplier_mantissa_buffer.data();
- spec.multiplier_exponent_perchannel =
- buffers.multiplier_exponent_buffer.data();
+ mul_params.set_multiplier_fixedpoint_perchannel(
+ buffers.multiplier_mantissa_buffer.data());
+ mul_params.set_multiplier_exponent_perchannel(
+ buffers.multiplier_exponent_buffer.data());
}
- ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &runtime_state->context, &dst);
+ ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);
return OkStatus();
}