blob: 8cedd751c087bc2f10d0f9f691cef90efd6385a1 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Conversion helper tables.
#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
#include "base/status.h"
#include "hal/buffer_view.h"
#include "hal/interpreter/bytecode_dispatch_util.h"
#include "schemas/bytecode/interpreter_bytecode_v0.h"
#include "vm/type.h"
namespace iree {
namespace hal {
template <typename KERNEL, bool src_signed, bool dst_signed, typename... ARGS>
struct ApplyConversionOp {
static Status Apply(const vm::Type& src_type, BufferView* src_local,
const vm::Type& dst_type, BufferView* dst_local,
ARGS... args) {
// Validate ranges so that we cannot go out of bounds on thunk table.
int src_type_index = src_type.type_index();
int dst_type_index = dst_type.type_index();
if (src_type_index < 0 || src_type_index >= kBuiltinTypeCount) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Conversion from invalid source builtin type "
<< src_type_index;
} else if (dst_type_index < 0 || dst_type_index >= kBuiltinTypeCount) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Conversion to invalid dest builtin type " << dst_type_index;
}
// All possible combinations of conversions.
using KernelFn = Status (*)(BufferView * src_local, BufferView * dst_local,
ARGS... args);
KernelFn fn = nullptr;
if (src_signed && dst_signed) {
// Signed -> signed.
static const KernelFn
kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
// src_type = kI8:
/* kI8 */ Thunk<int8_t, int8_t>::Apply,
/* kI16 */ Thunk<int8_t, int16_t>::Apply,
/* kI32 */ Thunk<int8_t, int32_t>::Apply,
/* kI64 */ Thunk<int8_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<int8_t, float>::Apply,
/* kF64 */ Thunk<int8_t, double>::Apply,
// src_type = kI16:
/* kI8 */ Thunk<int16_t, int8_t>::Apply,
/* kI16 */ Thunk<int16_t, int16_t>::Apply,
/* kI32 */ Thunk<int16_t, int32_t>::Apply,
/* kI64 */ Thunk<int16_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<int16_t, float>::Apply,
/* kF64 */ Thunk<int16_t, double>::Apply,
// src_type = kI32:
/* kI8 */ Thunk<int32_t, int8_t>::Apply,
/* kI16 */ Thunk<int32_t, int16_t>::Apply,
/* kI32 */ Thunk<int32_t, int32_t>::Apply,
/* kI64 */ Thunk<int32_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<int32_t, float>::Apply,
/* kF64 */ Thunk<int32_t, double>::Apply,
// src_type = kI64:
/* kI8 */ Thunk<int64_t, int8_t>::Apply,
/* kI16 */ Thunk<int64_t, int16_t>::Apply,
/* kI32 */ Thunk<int64_t, int32_t>::Apply,
/* kI64 */ Thunk<int64_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<int64_t, float>::Apply,
/* kF64 */ Thunk<int64_t, double>::Apply,
// src_type = kF16:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ Thunk<uint16_t, uint16_t>::Apply,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF32:
/* kI8 */ Thunk<float, int8_t>::Apply,
/* kI16 */ Thunk<float, int16_t>::Apply,
/* kI32 */ Thunk<float, int32_t>::Apply,
/* kI64 */ Thunk<float, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<float, float>::Apply,
/* kF64 */ Thunk<float, double>::Apply,
// src_type = kF64:
/* kI8 */ Thunk<double, int8_t>::Apply,
/* kI16 */ Thunk<double, int16_t>::Apply,
/* kI32 */ Thunk<double, int32_t>::Apply,
/* kI64 */ Thunk<double, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<double, float>::Apply,
/* kF64 */ Thunk<double, double>::Apply,
};
fn =
kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
} else if (src_signed && !dst_signed) {
// Signed -> unsigned.
static const KernelFn
kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
// src_type = kI8:
/* kI8 */ Thunk<int8_t, uint8_t>::Apply,
/* kI16 */ Thunk<int8_t, uint16_t>::Apply,
/* kI32 */ Thunk<int8_t, uint32_t>::Apply,
/* kI64 */ Thunk<int8_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kI16:
/* kI8 */ Thunk<int16_t, uint8_t>::Apply,
/* kI16 */ Thunk<int16_t, uint16_t>::Apply,
/* kI32 */ Thunk<int16_t, uint32_t>::Apply,
/* kI64 */ Thunk<int16_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kI32:
/* kI8 */ Thunk<int32_t, uint8_t>::Apply,
/* kI16 */ Thunk<int32_t, uint16_t>::Apply,
/* kI32 */ Thunk<int32_t, uint32_t>::Apply,
/* kI64 */ Thunk<int32_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kI64:
/* kI8 */ Thunk<int64_t, uint8_t>::Apply,
/* kI16 */ Thunk<int64_t, uint16_t>::Apply,
/* kI32 */ Thunk<int64_t, uint32_t>::Apply,
/* kI64 */ Thunk<int64_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF16:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF32:
/* kI8 */ Thunk<float, uint8_t>::Apply,
/* kI16 */ Thunk<float, uint16_t>::Apply,
/* kI32 */ Thunk<float, uint32_t>::Apply,
/* kI64 */ Thunk<float, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF64:
/* kI8 */ Thunk<double, uint8_t>::Apply,
/* kI16 */ Thunk<double, uint16_t>::Apply,
/* kI32 */ Thunk<double, uint32_t>::Apply,
/* kI64 */ Thunk<double, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
};
fn =
kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
} else if (!src_signed && dst_signed) {
// Unsigned -> signed.
static const KernelFn
kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
// src_type = kI8:
/* kI8 */ Thunk<uint8_t, int8_t>::Apply,
/* kI16 */ Thunk<uint8_t, int16_t>::Apply,
/* kI32 */ Thunk<uint8_t, int32_t>::Apply,
/* kI64 */ Thunk<uint8_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<uint8_t, float>::Apply,
/* kF64 */ Thunk<uint8_t, double>::Apply,
// src_type = kI16:
/* kI8 */ Thunk<uint16_t, int8_t>::Apply,
/* kI16 */ Thunk<uint16_t, int16_t>::Apply,
/* kI32 */ Thunk<uint16_t, int32_t>::Apply,
/* kI64 */ Thunk<uint16_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<uint16_t, float>::Apply,
/* kF64 */ Thunk<uint16_t, double>::Apply,
// src_type = kI32:
/* kI8 */ Thunk<uint32_t, int8_t>::Apply,
/* kI16 */ Thunk<uint32_t, int16_t>::Apply,
/* kI32 */ Thunk<uint32_t, int32_t>::Apply,
/* kI64 */ Thunk<uint32_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<uint32_t, float>::Apply,
/* kF64 */ Thunk<uint32_t, double>::Apply,
// src_type = kI64:
/* kI8 */ Thunk<uint64_t, int8_t>::Apply,
/* kI16 */ Thunk<uint64_t, int16_t>::Apply,
/* kI32 */ Thunk<uint64_t, int32_t>::Apply,
/* kI64 */ Thunk<uint64_t, int64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ Thunk<uint64_t, float>::Apply,
/* kF64 */ Thunk<uint64_t, double>::Apply,
// src_type = kF16:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF32:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF64:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
};
fn =
kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
} else if (!src_signed && !dst_signed) {
// Unsigned -> unsigned.
static const KernelFn
kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = {
// src_type = kI8:
/* kI8 */ Thunk<uint8_t, uint8_t>::Apply,
/* kI16 */ Thunk<uint8_t, uint16_t>::Apply,
/* kI32 */ Thunk<uint8_t, uint32_t>::Apply,
/* kI64 */ Thunk<uint8_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kI16:
/* kI8 */ Thunk<uint16_t, uint8_t>::Apply,
/* kI16 */ Thunk<uint16_t, uint16_t>::Apply,
/* kI32 */ Thunk<uint16_t, uint32_t>::Apply,
/* kI64 */ Thunk<uint16_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kI32:
/* kI8 */ Thunk<uint32_t, uint8_t>::Apply,
/* kI16 */ Thunk<uint32_t, uint16_t>::Apply,
/* kI32 */ Thunk<uint32_t, uint32_t>::Apply,
/* kI64 */ Thunk<uint32_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kI64:
/* kI8 */ Thunk<uint64_t, uint8_t>::Apply,
/* kI16 */ Thunk<uint64_t, uint16_t>::Apply,
/* kI32 */ Thunk<uint64_t, uint32_t>::Apply,
/* kI64 */ Thunk<uint64_t, uint64_t>::Apply,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF16:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF32:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
// src_type = kF64:
/* kI8 */ nullptr,
/* kI16 */ nullptr,
/* kI32 */ nullptr,
/* kI64 */ nullptr,
/* kF16 */ nullptr,
/* kF32 */ nullptr,
/* kF64 */ nullptr,
};
fn =
kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index];
}
if (!fn) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Unsupported conversion from " << src_type_index << " to "
<< dst_type_index;
}
return fn(src_local, dst_local, args...);
}
template <typename SRC, typename DST>
struct Thunk {
static Status Apply(BufferView* src_local, BufferView* dst_local,
ARGS... args) {
ASSIGN_OR_RETURN(auto src_buffer,
src_local->buffer->MapMemory<SRC>(MemoryAccess::kRead));
ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<DST>(
MemoryAccess::kDiscardWrite));
return KERNEL::Execute(src_buffer.contents(),
dst_buffer.mutable_contents(), args...);
}
};
// Disable F32/F64 conversions if they are not supported.
#if !defined(IREE_SUPPORT_F32)
template <typename DST>
struct Thunk<float, DST> {
static Status Apply(BufferView* src_local, BufferView* dst_local,
ARGS... args) {
return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported";
}
};
template <typename SRC>
struct Thunk<SRC, float> {
static Status Apply(BufferView* src_local, BufferView* dst_local,
ARGS... args) {
return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported";
}
};
#endif // !IREE_SUPPORT_F32
#if !defined(IREE_SUPPORT_F64)
template <typename DST>
struct Thunk<double, DST> {
static Status Apply(BufferView* src_local, BufferView* dst_local,
ARGS... args) {
return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported";
}
};
template <typename SRC>
struct Thunk<SRC, double> {
static Status Apply(BufferView* src_local, BufferView* dst_local,
ARGS... args) {
return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported";
}
};
#endif // !IREE_SUPPORT_F64
};
using ApplyConvertSS = ApplyConversionOp<kernels::Convert, /*src_signed=*/true,
/*dst_signed=*/true>;
using ApplyConvertUU = ApplyConversionOp<kernels::Convert, /*src_signed=*/false,
/*dst_signed=*/false>;
using ApplyConvertSU = ApplyConversionOp<kernels::Convert, /*src_signed=*/true,
/*dst_signed=*/false>;
using ApplyConvertUS = ApplyConversionOp<kernels::Convert, /*src_signed=*/false,
/*dst_signed=*/true>;
} // namespace hal
} // namespace iree
#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_