Adapt IREE to upcoming ruy::MulParams API change, whereby only the members that are meaningful in each case exist, e.g. no quantization multipliers in floating point cases.
PiperOrigin-RevId: 318855666
diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/hal/vmla/op_kernels_ruy.h
index 981ee72..003fb80 100644
--- a/iree/hal/vmla/op_kernels_ruy.h
+++ b/iree/hal/vmla/op_kernels_ruy.h
@@ -15,10 +15,13 @@
#ifndef IREE_HAL_VMLA_OP_KERNELS_RUY_H_
#define IREE_HAL_VMLA_OP_KERNELS_RUY_H_
+#include <type_traits>
+
#include "absl/base/thread_annotations.h"
#include "absl/memory/memory.h"
#include "iree/base/status.h"
#include "ruy/context.h"
+#include "ruy/mul_params.h"
#include "ruy/ruy.h"
namespace iree {
@@ -37,6 +40,56 @@
return absl::make_unique<RuntimeState>();
}
+// Floating-point case.
+template <typename ACC, typename T>
+struct MakeRuyMulParamsImpl {
+ static_assert(std::is_floating_point<ACC>::value, "");
+ static_assert(std::is_floating_point<T>::value, "");
+ static void Run(const MatMul::Buffers<T, ACC>& buffers,
+ ruy::MulParams<ACC, T>* mul_params) {
+ mul_params->set_bias(buffers.bias_buffer.data());
+ }
+};
+
+// Integer quantized case with downquantization to a destination T narrower than
+// int32.
+template <typename T>
+struct MakeRuyMulParamsImpl<std::int32_t, T> {
+ static_assert(std::is_integral<T>::value, "");
+ static_assert(sizeof(T) < sizeof(std::int32_t), "");
+ static void Run(const MatMul::Buffers<T, std::int32_t>& buffers,
+ ruy::MulParams<std::int32_t, T>* mul_params) {
+ mul_params->set_bias(buffers.bias_buffer.data());
+ if (buffers.multiplier_mantissa_buffer.size() == 1) {
+ mul_params->set_multiplier_fixedpoint(
+ buffers.multiplier_mantissa_buffer[0]);
+ mul_params->set_multiplier_exponent(
+ buffers.multiplier_exponent_buffer[0]);
+ } else {
+ mul_params->set_multiplier_fixedpoint_perchannel(
+ buffers.multiplier_mantissa_buffer.data());
+ mul_params->set_multiplier_exponent_perchannel(
+ buffers.multiplier_exponent_buffer.data());
+ }
+ }
+};
+
+// Raw integer case with int32 destination. This case does not support any
+// output operation besides bias-addition.
+template <>
+struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t> {
+ static void Run(const MatMul::Buffers<std::int32_t, std::int32_t>& buffers,
+ ruy::MulParams<std::int32_t, std::int32_t>* mul_params) {
+ mul_params->set_bias(buffers.bias_buffer.data());
+ }
+};
+
+template <typename ACC, typename T>
+void MakeRuyMulParams(const MatMul::Buffers<T, ACC>& buffers,
+ ruy::MulParams<ACC, T>* mul_params) {
+ MakeRuyMulParamsImpl<ACC, T>::Run(buffers, mul_params);
+}
+
template <typename T, typename ACC>
Status MatMul::Execute(RuntimeState* runtime_state,
const Buffers<T, ACC>& buffers) {
@@ -56,17 +109,7 @@
ruy::Order::kColMajor, dst.mutable_layout());
ruy::MulParams<ACC, T> mul_params;
- mul_params.set_bias(buffers.bias_buffer.data());
-
- if (buffers.multiplier_mantissa_buffer.size() == 1) {
- mul_params.set_multiplier_fixedpoint(buffers.multiplier_mantissa_buffer[0]);
- mul_params.set_multiplier_exponent(buffers.multiplier_exponent_buffer[0]);
- } else {
- mul_params.set_multiplier_fixedpoint_perchannel(
- buffers.multiplier_mantissa_buffer.data());
- mul_params.set_multiplier_exponent_perchannel(
- buffers.multiplier_exponent_buffer.data());
- }
+ MakeRuyMulParams(buffers, &mul_params);
ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);