blob: ca77eace83dc96ff9a2424eb66cbef63ccd6b165 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.cpp.inc"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.cpp.inc"
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
//------------------------------------------------------------------------------
// Dialect implementation
//------------------------------------------------------------------------------
using BuiltinIntegerType = mlir::IntegerType;
using PyBoolType = PYDM::BoolType;
using PyConstantOp = PYDM::ConstantOp;
using PyIntegerType = PYDM::IntegerType;
using PyRealType = PYDM::RealType;
void IREEPyDMDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
>();
}
Operation *IREEPyDMDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// Since we support materialization of builtin types too, explicitly
// allow these.
if (type.isa<PyBoolType, BytesType, PyIntegerType, PyRealType, StrType,
BuiltinIntegerType>()) {
return builder.create<PYDM::ConstantOp>(loc, type, value);
}
if (type.isa<NoneType>()) {
return builder.create<PYDM::NoneOp>(loc, type);
}
if (type.isa<ExceptionResultType>() && value.isa<UnitAttr>()) {
return builder.create<PYDM::SuccessOp>(loc, type);
}
llvm_unreachable("unhandled iree_pydm constant materialization");
}
//------------------------------------------------------------------------------
// Python type implementation
//------------------------------------------------------------------------------
// BoolType
BuiltinTypeCode PYDM::BoolType::getTypeCode() const {
return static_cast<BuiltinTypeCode>(
makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef PYDM::BoolType::getPythonTypeName() const { return "bool"; }
Optional<NumericCategory> PYDM::BoolType::getNumericCategory() const {
return NumericCategory::Bool;
}
Optional<int> PYDM::BoolType::getNumericSubTypeCode() const { return 0; }
Optional<int> PYDM::BoolType::getNumericPromotionOrder() const {
return static_cast<int>(getTypeCode());
}
// BytesType
BuiltinTypeCode PYDM::BytesType::getTypeCode() const {
return BuiltinTypeCode::Bytes;
}
StringRef PYDM::BytesType::getPythonTypeName() const { return "bytes"; }
// ExceptionResultType
BuiltinTypeCode PYDM::ExceptionResultType::getTypeCode() const {
return BuiltinTypeCode::ExceptionResult;
}
StringRef PYDM::ExceptionResultType::getPythonTypeName() const {
return "Exception";
}
// IntegerType
LogicalResult PYDM::IntegerType::verify(
function_ref<InFlightDiagnostic()> emitError, Optional<int> bitWidth) {
if (!bitWidth) return success();
int w = abs(*bitWidth);
if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64) return success();
return emitError() << "unsupported python integer bit width: " << w;
}
BuiltinTypeCode PYDM::IntegerType::getTypeCode() const {
return static_cast<BuiltinTypeCode>(
makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef PYDM::IntegerType::getPythonTypeName() const { return "int"; }
Optional<NumericCategory> PYDM::IntegerType::getNumericCategory() const {
if (isWeak()) return NumericCategory::WeakInteger;
if (getBitWidth() == 0) return NumericCategory::APSigned;
if (isSigned()) return NumericCategory::Signed;
return NumericCategory::Unsigned;
}
Optional<int> PYDM::IntegerType::getNumericSubTypeCode() const {
if (isWeak()) return 0;
IntegerSubTypeCode stc;
switch (getBitWidth()) {
case 8:
stc = IntegerSubTypeCode::Integer8;
break;
case 16:
stc = IntegerSubTypeCode::Integer16;
break;
case 32:
stc = IntegerSubTypeCode::Integer32;
break;
case 64:
stc = IntegerSubTypeCode::Integer64;
break;
default: {
llvm_unreachable("unsupported numeric bitwidth");
}
}
return static_cast<int>(stc);
}
Optional<int> PYDM::IntegerType::getNumericPromotionOrder() const {
return static_cast<int>(getTypeCode());
}
bool PYDM::IntegerType::isWeak() const { return !getImpl()->bitWidth; }
unsigned PYDM::IntegerType::getBitWidth() const {
return abs(*getImpl()->bitWidth);
}
bool PYDM::IntegerType::isSigned() const { return *getImpl()->bitWidth >= 0; }
BuiltinTypeCode PYDM::ListType::getTypeCode() const {
return BuiltinTypeCode::List;
}
// ListType
StringRef PYDM::ListType::getPythonTypeName() const { return "list"; }
BuiltinTypeCode PYDM::NoneType::getTypeCode() const {
return BuiltinTypeCode::List;
}
bool PYDM::ListType::isRefinable() const {
if (getStorageClass() == CollectionStorageClass::Empty) return false;
if (!getUniformElementType()) return true;
if (auto pyType = getUniformElementType().dyn_cast<PythonTypeInterface>())
return pyType.isRefinable();
return false;
}
Type PYDM::ListType::getElementStorageType() const {
switch (getStorageClass()) {
case CollectionStorageClass::Boxed:
case CollectionStorageClass::Empty:
return ObjectType::get(getContext());
case CollectionStorageClass::Unboxed:
assert(getUniformElementType() &&
"unboxed list should have uniform element type");
return getUniformElementType();
default:
llvm_unreachable("unsupported storage class");
return {};
}
}
// NoneType
StringRef PYDM::NoneType::getPythonTypeName() const { return "None"; }
// ObjectType
BuiltinTypeCode PYDM::ObjectType::getTypeCode() const {
return BuiltinTypeCode::Object;
}
StringRef PYDM::ObjectType::getPythonTypeName() const { return "object"; }
bool PYDM::ObjectType::isRefinable() const {
if (!getPrimitiveType()) return true;
if (auto pyType = getPrimitiveType().dyn_cast<PythonTypeInterface>())
return pyType.isRefinable();
return false;
}
// RealType
LogicalResult PYDM::RealType::verify(
function_ref<InFlightDiagnostic()> emitError, FloatType floatType) {
if (!floatType) return success();
if (!floatType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>()) {
return emitError() << "unsupported Python floating point type: "
<< floatType;
}
return success();
}
BuiltinTypeCode PYDM::RealType::getTypeCode() const {
return static_cast<BuiltinTypeCode>(
makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef PYDM::RealType::getPythonTypeName() const { return "float"; }
Optional<NumericCategory> PYDM::RealType::getNumericCategory() const {
if (isWeak()) return NumericCategory::WeakReal;
return NumericCategory::Real;
}
Optional<int> PYDM::RealType::getNumericSubTypeCode() const {
if (isWeak()) return 0;
RealSubTypeCode stc =
TypeSwitch<Type, RealSubTypeCode>(getFloatType())
.Case([](BFloat16Type t) { return RealSubTypeCode::BF16; })
.Case([](Float16Type t) { return RealSubTypeCode::FP16; })
.Case([](Float32Type t) { return RealSubTypeCode::FP32; })
.Case([](Float64Type t) { return RealSubTypeCode::FP64; })
.Default([](Type t) {
llvm_unreachable("unsupported float type");
return RealSubTypeCode::FP64;
});
return static_cast<int>(stc);
}
Optional<int> PYDM::RealType::getNumericPromotionOrder() const {
return static_cast<int>(getTypeCode());
}
bool PYDM::RealType::isWeak() const { return !getImpl()->floatType; }
// StrType
BuiltinTypeCode PYDM::StrType::getTypeCode() const {
return BuiltinTypeCode::Str;
}
StringRef PYDM::StrType::getPythonTypeName() const { return "str"; }
// TupleType
BuiltinTypeCode PYDM::TupleType::getTypeCode() const {
return BuiltinTypeCode::Tuple;
}
StringRef PYDM::TupleType::getPythonTypeName() const { return "tuple"; }
// TypeType
BuiltinTypeCode PYDM::TypeType::getTypeCode() const {
return BuiltinTypeCode::Type;
}
StringRef PYDM::TypeType::getPythonTypeName() const { return "type"; }
Type PYDM::TupleType::getElementStorageType() const {
// TODO: When it implements unboxed storage, switch here.
return ObjectType::get(getContext());
}
//------------------------------------------------------------------------------
// Union type implementation
//------------------------------------------------------------------------------
LogicalResult PYDM::UnionType::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> alternatives) {
int lastTypeCode = 0;
for (Type alternative : alternatives) {
if (auto pythonType = alternative.dyn_cast<PYDM::PythonTypeInterface>()) {
int thisTypeCode = static_cast<int>(pythonType.getTypeCode());
// TODO: This doesn't account for parameterized types.
if (thisTypeCode <= lastTypeCode) {
return emitError() << "expected total order of union to be normative. "
"got out of order: "
<< alternative;
}
} else {
return emitError() << "expected a python type in union. got: "
<< alternative;
}
}
return failure();
}