blob: 164187ee514133dbdd4c70e2934f19f986058a84 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.cpp.inc"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.cpp.inc"
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
//------------------------------------------------------------------------------
// Dialect implementation
//------------------------------------------------------------------------------
using BuiltinIntegerType = mlir::IntegerType;
using PyBoolType = PYDM::BoolType;
using PyConstantOp = PYDM::ConstantOp;
using PyIntegerType = PYDM::IntegerType;
using PyListType = PYDM::ListType;
using PyRealType = PYDM::RealType;
using PyObjectType = PYDM::ObjectType;
using PyUnionType = PYDM::UnionType;
void IREEPyDMDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
>();
}
Operation *IREEPyDMDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// Since we support materialization of builtin types too, explicitly
// allow these.
if (type.isa<PyBoolType, BytesType, PyIntegerType, PyRealType, StrType,
BuiltinIntegerType>()) {
return builder.create<PYDM::ConstantOp>(loc, type, value);
}
if (type.isa<NoneType>()) {
return builder.create<PYDM::NoneOp>(loc, type);
}
if (type.isa<ExceptionResultType>() && value.isa<UnitAttr>()) {
return builder.create<PYDM::SuccessOp>(loc, type);
}
assert(false && "unhandled iree_pydm constant materialization");
return nullptr;
}
//------------------------------------------------------------------------------
// Python type implementation
//------------------------------------------------------------------------------
// BoolType
BuiltinTypeCode PYDM::BoolType::getTypeCode() const {
return static_cast<BuiltinTypeCode>(
makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef PYDM::BoolType::getPythonTypeName() const { return "bool"; }
Optional<NumericCategory> PYDM::BoolType::getNumericCategory() const {
return NumericCategory::Bool;
}
Optional<int> PYDM::BoolType::getNumericSubTypeCode() const { return 0; }
Optional<int> PYDM::BoolType::getNumericPromotionOrder() const {
return static_cast<int>(getTypeCode());
}
// BytesType
BuiltinTypeCode PYDM::BytesType::getTypeCode() const {
return BuiltinTypeCode::Bytes;
}
StringRef PYDM::BytesType::getPythonTypeName() const { return "bytes"; }
// ExceptionResultType
BuiltinTypeCode PYDM::ExceptionResultType::getTypeCode() const {
return BuiltinTypeCode::ExceptionResult;
}
StringRef PYDM::ExceptionResultType::getPythonTypeName() const {
return "Exception";
}
// IntegerType
LogicalResult
PYDM::IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<int> bitWidth) {
if (!bitWidth)
return success();
int w = abs(*bitWidth);
if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64)
return success();
return emitError() << "unsupported python integer bit width: " << w;
}
Type PyIntegerType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
auto emitError = [&]() -> InFlightDiagnostic {
return parser.emitError(parser.getCurrentLocation());
};
// Weak
if (failed(parser.parseOptionalLess()))
return get(ctxt);
// AP
if (succeeded(parser.parseOptionalStar())) {
if (failed(parser.parseGreater()))
return Type();
return get(ctxt, None);
}
// Explicit
bool isSigned;
if (succeeded(parser.parseOptionalKeyword("unsigned"))) {
isSigned = false;
} else {
isSigned = true;
}
int width;
if (failed(parser.parseInteger(width)))
return Type();
if (failed(parser.parseGreater()))
return Type();
if (!isSigned)
width = -width;
return getChecked(emitError, ctxt, width);
}
void PyIntegerType::print(mlir::AsmPrinter &printer) const {
auto w = getImpl()->bitWidth;
if (w) {
printer << "<";
if (*w == 0) {
printer << "*";
} else if (*w > 0) {
printer << *w;
} else {
printer << "unsigned " << (-*w);
}
printer << ">";
}
}
BuiltinTypeCode PYDM::IntegerType::getTypeCode() const {
return static_cast<BuiltinTypeCode>(
makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef PYDM::IntegerType::getPythonTypeName() const { return "int"; }
Optional<NumericCategory> PYDM::IntegerType::getNumericCategory() const {
if (isWeak())
return NumericCategory::WeakInteger;
if (getBitWidth() == 0)
return NumericCategory::APSigned;
if (isSigned())
return NumericCategory::Signed;
return NumericCategory::Unsigned;
}
Optional<int> PYDM::IntegerType::getNumericSubTypeCode() const {
if (isWeak())
return 0;
IntegerSubTypeCode stc;
switch (getBitWidth()) {
case 8:
stc = IntegerSubTypeCode::Integer8;
break;
case 16:
stc = IntegerSubTypeCode::Integer16;
break;
case 32:
stc = IntegerSubTypeCode::Integer32;
break;
case 64:
stc = IntegerSubTypeCode::Integer64;
break;
default: {
stc = IntegerSubTypeCode::Integer8; // Arbitrarily picked value.
assert(false && "unsupported numeric bitwidth");
}
}
return static_cast<int>(stc);
}
Optional<int> PYDM::IntegerType::getNumericPromotionOrder() const {
return static_cast<int>(getTypeCode());
}
bool PYDM::IntegerType::isWeak() const { return !getImpl()->bitWidth; }
unsigned PYDM::IntegerType::getBitWidth() const {
return abs(*getImpl()->bitWidth);
}
bool PYDM::IntegerType::isSigned() const { return *getImpl()->bitWidth >= 0; }
BuiltinTypeCode PYDM::ListType::getTypeCode() const {
return BuiltinTypeCode::List;
}
// ListType
void PyListType::print(mlir::AsmPrinter &printer) const {
if (getImpl()->uniformElementType ||
getImpl()->storageClass != CollectionStorageClass::Boxed) {
printer << "<";
switch (getImpl()->storageClass) {
case CollectionStorageClass::Boxed:
printer << "boxed";
break;
case CollectionStorageClass::Empty:
printer << "empty";
break;
case CollectionStorageClass::Unboxed:
printer << "unboxed";
break;
}
if (getImpl()->uniformElementType) {
printer << ",";
printer << getImpl()->uniformElementType;
}
printer << ">";
}
}
Type PyListType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
if (parser.parseOptionalLess())
return get(ctxt, CollectionStorageClass::Boxed, nullptr);
Type t;
StringRef storageClassKeyword;
if (parser.parseKeyword(&storageClassKeyword))
return Type();
if (parser.parseComma())
return Type();
if (parser.parseType(t))
return Type();
if (parser.parseGreater())
return Type();
CollectionStorageClass storageClass;
if (storageClassKeyword == "boxed")
storageClass = CollectionStorageClass::Boxed;
else if (storageClassKeyword == "empty")
storageClass = CollectionStorageClass::Empty;
else if (storageClassKeyword == "unboxed")
storageClass = CollectionStorageClass::Unboxed;
else {
parser.emitError(parser.getCurrentLocation(),
"expected one of 'boxed', 'empty', 'unboxed'");
return Type();
}
return get(ctxt, storageClass, t);
}
StringRef PYDM::ListType::getPythonTypeName() const { return "list"; }
BuiltinTypeCode PYDM::NoneType::getTypeCode() const {
return BuiltinTypeCode::List;
}
bool PYDM::ListType::isRefinable() const {
if (getStorageClass() == CollectionStorageClass::Empty)
return false;
if (!getUniformElementType())
return true;
if (auto pyType = getUniformElementType().dyn_cast<PythonTypeInterface>())
return pyType.isRefinable();
return false;
}
Type PYDM::ListType::getElementStorageType() const {
switch (getStorageClass()) {
case CollectionStorageClass::Boxed:
case CollectionStorageClass::Empty:
return ObjectType::get(getContext());
case CollectionStorageClass::Unboxed:
assert(getUniformElementType() &&
"unboxed list should have uniform element type");
return getUniformElementType();
default:
assert(false && "unsupported storage class");
return {};
}
}
// NoneType
StringRef PYDM::NoneType::getPythonTypeName() const { return "None"; }
// ObjectType
void PyObjectType::print(mlir::AsmPrinter &printer) const {
if (getImpl()->primitiveType)
printer << "<" << getImpl()->primitiveType << ">";
}
Type PyObjectType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
if (parser.parseOptionalLess())
return get(ctxt, nullptr);
Type t;
if (parser.parseType(t))
return Type();
if (parser.parseGreater())
return Type();
if (auto primitiveType = t.dyn_cast<PrimitiveType>())
return get(ctxt, primitiveType);
else {
parser.emitError(parser.getNameLoc(), "expected a primitive type");
return Type();
}
}
BuiltinTypeCode PYDM::ObjectType::getTypeCode() const {
return BuiltinTypeCode::Object;
}
StringRef PYDM::ObjectType::getPythonTypeName() const { return "object"; }
bool PYDM::ObjectType::isRefinable() const {
if (!getPrimitiveType())
return true;
if (auto pyType = getPrimitiveType().dyn_cast<PythonTypeInterface>())
return pyType.isRefinable();
return false;
}
// RealType
void PyRealType::print(mlir::AsmPrinter &printer) const {
auto ft = getImpl()->floatType;
if (ft)
printer << "<" << ft << ">";
}
Type PyRealType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
auto emitError = [&]() -> InFlightDiagnostic {
return parser.emitError(parser.getCurrentLocation());
};
// Weak
if (failed(parser.parseOptionalLess()))
return get(ctxt);
// Explicit
FloatType subType;
if (failed(parser.parseType(subType)))
return Type();
if (failed(parser.parseGreater()))
return Type();
return getChecked(emitError, ctxt, subType);
}
LogicalResult
PYDM::RealType::verify(function_ref<InFlightDiagnostic()> emitError,
FloatType floatType) {
if (!floatType)
return success();
if (!floatType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>()) {
return emitError() << "unsupported Python floating point type: "
<< floatType;
}
return success();
}
BuiltinTypeCode PYDM::RealType::getTypeCode() const {
return static_cast<BuiltinTypeCode>(
makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef PYDM::RealType::getPythonTypeName() const { return "float"; }
Optional<NumericCategory> PYDM::RealType::getNumericCategory() const {
if (isWeak())
return NumericCategory::WeakReal;
return NumericCategory::Real;
}
Optional<int> PYDM::RealType::getNumericSubTypeCode() const {
if (isWeak())
return 0;
RealSubTypeCode stc =
TypeSwitch<Type, RealSubTypeCode>(getFloatType())
.Case([](BFloat16Type t) { return RealSubTypeCode::BF16; })
.Case([](Float16Type t) { return RealSubTypeCode::FP16; })
.Case([](Float32Type t) { return RealSubTypeCode::FP32; })
.Case([](Float64Type t) { return RealSubTypeCode::FP64; })
.Default([](Type t) {
assert(false && "unsupported float type");
return RealSubTypeCode::FP64;
});
return static_cast<int>(stc);
}
Optional<int> PYDM::RealType::getNumericPromotionOrder() const {
return static_cast<int>(getTypeCode());
}
bool PYDM::RealType::isWeak() const { return !getImpl()->floatType; }
// StrType
BuiltinTypeCode PYDM::StrType::getTypeCode() const {
return BuiltinTypeCode::Str;
}
StringRef PYDM::StrType::getPythonTypeName() const { return "str"; }
// TupleType
BuiltinTypeCode PYDM::TupleType::getTypeCode() const {
return BuiltinTypeCode::Tuple;
}
StringRef PYDM::TupleType::getPythonTypeName() const { return "tuple"; }
// TypeType
BuiltinTypeCode PYDM::TypeType::getTypeCode() const {
return BuiltinTypeCode::Type;
}
StringRef PYDM::TypeType::getPythonTypeName() const { return "type"; }
Type PYDM::TupleType::getElementStorageType() const {
// TODO: When it implements unboxed storage, switch here.
return ObjectType::get(getContext());
}
//------------------------------------------------------------------------------
// Union type implementation
//------------------------------------------------------------------------------
void PyUnionType::print(mlir::AsmPrinter &printer) const {
llvm::interleaveComma(getAlternatives(), printer);
}
Type PyUnionType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
if (parser.parseOptionalLess())
return get(ctxt, {});
SmallVector<::mlir::Type> alternatives;
do {
Type type;
if (parser.parseType(type))
return Type();
alternatives.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
return getChecked([&]() { return parser.emitError(parser.getNameLoc()); },
ctxt, alternatives);
}
LogicalResult
PYDM::UnionType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> alternatives) {
int lastTypeCode = 0;
for (Type alternative : alternatives) {
if (auto pythonType = alternative.dyn_cast<PYDM::PythonTypeInterface>()) {
int thisTypeCode = static_cast<int>(pythonType.getTypeCode());
// TODO: This doesn't account for parameterized types.
if (thisTypeCode <= lastTypeCode) {
return emitError() << "expected total order of union to be normative. "
"got out of order: "
<< alternative;
}
} else {
return emitError() << "expected a python type in union. got: "
<< alternative;
}
}
return failure();
}