[pydm] Defines the structure for the full numeric hierarchy. (#7274)
* [pydm] Defines the structure for the full numeric hierarchy.
* Full support modeled for signed/unsigned 8/16/32/64 bit integers, fp16/bf16/fp32/fp64, complex64/complex128, bool, weak integer, weak real, arbitrary precision integer.
* Actual support for everything is more limited. Using a frontend pass to squash all weak types to i32/f32 for now (type inference/analysis needs to come into play here before making such decisions).
* Numeric promotion is in-flux at the moment, but shooting for a combination of Numba/Cython/JAX reasoning about this. Key is that weak integer/real types exist and bind to the hierarchy in different ways. See: https://jax.readthedocs.io/en/latest/type_promotion.html
* This makes the generic runtime support a lot more complicated and required quite a few more lowerings and canonicalizations to achieve (i.e. the runtime library decodes the bit patterns in the type code to make numeric type decisions).
* The generated code is still a joke and not something we would ever use, but it does run: https://gist.github.com/stellaraccident/e9f41a09a3834465d7576312fc63c278
* Still holding off on any real optimizations beyond canonicalizations since generality is helpful at this stage. Most of what is there should melt away with some simple variable load/store analysis.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
index f52fdd5..e978ef0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
@@ -92,13 +92,20 @@
//------------------------------------------------------------------------------
BuiltinTypeCode iree_pydm::BoolType::getTypeCode() const {
- return BuiltinTypeCode::Bool;
+ return static_cast<BuiltinTypeCode>(
+ makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef iree_pydm::BoolType::getPythonTypeName() const { return "bool"; }
+Optional<NumericCategory> iree_pydm::BoolType::getNumericCategory() const {
+ return NumericCategory::Bool;
+}
+
+Optional<int> iree_pydm::BoolType::getNumericSubTypeCode() const { return 0; }
+
Optional<int> iree_pydm::BoolType::getNumericPromotionOrder() const {
- return 1;
+ return static_cast<int>(getTypeCode());
}
BuiltinTypeCode iree_pydm::BytesType::getTypeCode() const {
@@ -115,14 +122,63 @@
return "Exception";
}
+LogicalResult iree_pydm::IntegerType::verify(
+ function_ref<InFlightDiagnostic()> emitError, Optional<int> bitWidth) {
+ if (!bitWidth) return success();
+ int w = abs(*bitWidth);
+ if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64) return success();
+ return emitError() << "unsupported python integer bit width: " << w;
+}
+
BuiltinTypeCode iree_pydm::IntegerType::getTypeCode() const {
- return BuiltinTypeCode::Integer;
+ return static_cast<BuiltinTypeCode>(
+ makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef iree_pydm::IntegerType::getPythonTypeName() const { return "int"; }
+Optional<NumericCategory> iree_pydm::IntegerType::getNumericCategory() const {
+ if (isWeak()) return NumericCategory::WeakInteger;
+ if (getBitWidth() == 0) return NumericCategory::APSigned;
+ if (isSigned()) return NumericCategory::Signed;
+ return NumericCategory::Unsigned;
+}
+
+Optional<int> iree_pydm::IntegerType::getNumericSubTypeCode() const {
+ if (isWeak()) return 0;
+ IntegerSubTypeCode stc;
+ switch (getBitWidth()) {
+ case 8:
+ stc = IntegerSubTypeCode::Integer8;
+ break;
+ case 16:
+ stc = IntegerSubTypeCode::Integer16;
+ break;
+ case 32:
+ stc = IntegerSubTypeCode::Integer32;
+ break;
+ case 64:
+ stc = IntegerSubTypeCode::Integer64;
+ break;
+ default: {
+ llvm_unreachable("unsupported numeric bitwidth");
+ }
+ }
+ return static_cast<int>(stc);
+}
+
Optional<int> iree_pydm::IntegerType::getNumericPromotionOrder() const {
- return 2;
+ return static_cast<int>(getTypeCode());
+}
+
+bool iree_pydm::IntegerType::isWeak() const { return !getImpl()->bitWidth; }
+
+unsigned iree_pydm::IntegerType::getBitWidth() const {
+ return abs(*getImpl()->bitWidth);
+}
+
+bool iree_pydm::IntegerType::isSigned() const {
+ return *getImpl()->bitWidth >= 0;
}
BuiltinTypeCode iree_pydm::ListType::getTypeCode() const {
@@ -143,16 +199,49 @@
StringRef iree_pydm::ObjectType::getPythonTypeName() const { return "object"; }
+LogicalResult iree_pydm::RealType::verify(
+ function_ref<InFlightDiagnostic()> emitError, FloatType floatType) {
+ if (!floatType) return success();
+ if (!floatType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>()) {
+ return emitError() << "unsupported Python floating point type: "
+ << floatType;
+ }
+ return success();
+}
+
BuiltinTypeCode iree_pydm::RealType::getTypeCode() const {
- return BuiltinTypeCode::Real;
+ return static_cast<BuiltinTypeCode>(
+ makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef iree_pydm::RealType::getPythonTypeName() const { return "float"; }
-Optional<int> iree_pydm::RealType::getNumericPromotionOrder() const {
- return 3;
+Optional<NumericCategory> iree_pydm::RealType::getNumericCategory() const {
+ if (isWeak()) return NumericCategory::WeakReal;
+ return NumericCategory::Real;
}
+Optional<int> iree_pydm::RealType::getNumericSubTypeCode() const {
+ if (isWeak()) return 0;
+ RealSubTypeCode stc =
+ TypeSwitch<Type, RealSubTypeCode>(getFloatType())
+ .Case([](BFloat16Type t) { return RealSubTypeCode::BF16; })
+ .Case([](Float16Type t) { return RealSubTypeCode::FP16; })
+ .Case([](Float32Type t) { return RealSubTypeCode::FP32; })
+ .Case([](Float64Type t) { return RealSubTypeCode::FP64; })
+ .Default([](Type t) {
+ llvm_unreachable("unsupported float type");
+ return RealSubTypeCode::FP64;
+ });
+ return static_cast<int>(stc);
+}
+
+Optional<int> iree_pydm::RealType::getNumericPromotionOrder() const {
+ return static_cast<int>(getTypeCode());
+}
+
+bool iree_pydm::RealType::isWeak() const { return !getImpl()->floatType; }
+
BuiltinTypeCode iree_pydm::StrType::getTypeCode() const {
return BuiltinTypeCode::Str;
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
index 245e6be..3f93482 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
@@ -77,6 +77,39 @@
}
//===----------------------------------------------------------------------===//
+// ApplyCompareOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Matches an `apply_compare` op where both operands are defined by
+/// `box` ops that have the same operand type. Replaces the operands with the
+/// operands of the `box`.
+struct UnboxApplyCompareOperands : public OpRewritePattern<ApplyCompareOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ApplyCompareOp op,
+ PatternRewriter &rewriter) const override {
+ auto boxLeft = op.left().getDefiningOp<BoxOp>();
+ auto boxRight = op.right().getDefiningOp<BoxOp>();
+ if (!boxLeft || !boxRight) return failure();
+ if (boxLeft.primitive().getType() != boxRight.primitive().getType())
+ return failure();
+ rewriter.replaceOpWithNewOp<ApplyCompareOp>(
+ op, rewriter.getType<BoolType>(), op.dunder_nameAttr(),
+ boxLeft.primitive(), boxRight.primitive());
+ return success();
+ }
+};
+
+} // namespace
+
+void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<UnboxApplyCompareOperands>(context);
+}
+
+//===----------------------------------------------------------------------===//
// AsBoolOp
//===----------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
index 44fe3b2..db6f84f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Optimize)
add_subdirectory(RTL)
add_subdirectory(ToIREE)
@@ -8,8 +9,10 @@
MLIRIREEPyDMTransformsPassesIncGen
LINK_LIBS PUBLIC
+ IREEDialectsIREEPyDMOptimizePasses
IREEDialectsIREEPyDMRTLPasses
IREEDialectsIREEPyDMToIREEPasses
+ MLIRTransforms
)
iree_dialects_target_includes(IREEDialectsIREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
new file mode 100644
index 0000000..4bd33e1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_library(IREEDialectsIREEPyDMOptimizePasses
+ FixateWeakNumeric.cpp
+
+ DEPENDS
+ MLIRIREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEDialectsIREEPyDMDialect
+ MLIRIR
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEDialectsIREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
new file mode 100644
index 0000000..3bc936f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
@@ -0,0 +1,113 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+namespace {
+
+struct FixateWeakNumericPass
+ : public FixateWeakNumericBase<FixateWeakNumericPass> {
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+ rootOp->walk([&](Operation *op) {
+ op->dump();
+ convertOperation(op);
+ return WalkResult::advance();
+ });
+ }
+
+ void convertOperation(Operation *op) {
+ // Process all regions/blocks to rewrite block arguments.
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ for (BlockArgument blockArg : block.getArguments()) {
+ convertValue(blockArg);
+ }
+ }
+ }
+
+ // And all results.
+ for (Value result : op->getResults()) {
+ convertValue(result);
+ }
+
+ // Special cases for operations.
+ if (auto funcOp = llvm::dyn_cast<iree_pydm::FuncOp>(op)) {
+ FunctionType existingFt = funcOp.getType();
+ FunctionType newFt = convertFunctionType(existingFt);
+ if (newFt != existingFt) {
+ funcOp.setType(newFt);
+ }
+ }
+ }
+
+ void convertValue(Value value) {
+ value.setType(convertType(value.getType()));
+ }
+
+ Type convertType(Type type) {
+ // TODO: The specific types we promote to need to be configured by the
+ // lowering options.
+ if (auto integerType = type.dyn_cast<iree_pydm::IntegerType>()) {
+ if (integerType.isWeak()) {
+ return iree_pydm::IntegerType::get(type.getContext(), 32);
+ }
+ } else if (auto realType = type.dyn_cast<iree_pydm::RealType>()) {
+ if (realType.isWeak()) {
+ return iree_pydm::RealType::get(
+ type.getContext(), mlir::Float32Type::get(type.getContext()));
+ }
+ } else if (auto objectType = type.dyn_cast<iree_pydm::ObjectType>()) {
+ Type primitiveType = objectType.getPrimitiveType();
+ if (primitiveType) {
+ Type newPrimitiveType = convertType(primitiveType);
+ if (newPrimitiveType != primitiveType) {
+ return iree_pydm::ObjectType::get(
+ type.getContext(),
+ newPrimitiveType.cast<iree_pydm::PrimitiveType>());
+ }
+ }
+ }
+
+ return type;
+ }
+
+ FunctionType convertFunctionType(FunctionType ft) {
+ SmallVector<Type> inputs(ft.getInputs().begin(), ft.getInputs().end());
+ SmallVector<Type> results(ft.getResults().begin(), ft.getResults().end());
+ bool modified = false;
+ for (Type &type : inputs) {
+ Type newType = convertType(type);
+ if (type != newType) {
+ type = newType;
+ modified = true;
+ }
+ }
+ for (Type &type : results) {
+ Type newType = convertType(type);
+ if (type != newType) {
+ type = newType;
+ modified = true;
+ }
+ }
+
+ if (!modified) return ft;
+
+ return FunctionType::get(ft.getContext(), inputs, results);
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<>>
+mlir::iree_pydm::createFixateWeakNumericPass() {
+ return std::make_unique<FixateWeakNumericPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
index caaaf14..20aeb14 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
@@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::iree_pydm;
@@ -18,5 +19,14 @@
if (options.linkRtlSource) {
passManager.addPass(createLinkIREEPyDMRTLPass(options.linkRtlSource));
}
+ // TODO: Optimization passes need to be their own pipeline.
+ passManager.addPass(createFixateWeakNumericPass());
+ passManager.addPass(createCanonicalizerPass());
+
+ // Lowering passes.
passManager.addPass(createConvertIREEPyDMToIREEPass());
+
+ // Cleanup.
+ passManager.addPass(createCanonicalizerPass());
+ passManager.addPass(createCSEPass());
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
index 2d4d88d..4ac31e8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -43,12 +43,19 @@
// Some CFG ops can be present in the original pydm program. Need to
// verify legality based on types.
- target.addDynamicallyLegalOp<BranchOp>([&](BranchOp op) -> bool {
+ target.addDynamicallyLegalOp<BranchOp>([&](mlir::BranchOp op) -> bool {
return typeConverter.areTypesLegal(op.getOperandTypes());
});
- target.addDynamicallyLegalOp<CondBranchOp>([&](CondBranchOp op) -> bool {
- return typeConverter.areTypesLegal(op.getOperandTypes());
- });
+ target.addDynamicallyLegalOp<CondBranchOp>(
+ [&](mlir::CondBranchOp op) -> bool {
+ return typeConverter.areTypesLegal(op.getOperandTypes());
+ });
+
+ // Standard select can be emitted as part of CFG canonicalization.
+ target.addDynamicallyLegalOp<mlir::SelectOp>(
+ [&](mlir::SelectOp op) -> bool {
+ return typeConverter.areTypesLegal(op.getOperandTypes());
+ });
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
return signalPassFailure();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 9b3e1fd..e537125 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -175,6 +175,65 @@
}
};
+class ApplyBinaryNumericConversion
+ : public OpConversionPattern<pydm_d::ApplyBinaryOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::ApplyBinaryOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type leftType = adaptor.left().getType();
+ Type rightType = adaptor.right().getType();
+ Type resultType = typeConverter->convertType(srcOp.result().getType());
+ if (!resultType || leftType != rightType || leftType != resultType) {
+ return rewriter.notifyMatchFailure(srcOp,
+ "not same type operands/results");
+ }
+ if (leftType.isa<builtin_d::IntegerType>()) {
+ bool isSigned = true; // TODO: Unsigned.
+ Value converted =
+ convertIntegerOp(srcOp.getLoc(), adaptor.dunder_name().getValue(),
+ adaptor.left(), adaptor.right(), isSigned, rewriter);
+ if (!converted)
+ return rewriter.notifyMatchFailure(srcOp, "unsupported operation");
+ rewriter.replaceOp(srcOp, converted);
+ return success();
+ } else if (leftType.isa<builtin_d::FloatType>()) {
+ // TODO: Implement float binary
+ return rewriter.notifyMatchFailure(srcOp, "unsupported operation");
+ }
+
+ return rewriter.notifyMatchFailure(srcOp, "non numeric type");
+ }
+
+ Value convertIntegerOp(Location loc, StringRef dunderName, Value left,
+ Value right, bool isSigned,
+ ConversionPatternRewriter &rewriter) const {
+ // TODO: matmul, truediv, floordiv, mod, divmod, pow
+ if (dunderName == "add") {
+ return rewriter.create<arith_d::AddIOp>(loc, left, right);
+ } else if (dunderName == "and") {
+ return rewriter.create<arith_d::AndOp>(loc, left, right);
+ } else if (dunderName == "mul") {
+ return rewriter.create<arith_d::MulIOp>(loc, left, right);
+ } else if (dunderName == "lshift") {
+ return rewriter.create<arith_d::ShiftLeftOp>(loc, left, right);
+ } else if (dunderName == "or") {
+ return rewriter.create<arith_d::OrOp>(loc, left, right);
+ } else if (dunderName == "rshift") {
+ if (isSigned)
+ return rewriter.create<arith_d::SignedShiftRightOp>(loc, left, right);
+ else
+ return rewriter.create<arith_d::UnsignedShiftRightOp>(loc, left, right);
+ } else if (dunderName == "sub") {
+ return rewriter.create<arith_d::SubIOp>(loc, left, right);
+ } else if (dunderName == "xor") {
+ return rewriter.create<arith_d::XOrOp>(loc, left, right);
+ }
+ return nullptr;
+ }
+};
+
class ApplyCompareNumericConversion
: public OpConversionPattern<pydm_d::ApplyCompareOp> {
using OpConversionPattern::OpConversionPattern;
@@ -242,17 +301,6 @@
}
};
-class BranchConversion : public OpConversionPattern<std_d::BranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- std_d::BranchOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<std_d::BranchOp>(srcOp, srcOp.dest(),
- adaptor.destOperands());
- return success();
- }
-};
-
class CallOpConversion : public OpConversionPattern<pydm_d::CallOp> {
using OpConversionPattern::OpConversionPattern;
@@ -272,19 +320,6 @@
}
};
-class CondBranchConversion : public OpConversionPattern<std_d::CondBranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- std_d::CondBranchOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<std_d::CondBranchOp>(
- srcOp, adaptor.condition(), srcOp.trueDest(),
- adaptor.trueDestOperands(), srcOp.falseDest(),
- adaptor.falseDestOperands());
- return success();
- }
-};
-
class ConstantOpConversion : public OpConversionPattern<pydm_d::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
@@ -331,6 +366,22 @@
}
};
+/// Generates a failure exception code.
+/// This is just temporary to allow some libraries to signal exceptions.
+class FailureOpConversion : public OpConversionPattern<pydm_d::FailureOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::FailureOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type i32 = rewriter.getI32Type();
+ // '-3' == RuntimeError
+ rewriter.replaceOpWithNewOp<std_d::ConstantOp>(
+ srcOp, i32, rewriter.getIntegerAttr(i32, -3));
+ return success();
+ }
+};
+
class FuncOpConversion : public OpConversionPattern<pydm_d::FuncOp> {
using OpConversionPattern::OpConversionPattern;
@@ -603,19 +654,66 @@
}
};
+//------------------------------------------------------------------------------
+// Outside pydm op conversions
+// These are largely identity conversions for CFG related standard ops, and
+// those that can be emitted as part of canonicalizations.
+//------------------------------------------------------------------------------
+
+class BuiltinBranchConversion : public OpConversionPattern<std_d::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::BranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::BranchOp>(srcOp, srcOp.dest(),
+ adaptor.destOperands());
+ return success();
+ }
+};
+
+class BuiltinCondBranchConversion
+ : public OpConversionPattern<std_d::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::CondBranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::CondBranchOp>(
+ srcOp, adaptor.condition(), srcOp.trueDest(),
+ adaptor.trueDestOperands(), srcOp.falseDest(),
+ adaptor.falseDestOperands());
+ return success();
+ }
+};
+
+class BuiltinSelectConversion : public OpConversionPattern<std_d::SelectOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::SelectOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::SelectOp>(srcOp, adaptor.condition(),
+ adaptor.true_value(),
+ adaptor.false_value());
+ return success();
+ }
+};
+
} // namespace
void mlir::iree_pydm::populatePyDMToIREELoweringPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- // Structural.
- patterns.insert<AllocFreeVarOpConversion, ApplyCompareNumericConversion,
- BoolToPredConversion, BoxOpConversion, BranchConversion,
- CallOpConversion, CondBranchConversion, ConstantOpConversion,
- FuncOpConversion, GetTypeCodeConversion, LoadVarOpConversion,
- RaiseOnFailureOpConversion, ReturnOpConversion,
- StoreVarOpConversion, UnboxOpConversion>(typeConverter,
- context);
+ // PyDM conversions.
+ patterns.insert<AllocFreeVarOpConversion, ApplyBinaryNumericConversion,
+ ApplyCompareNumericConversion, BoolToPredConversion,
+ BoxOpConversion, CallOpConversion, ConstantOpConversion,
+ FailureOpConversion, FuncOpConversion, GetTypeCodeConversion,
+ LoadVarOpConversion, RaiseOnFailureOpConversion,
+ ReturnOpConversion, StoreVarOpConversion, UnboxOpConversion>(
+ typeConverter, context);
+
+ // External CFG ops.
+ patterns.insert<BuiltinBranchConversion, BuiltinCondBranchConversion,
+ BuiltinSelectConversion>(typeConverter, context);
// Constants and constructors.
patterns.insert<NoneOpConversion>(typeConverter, context);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
index 0ee4fb7..eac25d4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -45,13 +45,19 @@
// Integer type hierarchy.
addConversion([&](pydm_d::IntegerType t) -> Optional<Type> {
Builder b(t.getContext());
- return getWeakIntegerType(b);
+ if (t.isWeak()) {
+ return getWeakIntegerType(b);
+ }
+ return b.getIntegerType(t.getBitWidth());
});
// Real type hierarchy.
addConversion([&](pydm_d::RealType t) -> Optional<Type> {
Builder b(t.getContext());
- return getWeakFloatType(b);
+ if (t.isWeak()) {
+ return getWeakFloatType(b);
+ }
+ return t.getFloatType();
});
// Variable references.