Adapt IREE to upcoming ruy::MulParams API change, whereby only the members that are meaningful in each case exist, e.g. no quantization multipliers in floating point cases. PiperOrigin-RevId: 318855666
diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/hal/vmla/op_kernels_ruy.h index 981ee72..003fb80 100644 --- a/iree/hal/vmla/op_kernels_ruy.h +++ b/iree/hal/vmla/op_kernels_ruy.h
@@ -15,10 +15,13 @@ #ifndef IREE_HAL_VMLA_OP_KERNELS_RUY_H_ #define IREE_HAL_VMLA_OP_KERNELS_RUY_H_ +#include <type_traits> + #include "absl/base/thread_annotations.h" #include "absl/memory/memory.h" #include "iree/base/status.h" #include "ruy/context.h" +#include "ruy/mul_params.h" #include "ruy/ruy.h" namespace iree { @@ -37,6 +40,56 @@ return absl::make_unique<RuntimeState>(); } +// Floating-point case. +template <typename ACC, typename T> +struct MakeRuyMulParamsImpl { + static_assert(std::is_floating_point<ACC>::value, ""); + static_assert(std::is_floating_point<T>::value, ""); + static void Run(const MatMul::Buffers<T, ACC>& buffers, + ruy::MulParams<ACC, T>* mul_params) { + mul_params->set_bias(buffers.bias_buffer.data()); + } +}; + +// Integer quantized case with downquantization to a destination T narrower than +// int32. +template <typename T> +struct MakeRuyMulParamsImpl<std::int32_t, T> { + static_assert(std::is_integral<T>::value, ""); + static_assert(sizeof(T) < sizeof(std::int32_t), ""); + static void Run(const MatMul::Buffers<T, std::int32_t>& buffers, + ruy::MulParams<std::int32_t, T>* mul_params) { + mul_params->set_bias(buffers.bias_buffer.data()); + if (buffers.multiplier_mantissa_buffer.size() == 1) { + mul_params->set_multiplier_fixedpoint( + buffers.multiplier_mantissa_buffer[0]); + mul_params->set_multiplier_exponent( + buffers.multiplier_exponent_buffer[0]); + } else { + mul_params->set_multiplier_fixedpoint_perchannel( + buffers.multiplier_mantissa_buffer.data()); + mul_params->set_multiplier_exponent_perchannel( + buffers.multiplier_exponent_buffer.data()); + } + } +}; + +// Raw integer case with int32 destination. This case does not support any +// output operation besides bias-addition. +template <> +struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t> { + static void Run(const MatMul::Buffers<std::int32_t, std::int32_t>& buffers, + ruy::MulParams<std::int32_t, std::int32_t>* mul_params) { + mul_params->set_bias(buffers.bias_buffer.data()); + } +}; + +template <typename ACC, typename T> +void MakeRuyMulParams(const MatMul::Buffers<T, ACC>& buffers, + ruy::MulParams<ACC, T>* mul_params) { + MakeRuyMulParamsImpl<ACC, T>::Run(buffers, mul_params); +} + template <typename T, typename ACC> Status MatMul::Execute(RuntimeState* runtime_state, const Buffers<T, ACC>& buffers) { @@ -56,17 +109,7 @@ ruy::Order::kColMajor, dst.mutable_layout()); ruy::MulParams<ACC, T> mul_params; - mul_params.set_bias(buffers.bias_buffer.data()); - - if (buffers.multiplier_mantissa_buffer.size() == 1) { - mul_params.set_multiplier_fixedpoint(buffers.multiplier_mantissa_buffer[0]); - mul_params.set_multiplier_exponent(buffers.multiplier_exponent_buffer[0]); - } else { - mul_params.set_multiplier_fixedpoint_perchannel( - buffers.multiplier_mantissa_buffer.data()); - mul_params.set_multiplier_exponent_perchannel( - buffers.multiplier_exponent_buffer.data()); - } + MakeRuyMulParams(buffers, &mul_params); ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);