Health and welfare on the iree_pydm dialect. (#6978)
Dialect changes:
* Simplify file naming to limit repetition.
* Add PythonTypeInterface, implemented by all Python types (currently used for some identity and numeric promotion rules).
* Apply previous recommendation and rework var alloc/load/store into a !free_var_ref type, alloc_free_var, load_var, store_var. Cell variables will need something different but this should generalize (i.e. cell vars need to be resolved symbolically, inter-procedurally).
* Add a UnionType in order to support type refinement (not yet used, and still needs some refinement).
* Forked scf.if into `functional_if` for the specific case where we are emitting conditional Python code of a functional nature (shows up in conditionals and short-circuit evals a lot, but most Python control flow is naturally CFG based). With this change, the pydm dialect is self-complete, not relying on ops from outside of itself. This will help with type inference, etc.
* Implemented OpAsmOpInterface::getDefaultDialect on `func` and `functional_if`, making all pydm ops able to be used prefix-free, cleaning up IR a lot.
* Made `none` ConstantLike. Added `success` and `failure` ops to produce `ExceptionResults`.
* Added `make_tuple` op (not yet used).
* Added `promote_numeric` op.
* Implemented simple folders for `constant`, `none`, `success`, `as_bool`, `bool_to_pred`, `raise_on_failure`, `select`.
* Implemented static numeric promotion with: a folder+canonicalizer on `promote_numeric`, a canonicalizer on `dynamic_binary_promote` which reduces it to primitives for static cases.
* Implemented no-op box/unbox canonicalizations.
With this, the dialect is in good shape to start building out the optimization and lowering pipelines.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
new file mode 100644
index 0000000..245e6be
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
@@ -0,0 +1,525 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/FunctionImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+using PyBoolType = mlir::iree_pydm::BoolType;
+using PyConstantOp = mlir::iree_pydm::ConstantOp;
+using PyIntegerType = mlir::iree_pydm::IntegerType;
+using PyRealType = mlir::iree_pydm::RealType;
+using PyCallOp = mlir::iree_pydm::CallOp;
+using PyFuncOp = mlir::iree_pydm::FuncOp;
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static Value getNumericZeroConstant(Location loc, Type numericType,
+ OpBuilder &builder) {
+ return TypeSwitch<Type, Value>(numericType)
+ .Case([&](PyBoolType t) -> Value {
+ return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false));
+ })
+ .Case([&](PyIntegerType t) -> Value {
+ return builder.create<PyConstantOp>(loc, t,
+ builder.getI64IntegerAttr(0));
+ })
+ .Case([&](PyRealType t) -> Value {
+ return builder.create<PyConstantOp>(loc, t,
+ builder.getF64FloatAttr(0.0));
+ });
+}
+
+static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) {
+ return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(),
+ builder.getBoolAttr(pred));
+}
+
+//===----------------------------------------------------------------------===//
+// Constants
+//===----------------------------------------------------------------------===//
+
+OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return getValue();
+}
+
+OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return UnitAttr::get(getContext());
+}
+
+OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return UnitAttr::get(getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Variables
+//===----------------------------------------------------------------------===//
+
+void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+ setNameFn(getResult(), name());
+}
+
+//===----------------------------------------------------------------------===//
+// AsBoolOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsBoolOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.value().getType().isa<BoolType>()) {
+ rewriter.replaceOp(op, op.value());
+ return success();
+ }
+ return failure();
+ }
+};
+
+struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsBoolOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>();
+ if (!ptType) return failure();
+ if (!ptType.getNumericPromotionOrder()) return failure();
+
+ auto boolType = rewriter.getType<BoolType>();
+ Value zeroValue =
+ getNumericZeroConstant(loc, op.value().getType(), rewriter);
+ Value trueValue = getBoolConstant(loc, true, rewriter);
+ Value falseValue = getBoolConstant(loc, false, rewriter);
+ Value cmpResult = rewriter.create<ApplyCompareOp>(
+ loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue);
+ rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue,
+ trueValue);
+ return success();
+ }
+};
+
+} // namespace
+
+void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context);
+}
+
+OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) {
+ Builder b(getContext());
+ // Fold NoneType to False.
+ if (value().getType().isa<NoneType>()) {
+ return b.getBoolAttr(false);
+ }
+
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// BoolToPredOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+ // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1),
+ // we can just return it.
+ return operands[0];
+}
+
+//===----------------------------------------------------------------------===//
+// BoxOp and UnboxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) {
+ // Sometimes boxes are emitted when the input is an object. Just remove.
+ if (op.primitive().getType().isa<ObjectType>()) {
+ rewriter.replaceOp(op, op.primitive());
+ return success();
+ }
+
+ return failure();
+}
+
+LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp,
+ PatternRewriter &rewriter) {
+ auto loc = unboxOp.getLoc();
+
+ // Handle the case of an immediate BoxOp producer.
+ if (auto boxProducer =
+ dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) {
+ // If the producer is boxing to the same type we are unboxing, then
+ // just elide everything.
+ if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) {
+ auto successValue = rewriter.create<SuccessOp>(
+ loc, rewriter.getType<ExceptionResultType>());
+ rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()});
+ return success();
+ }
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicBinaryPromoteOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DynamicBinaryPromoteOp::canonicalize(DynamicBinaryPromoteOp op,
+ PatternRewriter &rewriter) {
+ auto loc = op.getLoc();
+ auto leftType = op.left().getType();
+ auto rightType = op.right().getType();
+ auto leftResultType = op.getResultTypes()[0];
+ auto rightResultType = op.getResultTypes()[1];
+ auto leftPt = leftType.dyn_cast<PythonTypeInterface>();
+ auto rightPt = rightType.dyn_cast<PythonTypeInterface>();
+ if (!leftPt || !rightPt) return failure();
+
+ Optional<int> leftOrder = leftPt.getNumericPromotionOrder();
+ Optional<int> rightOrder = rightPt.getNumericPromotionOrder();
+ Value newLeft = op.left();
+ Value newRight = op.right();
+
+ // Simple case: same types pass through.
+ if (leftType == rightType) {
+ // Nothing - pass-through rewrite.
+ } else if (leftOrder && rightOrder) {
+ // Both numeric.
+ if (*leftOrder > *rightOrder) {
+ newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight);
+ }
+ if (*rightOrder > *leftOrder) {
+ newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft);
+ }
+ } else {
+ return failure();
+ }
+
+ // Need to box back to the original type (which will always be a generic
+ // object).
+ newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft);
+ newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight);
+
+ rewriter.replaceOp(op, {newLeft, newRight});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FunctionalIfOp
+//===----------------------------------------------------------------------===//
+
+::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; }
+
+static LogicalResult verify(FunctionalIfOp op) {
+ if (op.getNumResults() != 0 && op.elseRegion().empty())
+ return op.emitOpError("must have an else block if defining values");
+
+ return RegionBranchOpInterface::verifyTypes(op);
+}
+
+static ParseResult parseFunctionalIfOp(OpAsmParser &parser,
+ OperationState &result) {
+ // Create the regions for 'then'.
+ result.regions.reserve(2);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+
+ auto &builder = parser.getBuilder();
+ OpAsmParser::OperandType cond;
+ Type conditionType = builder.getType<PyBoolType>();
+ if (parser.parseOperand(cond) ||
+ parser.resolveOperand(cond, conditionType, result.operands))
+ return failure();
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types)) return failure();
+ // Parse the 'then' region.
+ if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
+
+ // If we find an 'else' keyword then parse the 'else' region.
+ if (!parser.parseOptionalKeyword("else")) {
+ if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
+ // result.location);
+ }
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ return success();
+}
+
+static void print(OpAsmPrinter &p, FunctionalIfOp op) {
+ bool printBlockTerminators = false;
+
+ p << " " << op.condition();
+ if (!op.results().empty()) {
+ p << " -> (" << op.getResultTypes() << ")";
+ // Print yield explicitly if the op defines values.
+ printBlockTerminators = true;
+ }
+ p.printRegion(op.thenRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/printBlockTerminators);
+
+ // Print the 'else' regions if it exists and has a block.
+ auto &elseRegion = op.elseRegion();
+ if (!elseRegion.empty()) {
+ p << " else";
+ p.printRegion(elseRegion,
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/printBlockTerminators);
+ }
+
+ p.printOptionalAttrDict(op->getAttrs());
+}
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void FunctionalIfOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The `then` and the `else` region branch back to the parent operation.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // Don't consider the else region if it is empty.
+ Region *elseRegion = &this->elseRegion();
+ if (elseRegion->empty()) elseRegion = nullptr;
+
+ // Otherwise, the successor is dependent on the condition.
+ if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
+ bool condition = condAttr.getValue();
+ // Add the successor regions using the condition.
+ regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
+ } else {
+ // If the condition isn't constant, both regions may be executed.
+ regions.push_back(RegionSuccessor(&thenRegion()));
+ // If the else region does not exist, it is not a viable successor.
+ if (elseRegion) regions.push_back(RegionSuccessor(elseRegion));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; }
+
+LogicalResult PyFuncOp::verifyType() {
+ // TODO: Enforce arg/result invariants.
+ return success();
+}
+
+static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
+ ArrayRef<Type> results,
+ function_like_impl::VariadicFlag, std::string &) {
+ return builder.getFunctionType(argTypes, results);
+ };
+
+ return function_like_impl::parseFunctionLikeOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+static void print(PyFuncOp op, OpAsmPrinter &p) {
+ FunctionType fnType = op.getType();
+ function_like_impl::printFunctionLikeOp(
+ p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
+}
+
+static LogicalResult verify(PyFuncOp op) {
+ // TODO: Enforce invariants.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatchCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PatternMatchCallOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
+ for (auto symbolAttr : symbols) {
+ auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
+ PyFuncOp fn =
+ symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
+ if (!fn)
+ return emitOpError() << "'" << symbol.getValue()
+ << "' does not reference a valid function";
+ }
+ return success();
+ };
+ auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
+ if (!genericsAttr)
+ return emitOpError(
+ "requires a 'generic_match' array of symbol reference attributes");
+ if (failed(verifySymbols(genericsAttr))) return failure();
+
+ auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
+ if (!specificsAttr)
+ return emitOpError(
+ "requires a 'specific_match' array of symbol reference attributes");
+ if (failed(verifySymbols(specificsAttr))) return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PromoteNumericOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+
+ Builder b(getContext());
+ Attribute fromAttr = operands[0];
+ return TypeSwitch<Type, OpFoldResult>(getResult().getType())
+ .Case([&](PyIntegerType toType) -> OpFoldResult {
+ return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
+ .Case([&](BoolAttr fromBool) -> OpFoldResult {
+ return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0);
+ })
+ .Default([](Attribute) -> OpFoldResult { return {}; });
+ })
+ .Case([&](PyRealType toType) -> OpFoldResult {
+ return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
+ .Case([&](BoolAttr fromBool) -> OpFoldResult {
+ return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0);
+ })
+ .Case([&](IntegerAttr fromInteger) -> OpFoldResult {
+ APInt value = fromInteger.getValue();
+ return b.getF64FloatAttr(value.getSExtValue());
+ })
+ .Default([](Attribute) -> OpFoldResult { return {}; });
+ })
+ .Default([](Type) -> OpFoldResult { return {}; });
+}
+
+LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op,
+ PatternRewriter &rewriter) {
+ if (op.input().getType() == op.getResult().getType()) {
+ rewriter.replaceOp(op, op.input());
+ return success();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// RaiseOnFailureOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult iree_pydm::RaiseOnFailureOp::fold(
+ ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
+ assert(operands.size() == 1 && "expected one fold operand");
+ // Unit exception result is success. Just elide.
+ if (operands[0] && operands[0].isa<UnitAttr>()) {
+ erase();
+ return success();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+
+ BoolAttr boolAttr = operands[0].cast<BoolAttr>();
+ if (boolAttr.getValue())
+ return true_value();
+ else
+ return false_value();
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
+ if (!fn)
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+
+ // Verify that the operand and result types match the callee.
+ auto fnType = fn.getType();
+ if (fnType.getNumInputs() != getNumOperands())
+ return emitOpError("incorrect number of operands for callee");
+
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i).getType() != fnType.getInput(i)) {
+ return emitOpError("operand type mismatch: expected operand type ")
+ << fnType.getInput(i) << ", but provided "
+ << getOperand(i).getType() << " for operand number " << i;
+ }
+ }
+
+ if (fnType.getNumResults() != getNumResults())
+ return emitOpError("incorrect number of results for callee");
+
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i).getType() != fnType.getResult(i)) {
+ auto diag = emitOpError("result type mismatch at index ") << i;
+ diag.attachNote() << " op result types: " << getResultTypes();
+ diag.attachNote() << "function result types: " << fnType.getResults();
+ return diag;
+ }
+ }
+
+ return success();
+}
+
+FunctionType PyCallOp::getCalleeType() {
+ return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DynamicCallOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
+ if (!fn || !isa<PyFuncOp>(fn))
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+ return success();
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"