Adapt IREE to upcoming ruy::MulParams API change, whereby only the members that are meaningful in each case exist, e.g. no quantization multipliers in floating point cases.

PiperOrigin-RevId: 318855666
diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/hal/vmla/op_kernels_ruy.h
index 981ee72..003fb80 100644
--- a/iree/hal/vmla/op_kernels_ruy.h
+++ b/iree/hal/vmla/op_kernels_ruy.h
@@ -15,10 +15,13 @@
 #ifndef IREE_HAL_VMLA_OP_KERNELS_RUY_H_
 #define IREE_HAL_VMLA_OP_KERNELS_RUY_H_
 
+#include <type_traits>
+
 #include "absl/base/thread_annotations.h"
 #include "absl/memory/memory.h"
 #include "iree/base/status.h"
 #include "ruy/context.h"
+#include "ruy/mul_params.h"
 #include "ruy/ruy.h"
 
 namespace iree {
@@ -37,6 +40,56 @@
   return absl::make_unique<RuntimeState>();
 }
 
+// Floating-point case.
+template <typename ACC, typename T>
+struct MakeRuyMulParamsImpl {
+  static_assert(std::is_floating_point<ACC>::value, "");
+  static_assert(std::is_floating_point<T>::value, "");
+  static void Run(const MatMul::Buffers<T, ACC>& buffers,
+                  ruy::MulParams<ACC, T>* mul_params) {
+    mul_params->set_bias(buffers.bias_buffer.data());
+  }
+};
+
+// Integer quantized case with downquantization to a destination T narrower than
+// int32.
+template <typename T>
+struct MakeRuyMulParamsImpl<std::int32_t, T> {
+  static_assert(std::is_integral<T>::value, "");
+  static_assert(sizeof(T) < sizeof(std::int32_t), "");
+  static void Run(const MatMul::Buffers<T, std::int32_t>& buffers,
+                  ruy::MulParams<std::int32_t, T>* mul_params) {
+    mul_params->set_bias(buffers.bias_buffer.data());
+    if (buffers.multiplier_mantissa_buffer.size() == 1) {
+      mul_params->set_multiplier_fixedpoint(
+          buffers.multiplier_mantissa_buffer[0]);
+      mul_params->set_multiplier_exponent(
+          buffers.multiplier_exponent_buffer[0]);
+    } else {
+      mul_params->set_multiplier_fixedpoint_perchannel(
+          buffers.multiplier_mantissa_buffer.data());
+      mul_params->set_multiplier_exponent_perchannel(
+          buffers.multiplier_exponent_buffer.data());
+    }
+  }
+};
+
+// Raw integer case with int32 destination. This case does not support any
+// output operation besides bias-addition.
+template <>
+struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t> {
+  static void Run(const MatMul::Buffers<std::int32_t, std::int32_t>& buffers,
+                  ruy::MulParams<std::int32_t, std::int32_t>* mul_params) {
+    mul_params->set_bias(buffers.bias_buffer.data());
+  }
+};
+
+template <typename ACC, typename T>
+void MakeRuyMulParams(const MatMul::Buffers<T, ACC>& buffers,
+                      ruy::MulParams<ACC, T>* mul_params) {
+  MakeRuyMulParamsImpl<ACC, T>::Run(buffers, mul_params);
+}
+
 template <typename T, typename ACC>
 Status MatMul::Execute(RuntimeState* runtime_state,
                        const Buffers<T, ACC>& buffers) {
@@ -56,17 +109,7 @@
                         ruy::Order::kColMajor, dst.mutable_layout());
 
   ruy::MulParams<ACC, T> mul_params;
-  mul_params.set_bias(buffers.bias_buffer.data());
-
-  if (buffers.multiplier_mantissa_buffer.size() == 1) {
-    mul_params.set_multiplier_fixedpoint(buffers.multiplier_mantissa_buffer[0]);
-    mul_params.set_multiplier_exponent(buffers.multiplier_exponent_buffer[0]);
-  } else {
-    mul_params.set_multiplier_fixedpoint_perchannel(
-        buffers.multiplier_mantissa_buffer.data());
-    mul_params.set_multiplier_exponent_perchannel(
-        buffers.multiplier_exponent_buffer.data());
-  }
+  MakeRuyMulParams(buffers, &mul_params);
 
   ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);