blob: 4cea884d41825598f535d6f9b97ee8c95c40c129 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_MLIR_EDGE_IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
#define THIRD_PARTY_MLIR_EDGE_IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
#include "third_party/absl/container/flat_hash_set.h"
#include "third_party/absl/container/inlined_vector.h"
#include "third_party/absl/types/span.h"
#include "third_party/mlir_edge/iree/base/status.h"
namespace iree {
namespace hal {
namespace kernels {
template <typename T>
Status CompareEQ::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<uint8_t> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] == rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status CompareNE::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<uint8_t> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] != rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status CompareLT::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<uint8_t> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] < rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status CompareLE::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<uint8_t> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] <= rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status CompareGT::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<uint8_t> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] > rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status CompareGE::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<uint8_t> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] >= rhs_buffer[i];
}
return OkStatus();
}
namespace impl {
inline absl::InlinedVector<size_t, 6> ComputeCopyStrides(const Shape& shape,
size_t element_size) {
absl::InlinedVector<size_t, 6> strides(shape.empty() ? 1 : shape.size());
strides.back() = element_size;
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
inline void CopyRegion(absl::Span<const uint8_t> src_buffer,
absl::Span<const size_t> src_strides,
absl::Span<const int32_t> src_indices,
absl::Span<uint8_t> dst_buffer,
absl::Span<const size_t> dst_strides,
absl::Span<const int32_t> dst_indices,
absl::Span<const int32_t> lengths) {
if (lengths.size() > 1) {
for (int i = 0; i < lengths[0]; ++i) {
size_t src_offset = src_strides[0] * (src_indices[0] + i);
size_t dst_offset = dst_strides[0] * (dst_indices[0] + i);
CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1),
src_indices.subspan(1), dst_buffer.subspan(dst_offset),
dst_strides.subspan(1), dst_indices.subspan(1),
lengths.subspan(1));
}
} else {
DCHECK_EQ(dst_strides.size(), 1);
DCHECK_EQ(src_strides.size(), 1);
DCHECK_EQ(src_indices.size(), 1);
DCHECK_EQ(dst_indices.size(), 1);
DCHECK_EQ(lengths.size(), 1);
auto src_offset = src_indices[0] * src_strides[0];
auto dst_offset = dst_indices[0] * dst_strides[0];
auto length = dst_strides[0] * lengths[0];
std::memcpy(dst_buffer.data() + dst_offset, src_buffer.data() + src_offset,
length);
}
}
} // namespace impl
// TODO(benvanik): replace with a real implementation once copy is defined.
template <int element_size>
Status Copy::Execute(absl::Span<const uint8_t> src_buffer,
const Shape& src_shape,
absl::Span<const int32_t> src_indices,
absl::Span<uint8_t> dst_buffer, const Shape& dst_shape,
absl::Span<const int32_t> dst_indices,
absl::Span<const int32_t> lengths) {
// TODO(gcmn) Maybe we can fast-path earlier if we detect contiguous memory
// across multiple rows.
auto src_strides = impl::ComputeCopyStrides(src_shape, element_size);
auto dst_strides = impl::ComputeCopyStrides(dst_shape, element_size);
impl::CopyRegion(src_buffer, src_strides, src_indices, dst_buffer,
dst_strides, dst_indices, lengths);
return OkStatus();
}
template <typename T>
Status Select::Execute(absl::Span<const uint8_t> cond_buffer,
absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = cond_buffer[i] ? lhs_buffer[i] : rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Transpose::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer, const Shape& src_shape,
absl::Span<const int32_t> perm) {
// This implementation is .... not fast.
int rank = src_shape.size();
absl::InlinedVector<int, 8> src_strides(rank);
absl::InlinedVector<int, 8> dst_strides(rank);
size_t src_stride = 1;
size_t dst_stride = 1;
for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
src_strides[dim_i] = src_stride;
dst_strides[dim_i] = dst_stride;
src_stride *= src_shape[dim_i];
dst_stride *= src_shape[perm[dim_i]];
}
for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
size_t src_i = 0;
size_t t = dst_i;
for (int dim_i = 0; dim_i < rank; ++dim_i) {
size_t ratio = t / dst_strides[dim_i];
t -= ratio * dst_strides[dim_i];
src_i += ratio * src_strides[perm[dim_i]];
}
dst_buffer[dst_i] = src_buffer[src_i];
}
return OkStatus();
}
namespace impl {
inline void IncrementShapeIndex(absl::Span<int32_t> indices,
const Shape& shape) {
for (int i = indices.size() - 1; i >= 0; --i) {
if (++indices[i] < shape[i]) return;
indices[i] = 0;
}
}
inline bool IsPadding(absl::Span<const int32_t> indices, const Shape& shape,
absl::Span<const int32_t> edge_padding_low,
absl::Span<const int32_t> edge_padding_high,
absl::Span<const int32_t> interior_padding) {
for (int i = 0; i < indices.size(); ++i) {
auto index = indices[i];
if (index < edge_padding_low[i] ||
index >= shape[i] - edge_padding_high[i] ||
(index - edge_padding_low[i]) % (interior_padding[i] + 1) != 0) {
return true;
}
}
return false;
}
} // namespace impl
template <typename T>
Status Pad::Execute(absl::Span<const T> src_buffer,
absl::Span<const T> padding_value_buffer,
absl::Span<T> dst_buffer, const Shape& src_shape,
const Shape& dst_shape,
absl::Span<const int32_t> edge_padding_low,
absl::Span<const int32_t> edge_padding_high,
absl::Span<const int32_t> interior_padding) {
// This implementation is not at all fast, as it iterates every index in the
// destination buffer individually. Potential improvements:
// 1. Fill the dst buffer with padded value initially. Only need to iterate
// through source buffer and can exit early.
// 2. Use striding to advance through larger swaths of the buffer with a
// memcpy from src and filling (or skipping) padded incides. Especially
// useful when e.g. entire rows are padded.
// TODO(b/140836672) support negative padding
if (padding_value_buffer.size() != 1) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Padding value buffer is larger than one element.";
}
auto padding_value = padding_value_buffer.front();
absl::InlinedVector<int, 8> dst_indices(src_shape.size(), 0);
const T* src_ptr = src_buffer.begin();
T* dst_ptr = dst_buffer.begin();
while (dst_ptr != dst_buffer.end()) {
if (impl::IsPadding(dst_indices, dst_shape, edge_padding_low,
edge_padding_high, interior_padding)) {
*dst_ptr++ = padding_value;
} else {
DCHECK(src_ptr != src_buffer.end());
*dst_ptr++ = *src_ptr++;
}
impl::IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape);
}
return OkStatus();
}
template <typename T>
Status Reverse::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer, const Shape& src_shape,
absl::Span<const int32_t> dimensions) {
// This implementation is not fast either
int rank = src_shape.size();
absl::InlinedVector<int, 8> strides(rank);
size_t stride = 1;
for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
strides[dim_i] = stride;
stride *= src_shape[dim_i];
}
absl::flat_hash_set<int32_t> dims_set(dimensions.begin(), dimensions.end());
for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
size_t src_i = 0;
size_t t = dst_i;
for (int dim_i = 0; dim_i < rank; ++dim_i) {
size_t ratio = t / strides[dim_i];
t -= ratio * strides[dim_i];
bool do_reverse = dims_set.contains(dim_i);
src_i += (do_reverse ? (src_shape[dim_i] - 1 - ratio) : ratio) *
strides[dim_i];
}
dst_buffer[dst_i] = src_buffer[src_i];
}
return OkStatus();
}
template <typename T>
Status Broadcast::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = src_buffer[0];
}
return OkStatus();
}
template <typename T>
Status Tile::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer,
const Shape& src_shape, const Shape& dst_shape) {
// This implementation is .... not fast.
int rank = dst_shape.size();
absl::InlinedVector<int, 8> src_strides(rank);
absl::InlinedVector<int, 8> dst_strides(rank);
size_t src_stride = 1;
size_t dst_stride = 1;
for (int dim_i = rank - 1; dim_i >= 0; --dim_i) {
src_strides[dim_i] = src_stride;
dst_strides[dim_i] = dst_stride;
src_stride *= src_shape[dim_i];
dst_stride *= dst_shape[dim_i];
}
for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) {
size_t src_i = 0;
size_t t = dst_i;
for (int dim_i = 0; dim_i < rank; ++dim_i) {
src_i += t / dst_strides[dim_i] % src_shape[dim_i] * src_strides[dim_i];
t %= dst_strides[dim_i];
}
dst_buffer[dst_i] = src_buffer[src_i];
}
return OkStatus();
}
template <typename T>
Status Not::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = ~src_buffer[i];
}
return OkStatus();
}
template <typename T>
Status And::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] & rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Or::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] | rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Xor::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] ^ rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status ShiftLeft::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] << rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status ShiftRight::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] >> rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Add::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] + rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Sub::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] - rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Abs::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::abs(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Mul::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] * rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status Div::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = lhs_buffer[i] / rhs_buffer[i];
}
return OkStatus();
}
template <typename T>
Status MulAdd::Execute(absl::Span<const T> a_buffer,
absl::Span<const T> b_buffer,
absl::Span<const T> c_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = a_buffer[i] + (b_buffer[i] * c_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Exp::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::exp(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Rsqrt::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = 1.0 / std::sqrt(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Log::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::log(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Cos::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::cos(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Sin::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::sin(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Tanh::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::tanh(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Atan2::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::atan2(lhs_buffer[i], rhs_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Min::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::min(lhs_buffer[i], rhs_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Max::Execute(absl::Span<const T> lhs_buffer,
absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::max(lhs_buffer[i], rhs_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Clamp::Execute(absl::Span<const T> src_buffer,
absl::Span<const T> min_buffer,
absl::Span<const T> max_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
T src = src_buffer[i];
T min = min_buffer[i];
T max = max_buffer[i];
dst_buffer[i] = src <= min ? min : src >= max ? max : src;
}
return OkStatus();
}
template <typename T>
Status Floor::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::floor(src_buffer[i]);
}
return OkStatus();
}
template <typename T>
Status Ceil::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = std::ceil(src_buffer[i]);
}
return OkStatus();
}
template <typename SRC, typename DST>
Status Convert::Execute(absl::Span<const SRC> src_buffer,
absl::Span<DST> dst_buffer) {
DCHECK_EQ(src_buffer.size(), dst_buffer.size());
for (size_t i = 0; i < dst_buffer.size(); ++i) {
dst_buffer[i] = static_cast<DST>(src_buffer[i]);
}
return OkStatus();
}
namespace impl {
struct SumKernel {
template <typename T>
inline void operator()(T* value0, const T value1) {
*value0 += value1;
}
};
struct MinKernel {
template <typename T>
inline void operator()(T* value0, const T value1) {
*value0 = std::min(*value0, value1);
}
};
struct MaxKernel {
template <typename T>
inline void operator()(T* value0, const T value1) {
*value0 = std::max(*value0, value1);
}
};
template <typename T, typename KernelImpl>
inline void ReduceDimension(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer, const Shape& src_shape,
absl::Span<const int32_t> reduce_dims,
absl::Span<const int> dst_strides, int dim,
absl::Span<int> src_indices, size_t flat_src_i,
size_t src_stride) {
if (dim < 0) {
// Base case of the recursion - figure out which elements should be acted
// upon and apply the reduction kernel to them.
// Derive destination indices from source indices.
// For example,
// reduce_dims: [1, 2]
// src_indices: [2, 1, 3, 0]
// ^ ^
// | |
// |----- remove these dimensions
// dst_indices: [2, 0]
//
// TODO(scotttodd): Clean this up somehow, share across recursion levels?
size_t dst_size = src_shape.size() - reduce_dims.size();
absl::InlinedVector<int, 8> dst_indices;
for (size_t i = 0; i < src_indices.size(); ++i) {
if (std::find(std::begin(reduce_dims), std::end(reduce_dims), i) ==
std::end(reduce_dims)) {
dst_indices.push_back(src_indices[i]);
}
}
// Compute the flattened index into dst_buffer at [dst_indices].
size_t dst_i = 0;
for (size_t i = 0; i < dst_indices.size(); ++i) {
dst_i += dst_indices[i] * dst_strides[dst_size - 1 - i];
}
// Flattened src and dst indices have been computed, invoke the kernel.
KernelImpl()(&dst_buffer[dst_i], src_buffer[flat_src_i]);
return;
}
// Iterate through the current dimension in the source shape, recursing
// down one dimension at a time.
//
// This touches each element in the source buffer once, tracking complete
// dimensions within the shaped source buffer and using them to compute
// the corresponding indices (shaped and flattened) within the destination
// buffer. Each element in the destination buffer will be touched multiple
// times.
//
// Note that cache coherency isn't considered here, and some computations
// are redundant, so this could be optimized substantially.
for (size_t dim_i = 0; dim_i < src_shape[dim]; ++dim_i) {
src_indices[dim] = dim_i;
// Recurse down to the next dimension (e.g. 2 -> 1 -> 0 -> base case)
// * Add the current stride to flat_src_i
// * Multiply src_stride by this dimension's shape
ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape,
reduce_dims, dst_strides, dim - 1,
src_indices, flat_src_i + dim_i * src_stride,
src_stride * src_shape[dim]);
}
}
template <typename T, typename KernelImpl>
Status GenericReduce(absl::Span<const T> src_buffer,
absl::Span<const T> init_buffer, absl::Span<T> dst_buffer,
int32_t dimension, const Shape& src_shape,
const Shape& dst_shape) {
// Initialize using init_buffer, which is expected to be a scalar.
std::fill_n(dst_buffer.data(), dst_buffer.size(), init_buffer[0]);
// Precompute destination strides.
int dst_rank = dst_shape.size();
absl::InlinedVector<int, 8> dst_strides;
size_t dst_stride = 1;
for (int dim_i = dst_rank - 1; dim_i >= 0; --dim_i) {
dst_strides.push_back(dst_stride);
dst_stride *= dst_shape[dim_i];
}
// Call the helper (recursive) function, starting with:
// * source index [0, 0, ..., 0]
// * the innermost dimension (last in the shape)
// * flat_src_i of 0 (corresponds to [0, 0, ..., 0] above)
// * source stride 1
absl::InlinedVector<int, 8> src_indices(src_shape.size(), 0);
ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, {dimension},
absl::MakeSpan(dst_strides),
src_shape.size() - 1,
absl::MakeSpan(src_indices), 0, 1);
return OkStatus();
}
} // namespace impl
template <typename T>
Status ReduceSum::Execute(absl::Span<const T> src_buffer,
absl::Span<const T> init_buffer,
absl::Span<T> dst_buffer, int32_t dimension,
const Shape& src_shape, const Shape& dst_shape) {
return impl::GenericReduce<T, impl::SumKernel>(
src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
}
template <typename T>
Status ReduceMin::Execute(absl::Span<const T> src_buffer,
absl::Span<const T> init_buffer,
absl::Span<T> dst_buffer, int32_t dimension,
const Shape& src_shape, const Shape& dst_shape) {
return impl::GenericReduce<T, impl::MinKernel>(
src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
}
template <typename T>
Status ReduceMax::Execute(absl::Span<const T> src_buffer,
absl::Span<const T> init_buffer,
absl::Span<T> dst_buffer, int32_t dimension,
const Shape& src_shape, const Shape& dst_shape) {
return impl::GenericReduce<T, impl::MaxKernel>(
src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
}
} // namespace kernels
} // namespace hal
} // namespace iree
#endif // THIRD_PARTY_MLIR_EDGE_IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_