Health and welfare on the iree_pydm dialect. (#6978)
Dialect changes:
* Simplify file naming to limit repetition.
* Add PythonTypeInterface, implemented by all Python types (currently used for some identity and numeric promotion rules).
* Apply previous recommendation and rework var alloc/load/store into a !free_var_ref type, alloc_free_var, load_var, store_var. Cell variables will need something different but this should generalize (i.e. cell vars need to be resolved symbolically, inter-procedurally).
* Add a UnionType in order to support type refinement (not yet used, and still needs some refinement).
* Forked scf.if into `functional_if` for the specific case where we are emitting conditional Python code of a functional nature (shows up in conditionals and short-circuit evals a lot, but most Python control flow is naturally CFG based). With this change, the pydm dialect is self-complete, not relying on ops from outside of itself. This will help with type inference, etc.
* Implemented OpAsmOpInterface::getDefaultDialect on `func` and `functional_if`, making all pydm ops able to be used prefix-free, cleaning up IR a lot.
* Made `none` ConstantLike. Added `success` and `failure` ops to produce `ExceptionResults`.
* Added `make_tuple` op (not yet used).
* Added `promote_numeric` op.
* Implemented simple folders for `constant`, `none`, `success`, `as_bool`, `bool_to_pred`, `raise_on_failure`, `select`.
* Implemented static numeric promotion with: a folder+canonicalizer on `promote_numeric`, a canonicalizer on `dynamic_binary_promote` which reduces it to primitives for static cases.
* Implemented no-op box/unbox canonicalizations.
With this, the dialect is in good shape to start building out the optimization and lowering pipelines.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
index cb991b8..b6293fe 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_library(IREEDialectsIREEPyDMDialect
- IREEPyDMDialect.cpp
- IREEPyDMOps.cpp
+ Dialect.cpp
+ Ops.cpp
ADDITIONAL_HEADER_DIRS
${IREE_DIALECTS_SOURCE_DIR}/include
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
new file mode 100644
index 0000000..4b6870d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
@@ -0,0 +1,200 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h"
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/TypeInterfaces.cpp.inc"
+#include "iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc"
+
+//------------------------------------------------------------------------------
+// Dialect implementation
+//------------------------------------------------------------------------------
+
+using BuiltinIntegerType = mlir::IntegerType;
+
+using PyBoolType = mlir::iree_pydm::BoolType;
+using PyConstantOp = mlir::iree_pydm::ConstantOp;
+using PyIntegerType = mlir::iree_pydm::IntegerType;
+using PyRealType = mlir::iree_pydm::RealType;
+
+void IREEPyDMDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc"
+ >();
+ addOperations<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"
+ >();
+}
+
+Operation *IREEPyDMDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ // Since we support materialization of builtin types too, explicitly
+ // allow these.
+ if (type.isa<PyBoolType, BytesType, PyIntegerType, PyRealType, StrType,
+ BuiltinIntegerType>()) {
+ return builder.create<iree_pydm::ConstantOp>(loc, type, value);
+ }
+
+ if (type.isa<NoneType>()) {
+ return builder.create<iree_pydm::NoneOp>(loc, type);
+ }
+
+ if (type.isa<ExceptionResultType>() && value.isa<UnitAttr>()) {
+ return builder.create<iree_pydm::SuccessOp>(loc, type);
+ }
+
+ llvm_unreachable("unhandled iree_pydm constant materialization");
+}
+
+Type IREEPyDMDialect::parseType(DialectAsmParser &parser) const {
+ StringRef typeTag;
+ if (succeeded(parser.parseKeyword(&typeTag))) {
+ Type genType;
+ auto parseResult =
+ generatedTypeParser(getContext(), parser, typeTag, genType);
+ if (parseResult.hasValue()) {
+ if (*parseResult) {
+ return Type();
+ }
+ return genType;
+ }
+ }
+
+ parser.emitError(parser.getNameLoc(), "unknown dialect type");
+ return Type();
+}
+
+void IREEPyDMDialect::printType(Type type, DialectAsmPrinter &printer) const {
+ (void)generatedTypePrinter(type, printer);
+}
+
+//------------------------------------------------------------------------------
+// Python type implementation
+//------------------------------------------------------------------------------
+
+BuiltinTypeCode iree_pydm::BoolType::getTypeCode() const {
+ return BuiltinTypeCode::Bool;
+}
+
+StringRef iree_pydm::BoolType::getPythonTypeName() const { return "bool"; }
+
+Optional<int> iree_pydm::BoolType::getNumericPromotionOrder() const {
+ return 1;
+}
+
+BuiltinTypeCode iree_pydm::BytesType::getTypeCode() const {
+ return BuiltinTypeCode::Bytes;
+}
+
+StringRef iree_pydm::BytesType::getPythonTypeName() const { return "bytes"; }
+
+BuiltinTypeCode iree_pydm::ExceptionResultType::getTypeCode() const {
+ return BuiltinTypeCode::ExceptionResult;
+}
+
+StringRef iree_pydm::ExceptionResultType::getPythonTypeName() const {
+ return "Exception";
+}
+
+BuiltinTypeCode iree_pydm::IntegerType::getTypeCode() const {
+ return BuiltinTypeCode::Integer;
+}
+
+StringRef iree_pydm::IntegerType::getPythonTypeName() const { return "int"; }
+
+Optional<int> iree_pydm::IntegerType::getNumericPromotionOrder() const {
+ return 2;
+}
+
+BuiltinTypeCode iree_pydm::ListType::getTypeCode() const {
+ return BuiltinTypeCode::List;
+}
+
+StringRef iree_pydm::ListType::getPythonTypeName() const { return "list"; }
+
+BuiltinTypeCode iree_pydm::NoneType::getTypeCode() const {
+ return BuiltinTypeCode::None;
+}
+
+StringRef iree_pydm::NoneType::getPythonTypeName() const { return "None"; }
+
+BuiltinTypeCode iree_pydm::ObjectType::getTypeCode() const {
+ return BuiltinTypeCode::Object;
+}
+
+StringRef iree_pydm::ObjectType::getPythonTypeName() const { return "object"; }
+
+BuiltinTypeCode iree_pydm::RealType::getTypeCode() const {
+ return BuiltinTypeCode::Real;
+}
+
+StringRef iree_pydm::RealType::getPythonTypeName() const { return "float"; }
+
+Optional<int> iree_pydm::RealType::getNumericPromotionOrder() const {
+ return 3;
+}
+
+BuiltinTypeCode iree_pydm::StrType::getTypeCode() const {
+ return BuiltinTypeCode::Str;
+}
+
+StringRef iree_pydm::StrType::getPythonTypeName() const { return "str"; }
+
+BuiltinTypeCode iree_pydm::TupleType::getTypeCode() const {
+ return BuiltinTypeCode::Tuple;
+}
+
+StringRef iree_pydm::TupleType::getPythonTypeName() const { return "tuple"; }
+
+BuiltinTypeCode iree_pydm::TypeType::getTypeCode() const {
+ return BuiltinTypeCode::Type;
+}
+
+StringRef iree_pydm::TypeType::getPythonTypeName() const { return "type"; }
+
+//------------------------------------------------------------------------------
+// Union type implementation
+//------------------------------------------------------------------------------
+
+LogicalResult iree_pydm::UnionType::verify(
+ llvm::function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> alternatives) {
+ int lastTypeCode = 0;
+ for (Type alternative : alternatives) {
+ if (auto pythonType =
+ alternative.dyn_cast<iree_pydm::PythonTypeInterface>()) {
+ int thisTypeCode = static_cast<int>(pythonType.getTypeCode());
+ // TODO: This doesn't account for parameterized types.
+ if (thisTypeCode <= lastTypeCode) {
+ return emitError() << "expected total order of union to be normative. "
+ "got out of order: "
+ << alternative;
+ }
+ } else {
+ return emitError() << "expected a python type in union. got: "
+ << alternative;
+ }
+ }
+
+ return failure();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMDialect.cpp
deleted file mode 100644
index 50043ae..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMDialect.cpp
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/Support/LLVM.h"
-
-using namespace mlir;
-using namespace mlir::iree_pydm;
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsDialect.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.cpp.inc"
-
-void IREEPyDMDialect::initialize() {
- addTypes<
-#define GET_TYPEDEF_LIST
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.cpp.inc"
- >();
- addOperations<
-#define GET_OP_LIST
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp.inc"
- >();
-}
-
-Type IREEPyDMDialect::parseType(DialectAsmParser &parser) const {
- StringRef typeTag;
- Type genType;
- if (succeeded(parser.parseKeyword(&typeTag)))
- generatedTypeParser(getContext(), parser, typeTag, genType);
- return genType;
-}
-
-void IREEPyDMDialect::printType(Type type, DialectAsmPrinter &printer) const {
- (void)generatedTypePrinter(type, printer);
-}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp
deleted file mode 100644
index 1bb1fa4..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp
+++ /dev/null
@@ -1,149 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h"
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/FunctionImplementation.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/TypeUtilities.h"
-
-using namespace mlir;
-using namespace mlir::iree_pydm;
-
-using PyCallOp = mlir::iree_pydm::CallOp;
-using PyFuncOp = mlir::iree_pydm::FuncOp;
-
-//===----------------------------------------------------------------------===//
-// FuncOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult PyFuncOp::verifyType() {
- // TODO: Enforce arg/result invariants.
- return success();
-}
-
-static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
- auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
- ArrayRef<Type> results,
- function_like_impl::VariadicFlag, std::string &) {
- return builder.getFunctionType(argTypes, results);
- };
-
- return function_like_impl::parseFunctionLikeOp(
- parser, result, /*allowVariadic=*/false, buildFuncType);
-}
-
-static void print(PyFuncOp op, OpAsmPrinter &p) {
- FunctionType fnType = op.getType();
- function_like_impl::printFunctionLikeOp(
- p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
-}
-
-static LogicalResult verify(PyFuncOp op) {
- // TODO: Enforce invariants.
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// PatternMatchCallOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult PatternMatchCallOp::verifySymbolUses(
- SymbolTableCollection &symbolTable) {
- auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
- for (auto symbolAttr : symbols) {
- auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
- PyFuncOp fn =
- symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
- if (!fn)
- return emitOpError() << "'" << symbol.getValue()
- << "' does not reference a valid function";
- }
- return success();
- };
- auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
- if (!genericsAttr)
- return emitOpError(
- "requires a 'generic_match' array of symbol reference attributes");
- if (failed(verifySymbols(genericsAttr))) return failure();
-
- auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
- if (!specificsAttr)
- return emitOpError(
- "requires a 'specific_match' array of symbol reference attributes");
- if (failed(verifySymbols(specificsAttr))) return failure();
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// CallOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // Check that the callee attribute was specified.
- auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
- if (!fnAttr)
- return emitOpError("requires a 'callee' symbol reference attribute");
- PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
- if (!fn)
- return emitOpError() << "'" << fnAttr.getValue()
- << "' does not reference a valid function";
-
- // Verify that the operand and result types match the callee.
- auto fnType = fn.getType();
- if (fnType.getNumInputs() != getNumOperands())
- return emitOpError("incorrect number of operands for callee");
-
- for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
- if (getOperand(i).getType() != fnType.getInput(i)) {
- return emitOpError("operand type mismatch: expected operand type ")
- << fnType.getInput(i) << ", but provided "
- << getOperand(i).getType() << " for operand number " << i;
- }
- }
-
- if (fnType.getNumResults() != getNumResults())
- return emitOpError("incorrect number of results for callee");
-
- for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
- if (getResult(i).getType() != fnType.getResult(i)) {
- auto diag = emitOpError("result type mismatch at index ") << i;
- diag.attachNote() << " op result types: " << getResultTypes();
- diag.attachNote() << "function result types: " << fnType.getResults();
- return diag;
- }
- }
-
- return success();
-}
-
-FunctionType PyCallOp::getCalleeType() {
- return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
-}
-
-//===----------------------------------------------------------------------===//
-// DynamicCallOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult DynamicCallOp::verifySymbolUses(
- SymbolTableCollection &symbolTable) {
- // Check that the callee attribute was specified.
- auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
- if (!fnAttr)
- return emitOpError("requires a 'callee' symbol reference attribute");
- Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
- if (!fn || !isa<PyFuncOp>(fn))
- return emitOpError() << "'" << fnAttr.getValue()
- << "' does not reference a valid function";
- return success();
-}
-
-#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
new file mode 100644
index 0000000..245e6be
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
@@ -0,0 +1,525 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/FunctionImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+using PyBoolType = mlir::iree_pydm::BoolType;
+using PyConstantOp = mlir::iree_pydm::ConstantOp;
+using PyIntegerType = mlir::iree_pydm::IntegerType;
+using PyRealType = mlir::iree_pydm::RealType;
+using PyCallOp = mlir::iree_pydm::CallOp;
+using PyFuncOp = mlir::iree_pydm::FuncOp;
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static Value getNumericZeroConstant(Location loc, Type numericType,
+ OpBuilder &builder) {
+ return TypeSwitch<Type, Value>(numericType)
+ .Case([&](PyBoolType t) -> Value {
+ return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false));
+ })
+ .Case([&](PyIntegerType t) -> Value {
+ return builder.create<PyConstantOp>(loc, t,
+ builder.getI64IntegerAttr(0));
+ })
+ .Case([&](PyRealType t) -> Value {
+ return builder.create<PyConstantOp>(loc, t,
+ builder.getF64FloatAttr(0.0));
+ });
+}
+
+static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) {
+ return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(),
+ builder.getBoolAttr(pred));
+}
+
+//===----------------------------------------------------------------------===//
+// Constants
+//===----------------------------------------------------------------------===//
+
+OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return getValue();
+}
+
+OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return UnitAttr::get(getContext());
+}
+
+OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return UnitAttr::get(getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Variables
+//===----------------------------------------------------------------------===//
+
+void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+ setNameFn(getResult(), name());
+}
+
+//===----------------------------------------------------------------------===//
+// AsBoolOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsBoolOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.value().getType().isa<BoolType>()) {
+ rewriter.replaceOp(op, op.value());
+ return success();
+ }
+ return failure();
+ }
+};
+
+struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsBoolOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>();
+ if (!ptType) return failure();
+ if (!ptType.getNumericPromotionOrder()) return failure();
+
+ auto boolType = rewriter.getType<BoolType>();
+ Value zeroValue =
+ getNumericZeroConstant(loc, op.value().getType(), rewriter);
+ Value trueValue = getBoolConstant(loc, true, rewriter);
+ Value falseValue = getBoolConstant(loc, false, rewriter);
+ Value cmpResult = rewriter.create<ApplyCompareOp>(
+ loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue);
+ rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue,
+ trueValue);
+ return success();
+ }
+};
+
+} // namespace
+
+void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context);
+}
+
+OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) {
+ Builder b(getContext());
+ // Fold NoneType to False.
+ if (value().getType().isa<NoneType>()) {
+ return b.getBoolAttr(false);
+ }
+
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// BoolToPredOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+ // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1),
+ // we can just return it.
+ return operands[0];
+}
+
+//===----------------------------------------------------------------------===//
+// BoxOp and UnboxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) {
+ // Sometimes boxes are emitted when the input is an object. Just remove.
+ if (op.primitive().getType().isa<ObjectType>()) {
+ rewriter.replaceOp(op, op.primitive());
+ return success();
+ }
+
+ return failure();
+}
+
+LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp,
+ PatternRewriter &rewriter) {
+ auto loc = unboxOp.getLoc();
+
+ // Handle the case of an immediate BoxOp producer.
+ if (auto boxProducer =
+ dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) {
+ // If the producer is boxing to the same type we are unboxing, then
+ // just elide everything.
+ if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) {
+ auto successValue = rewriter.create<SuccessOp>(
+ loc, rewriter.getType<ExceptionResultType>());
+ rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()});
+ return success();
+ }
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicBinaryPromoteOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DynamicBinaryPromoteOp::canonicalize(DynamicBinaryPromoteOp op,
+ PatternRewriter &rewriter) {
+ auto loc = op.getLoc();
+ auto leftType = op.left().getType();
+ auto rightType = op.right().getType();
+ auto leftResultType = op.getResultTypes()[0];
+ auto rightResultType = op.getResultTypes()[1];
+ auto leftPt = leftType.dyn_cast<PythonTypeInterface>();
+ auto rightPt = rightType.dyn_cast<PythonTypeInterface>();
+ if (!leftPt || !rightPt) return failure();
+
+ Optional<int> leftOrder = leftPt.getNumericPromotionOrder();
+ Optional<int> rightOrder = rightPt.getNumericPromotionOrder();
+ Value newLeft = op.left();
+ Value newRight = op.right();
+
+ // Simple case: same types pass through.
+ if (leftType == rightType) {
+ // Nothing - pass-through rewrite.
+ } else if (leftOrder && rightOrder) {
+ // Both numeric.
+ if (*leftOrder > *rightOrder) {
+ newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight);
+ }
+ if (*rightOrder > *leftOrder) {
+ newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft);
+ }
+ } else {
+ return failure();
+ }
+
+ // Need to box back to the original type (which will always be a generic
+ // object).
+ newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft);
+ newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight);
+
+ rewriter.replaceOp(op, {newLeft, newRight});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FunctionalIfOp
+//===----------------------------------------------------------------------===//
+
+::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; }
+
+static LogicalResult verify(FunctionalIfOp op) {
+ if (op.getNumResults() != 0 && op.elseRegion().empty())
+ return op.emitOpError("must have an else block if defining values");
+
+ return RegionBranchOpInterface::verifyTypes(op);
+}
+
+static ParseResult parseFunctionalIfOp(OpAsmParser &parser,
+ OperationState &result) {
+ // Create the regions for 'then'.
+ result.regions.reserve(2);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+
+ auto &builder = parser.getBuilder();
+ OpAsmParser::OperandType cond;
+ Type conditionType = builder.getType<PyBoolType>();
+ if (parser.parseOperand(cond) ||
+ parser.resolveOperand(cond, conditionType, result.operands))
+ return failure();
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types)) return failure();
+ // Parse the 'then' region.
+ if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
+
+ // If we find an 'else' keyword then parse the 'else' region.
+ if (!parser.parseOptionalKeyword("else")) {
+ if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
+ // result.location);
+ }
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ return success();
+}
+
+static void print(OpAsmPrinter &p, FunctionalIfOp op) {
+ bool printBlockTerminators = false;
+
+ p << " " << op.condition();
+ if (!op.results().empty()) {
+ p << " -> (" << op.getResultTypes() << ")";
+ // Print yield explicitly if the op defines values.
+ printBlockTerminators = true;
+ }
+ p.printRegion(op.thenRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/printBlockTerminators);
+
+ // Print the 'else' regions if it exists and has a block.
+ auto &elseRegion = op.elseRegion();
+ if (!elseRegion.empty()) {
+ p << " else";
+ p.printRegion(elseRegion,
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/printBlockTerminators);
+ }
+
+ p.printOptionalAttrDict(op->getAttrs());
+}
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void FunctionalIfOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The `then` and the `else` region branch back to the parent operation.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // Don't consider the else region if it is empty.
+ Region *elseRegion = &this->elseRegion();
+ if (elseRegion->empty()) elseRegion = nullptr;
+
+ // Otherwise, the successor is dependent on the condition.
+ if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
+ bool condition = condAttr.getValue();
+ // Add the successor regions using the condition.
+ regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
+ } else {
+ // If the condition isn't constant, both regions may be executed.
+ regions.push_back(RegionSuccessor(&thenRegion()));
+ // If the else region does not exist, it is not a viable successor.
+ if (elseRegion) regions.push_back(RegionSuccessor(elseRegion));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; }
+
+LogicalResult PyFuncOp::verifyType() {
+ // TODO: Enforce arg/result invariants.
+ return success();
+}
+
+static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
+ ArrayRef<Type> results,
+ function_like_impl::VariadicFlag, std::string &) {
+ return builder.getFunctionType(argTypes, results);
+ };
+
+ return function_like_impl::parseFunctionLikeOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+static void print(PyFuncOp op, OpAsmPrinter &p) {
+ FunctionType fnType = op.getType();
+ function_like_impl::printFunctionLikeOp(
+ p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
+}
+
+static LogicalResult verify(PyFuncOp op) {
+ // TODO: Enforce invariants.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatchCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PatternMatchCallOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
+ for (auto symbolAttr : symbols) {
+ auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
+ PyFuncOp fn =
+ symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
+ if (!fn)
+ return emitOpError() << "'" << symbol.getValue()
+ << "' does not reference a valid function";
+ }
+ return success();
+ };
+ auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
+ if (!genericsAttr)
+ return emitOpError(
+ "requires a 'generic_match' array of symbol reference attributes");
+ if (failed(verifySymbols(genericsAttr))) return failure();
+
+ auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
+ if (!specificsAttr)
+ return emitOpError(
+ "requires a 'specific_match' array of symbol reference attributes");
+ if (failed(verifySymbols(specificsAttr))) return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PromoteNumericOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+
+ Builder b(getContext());
+ Attribute fromAttr = operands[0];
+ return TypeSwitch<Type, OpFoldResult>(getResult().getType())
+ .Case([&](PyIntegerType toType) -> OpFoldResult {
+ return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
+ .Case([&](BoolAttr fromBool) -> OpFoldResult {
+ return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0);
+ })
+ .Default([](Attribute) -> OpFoldResult { return {}; });
+ })
+ .Case([&](PyRealType toType) -> OpFoldResult {
+ return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
+ .Case([&](BoolAttr fromBool) -> OpFoldResult {
+ return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0);
+ })
+ .Case([&](IntegerAttr fromInteger) -> OpFoldResult {
+ APInt value = fromInteger.getValue();
+ return b.getF64FloatAttr(value.getSExtValue());
+ })
+ .Default([](Attribute) -> OpFoldResult { return {}; });
+ })
+ .Default([](Type) -> OpFoldResult { return {}; });
+}
+
+LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op,
+ PatternRewriter &rewriter) {
+ if (op.input().getType() == op.getResult().getType()) {
+ rewriter.replaceOp(op, op.input());
+ return success();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// RaiseOnFailureOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult iree_pydm::RaiseOnFailureOp::fold(
+ ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
+ assert(operands.size() == 1 && "expected one fold operand");
+ // Unit exception result is success. Just elide.
+ if (operands[0] && operands[0].isa<UnitAttr>()) {
+ erase();
+ return success();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+
+ BoolAttr boolAttr = operands[0].cast<BoolAttr>();
+ if (boolAttr.getValue())
+ return true_value();
+ else
+ return false_value();
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
+ if (!fn)
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+
+ // Verify that the operand and result types match the callee.
+ auto fnType = fn.getType();
+ if (fnType.getNumInputs() != getNumOperands())
+ return emitOpError("incorrect number of operands for callee");
+
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i).getType() != fnType.getInput(i)) {
+ return emitOpError("operand type mismatch: expected operand type ")
+ << fnType.getInput(i) << ", but provided "
+ << getOperand(i).getType() << " for operand number " << i;
+ }
+ }
+
+ if (fnType.getNumResults() != getNumResults())
+ return emitOpError("incorrect number of results for callee");
+
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i).getType() != fnType.getResult(i)) {
+ auto diag = emitOpError("result type mismatch at index ") << i;
+ diag.attachNote() << " op result types: " << getResultTypes();
+ diag.attachNote() << "function result types: " << fnType.getResults();
+ return diag;
+ }
+ }
+
+ return success();
+}
+
+FunctionType PyCallOp::getCalleeType() {
+ return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DynamicCallOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
+ if (!fn || !isa<PyFuncOp>(fn))
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+ return success();
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"