blob: 46d73704c2250d0ef955f4b29a9bc04af824e633 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef IREE_MODULES_VMLA_OP_KERNELS_RUY_H_
#define IREE_MODULES_VMLA_OP_KERNELS_RUY_H_
#include <type_traits>
#include "absl/base/thread_annotations.h"
#include "absl/memory/memory.h"
#include "iree/base/status.h"
#include "ruy/context.h"
#include "ruy/mul_params.h"
#include "ruy/ruy.h"
namespace iree {
namespace hal {
namespace vmla {
namespace kernels {
// TODO(benvanik): something more clever for making this shareable.
// Maybe a factory fn based on the impl selected?
struct MatMul::RuntimeState {
// TODO(benvanik): share the thread pool but keep context per-fiber?
ruy::Context context;
};
inline std::unique_ptr<MatMul::RuntimeState> MatMul::CreateRuntimeState() {
return absl::make_unique<RuntimeState>();
}
// Floating-point case.
template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
struct MakeRuyMulParamsImpl {
static_assert(std::is_floating_point<LhsEl>::value, "");
static_assert(std::is_floating_point<RhsEl>::value, "");
static_assert(std::is_floating_point<AccumEl>::value, "");
static_assert(std::is_floating_point<DstEl>::value, "");
static void Run(const MatMul::Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers,
ruy::MulParams<AccumEl, DstEl>* mul_params) {
mul_params->set_bias(buffers.bias_buffer.data());
}
};
// Raw integer case with int32 destination. This case does not support any
// output operation besides bias-addition.
template <typename LhsEl, typename RhsEl>
struct MakeRuyMulParamsImpl<LhsEl, RhsEl, std::int32_t, std::int32_t> {
static void Run(
const MatMul::Buffers<LhsEl, RhsEl, std::int32_t, std::int32_t>& buffers,
ruy::MulParams<std::int32_t, std::int32_t>* mul_params) {
mul_params->set_bias(buffers.bias_buffer.data());
}
};
// Integer quantized case with downquantization to a destination DstEl narrower
// than int32.
template <typename LhsEl, typename RhsEl, typename DstEl>
struct MakeRuyMulParamsImpl<LhsEl, RhsEl, std::int32_t, DstEl> {
static_assert(std::is_integral<LhsEl>::value, "");
static_assert(std::is_integral<RhsEl>::value, "");
static_assert(std::is_integral<DstEl>::value, "");
static_assert(sizeof(DstEl) < sizeof(std::int32_t), "");
static void Run(
const MatMul::Buffers<LhsEl, RhsEl, std::int32_t, DstEl>& buffers,
ruy::MulParams<std::int32_t, DstEl>* mul_params) {
mul_params->set_bias(buffers.bias_buffer.data());
if (buffers.multiplier_mantissa_buffer.size() == 1) {
mul_params->set_multiplier_fixedpoint(
buffers.multiplier_mantissa_buffer[0]);
mul_params->set_multiplier_exponent(
buffers.multiplier_exponent_buffer[0]);
} else {
mul_params->set_multiplier_fixedpoint_perchannel(
buffers.multiplier_mantissa_buffer.data());
mul_params->set_multiplier_exponent_perchannel(
buffers.multiplier_exponent_buffer.data());
}
}
};
template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
void MakeRuyMulParams(
const MatMul::Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers,
ruy::MulParams<AccumEl, DstEl>* mul_params) {
MakeRuyMulParamsImpl<LhsEl, RhsEl, AccumEl, DstEl>::Run(buffers, mul_params);
}
template <typename LhsEl, typename RhsEl, typename AccumEl, typename DstEl>
Status MatMul::Execute(RuntimeState* runtime_state,
const Buffers<LhsEl, RhsEl, AccumEl, DstEl>& buffers) {
ruy::Matrix<LhsEl> lhs;
lhs.set_data(buffers.lhs_buffer.data());
ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1],
ruy::Order::kRowMajor, lhs.mutable_layout());
ruy::Matrix<RhsEl> rhs;
rhs.set_data(buffers.rhs_buffer.data());
ruy::MakeSimpleLayout(buffers.rhs_shape[1], buffers.rhs_shape[0],
ruy::Order::kColMajor, rhs.mutable_layout());
ruy::Matrix<DstEl> dst;
dst.set_data(buffers.dst_buffer.data());
ruy::MakeSimpleLayout(buffers.dst_shape[1], buffers.dst_shape[0],
ruy::Order::kColMajor, dst.mutable_layout());
ruy::MulParams<AccumEl, DstEl> mul_params;
MakeRuyMulParams(buffers, &mul_params);
ruy::Mul(lhs, rhs, mul_params, &runtime_state->context, &dst);
return OkStatus();
}
} // namespace kernels
} // namespace vmla
} // namespace hal
} // namespace iree
#endif // IREE_MODULES_VMLA_OP_KERNELS_RUY_H_