blob: 76b61aaff4388df53cadd232043ec52ede1e96bc [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/Input/InputOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::Input;
#include "iree-dialects/Dialect/Input/InputDialect.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "iree-dialects/Dialect/Input/InputAttrs.cpp.inc"
#include "iree-dialects/Dialect/Input/InputEnums.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "iree-dialects/Dialect/Input/InputTypes.cpp.inc"
void IREEInputDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "iree-dialects/Dialect/Input/InputAttrs.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "iree-dialects/Dialect/Input/InputTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"
>();
}
namespace mlir::iree_compiler::IREE::Input {
//===----------------------------------------------------------------------===//
// IREE ABI helpers for constructing buffer views
//===----------------------------------------------------------------------===//
// Keep these in sync with iree/hal/api.h
namespace {
enum class NumericalType : uint32_t {
kUnknown = 0x00,
kInteger = 0x10,
kIntegerSigned = kInteger | 0x01,
kIntegerUnsigned = kInteger | 0x02,
kBoolean = kInteger | 0x03,
kFloat = 0x20,
kFloatIEEE = kFloat | 0x01,
kFloatBrain = kFloat | 0x02,
kFloatComplex = kFloat | 0x03,
};
} // namespace
static constexpr int32_t makeElementTypeValue(NumericalType numericalType,
int32_t bitCount) {
return (static_cast<uint32_t>(numericalType) << 24) | bitCount;
}
std::optional<int32_t> getElementTypeValue(Type type) {
if (auto intType = llvm::dyn_cast_if_present<IntegerType>(type)) {
NumericalType numericalType;
if (intType.isInteger(1)) {
return makeElementTypeValue(NumericalType::kBoolean, 8);
} else if (intType.isSigned()) {
numericalType = NumericalType::kIntegerSigned;
} else if (intType.isUnsigned()) {
numericalType = NumericalType::kIntegerUnsigned;
} else {
// There's no such thing as a signless integer in machine types but we
// need to be able to round-trip the format through the ABI. Exact
// numerical type equality comparisons may fail if the frontend assumes
// signed/unsigned but the compiler is propagating signless.
numericalType = NumericalType::kInteger;
}
return makeElementTypeValue(numericalType, intType.getWidth());
} else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
case APFloat::S_IEEEhalf:
case APFloat::S_IEEEsingle:
case APFloat::S_IEEEdouble:
case APFloat::S_IEEEquad:
return makeElementTypeValue(NumericalType::kFloatIEEE,
floatType.getWidth());
case APFloat::S_BFloat:
return makeElementTypeValue(NumericalType::kFloatBrain,
floatType.getWidth());
default:
return std::nullopt;
}
} else if (auto complexType = llvm::dyn_cast_if_present<ComplexType>(type)) {
return makeElementTypeValue(
NumericalType::kFloatComplex,
complexType.getElementType().getIntOrFloatBitWidth() * 2);
}
return std::nullopt;
}
std::optional<int32_t> getEncodingTypeValue(Attribute attr) {
// TODO(#6762): encoding attribute handling/mapping to enums.
assert(!attr && "encoding types other than default not yet supported");
// Default to IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR for now.
return 1;
}
//===----------------------------------------------------------------------===//
// ListType
//===----------------------------------------------------------------------===//
Type ListType::parse(AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
Type elementType;
if (parser.parseLess() || parser.parseType(elementType) ||
parser.parseGreater())
return Type();
return get(ctxt, elementType);
}
void ListType::print(AsmPrinter &printer) const {
printer << "<" << getElementType() << ">";
}
//===----------------------------------------------------------------------===//
// PtrType
//===----------------------------------------------------------------------===//
Type PtrType::parse(AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
Type targetType;
if (parser.parseLess() || parser.parseType(targetType) ||
parser.parseGreater())
return Type();
return get(ctxt, targetType);
}
void PtrType::print(AsmPrinter &printer) const {
printer << "<" << getTargetType() << ">";
}
//===----------------------------------------------------------------------===//
// #iree_input.device.target
//===----------------------------------------------------------------------===//
DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context,
StringRef deviceID) {
return get(context, StringAttr::get(context, deviceID),
DictionaryAttr::get(context));
}
Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) {
StringAttr deviceIDAttr;
DictionaryAttr configAttr;
// `<"device-id"`
if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) {
return {};
}
// `, {config}`
if (succeeded(p.parseOptionalComma()) &&
failed(p.parseAttribute(configAttr))) {
return {};
}
// `>`
if (failed(p.parseGreater())) {
return {};
}
return get(p.getContext(), deviceIDAttr, configAttr);
}
void DeviceTargetAttr::print(AsmPrinter &p) const {
auto &os = p.getStream();
os << "<";
p.printAttribute(getDeviceID());
auto configAttr = getConfiguration();
if (configAttr && !configAttr.empty()) {
os << ", ";
p.printAttribute(configAttr);
}
os << ">";
}
//===----------------------------------------------------------------------===//
// #iree_input.executable.target
//===----------------------------------------------------------------------===//
ExecutableTargetAttr ExecutableTargetAttr::get(MLIRContext *context,
StringRef backend,
StringRef format) {
return get(context, StringAttr::get(context, backend),
StringAttr::get(context, format), DictionaryAttr::get(context));
}
Attribute ExecutableTargetAttr::parse(AsmParser &p, Type type) {
StringAttr backendAttr;
StringAttr formatAttr;
DictionaryAttr configurationAttr;
// `<"backend", "format"`
if (failed(p.parseLess()) || failed(p.parseAttribute(backendAttr)) ||
failed(p.parseComma()) || failed(p.parseAttribute(formatAttr))) {
return {};
}
// `, {config}`
if (succeeded(p.parseOptionalComma()) &&
failed(p.parseAttribute(configurationAttr))) {
return {};
}
// `>`
if (failed(p.parseGreater())) {
return {};
}
return get(p.getContext(), backendAttr, formatAttr, configurationAttr);
}
void ExecutableTargetAttr::print(AsmPrinter &p) const {
auto &os = p.getStream();
os << "<";
p.printAttribute(getBackend());
os << ", ";
p.printAttribute(getFormat());
auto config = getConfiguration();
if (config && !config.empty()) {
os << ", ";
p.printAttribute(config);
}
os << ">";
}
//===----------------------------------------------------------------------===//
// #iree_input.executable.object
//===----------------------------------------------------------------------===//
Attribute ExecutableObjectAttr::parse(AsmParser &p, Type type) {
NamedAttrList dict;
// `<{` dict `}>`
if (failed(p.parseLess()) || failed(p.parseOptionalAttrDict(dict)) ||
failed(p.parseGreater())) {
return {};
}
auto pathAttr = llvm::dyn_cast_if_present<StringAttr>(dict.get("path"));
auto dataAttr =
llvm::dyn_cast_if_present<DenseIntElementsAttr>(dict.get("data"));
return get(p.getContext(), pathAttr, dataAttr);
}
void ExecutableObjectAttr::print(AsmPrinter &p) const {
auto &os = p.getStream();
os << "<{";
if (auto pathAttr = getPath()) {
os << "path = ";
p.printAttribute(getPath());
}
if (auto dataAttr = getData()) {
os << ", data = ";
p.printAttribute(getData());
}
os << "}>";
}
//===----------------------------------------------------------------------===//
// #iree_input.executable.objects
//===----------------------------------------------------------------------===//
LogicalResult ExecutableObjectsAttr::verify(
function_ref<mlir::InFlightDiagnostic()> emitError, ArrayAttr targetsAttr,
ArrayAttr targetObjectsAttr) {
if (targetsAttr.size() != targetObjectsAttr.size()) {
return emitError() << "targets and objects must be 1:1";
}
for (auto targetAttr : targetsAttr) {
if (!llvm::isa<IREE::Input::ExecutableTargetAttr>(targetAttr)) {
return emitError()
<< "target keys must be #iree_input.executable.target attributes";
}
}
for (auto objectsAttr : targetObjectsAttr) {
auto objectsArrayAttr = llvm::dyn_cast<ArrayAttr>(objectsAttr);
if (!objectsArrayAttr) {
return emitError() << "target objects must be an array of "
"#iree_input.executable.object attributes";
}
}
return success();
}
Attribute ExecutableObjectsAttr::parse(AsmParser &p, Type type) {
// `<{` target = [objects, ...], ... `}>`
SmallVector<Attribute> targetAttrs;
SmallVector<Attribute> objectsAttrs;
if (failed(p.parseLess()))
return {};
if (succeeded(p.parseLBrace()) && !succeeded(p.parseOptionalRBrace())) {
do {
Attribute targetAttr;
ArrayAttr objectsAttr;
if (failed(p.parseAttribute(targetAttr)) || failed(p.parseEqual()) ||
failed(p.parseAttribute(objectsAttr))) {
return {};
}
targetAttrs.push_back(targetAttr);
objectsAttrs.push_back(objectsAttr);
} while (succeeded(p.parseOptionalComma()));
if (failed(p.parseRBrace()))
return {};
}
if (failed(p.parseGreater()))
return {};
return get(p.getContext(), ArrayAttr::get(p.getContext(), targetAttrs),
ArrayAttr::get(p.getContext(), objectsAttrs));
}
void ExecutableObjectsAttr::print(AsmPrinter &p) const {
auto &os = p.getStream();
os << "<{";
llvm::interleaveComma(llvm::zip_equal(getTargets(), getTargetObjects()), os,
[&](std::tuple<Attribute, Attribute> keyValue) {
p.printAttribute(std::get<0>(keyValue));
os << " = ";
p.printAttributeWithoutType(std::get<1>(keyValue));
});
os << "}>";
}
} // namespace mlir::iree_compiler::IREE::Input