[pydm] Defines the structure for the full numeric hierarchy. (#7274)
* [pydm] Defines the structure for the full numeric hierarchy.
* Full support modeled for signed/unsigned 8/16/32/64 bit integers, fp16/bf16/fp32/fp64, complex64/complex128, bool, weak integer, weak real, arbitrary precision integer.
* Actual support for everything is more limited. Using a frontend pass to squash all weak types to i32/f32 for now (type inference/analysis needs to come into play here before making such decisions).
* Numeric promotion is in-flux at the moment, but shooting for a combination of Numba/Cython/JAX reasoning about this. Key is that weak integer/real types exist and bind to the hierarchy in different ways. See: https://jax.readthedocs.io/en/latest/type_promotion.html
* This makes the generic runtime support a lot more complicated and required quite a few more lowerings and canonicalizations to achieve (i.e. the runtime library decodes the bit patterns in the type code to make numeric type decisions).
* The generated code is still a joke and not something we would ever use, but it does run: https://gist.github.com/stellaraccident/e9f41a09a3834465d7576312fc63c278
* Still holding off on any real optimizations beyond canonicalizations since generality is helpful at this stage. Most of what is there should melt away with some simple variable load/store analysis.
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 90340be..a96da53 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -296,6 +296,7 @@
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
index da8701a..1c5f81a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
@@ -57,6 +57,7 @@
IREEPYDM_DECLARE_NULLARY_TYPE(Bool)
IREEPYDM_DECLARE_NULLARY_TYPE(Bytes)
+// Note: Also has a non-nullary constructor
IREEPYDM_DECLARE_NULLARY_TYPE(Integer)
IREEPYDM_DECLARE_NULLARY_TYPE(ExceptionResult)
IREEPYDM_DECLARE_NULLARY_TYPE(FreeVarRef)
@@ -69,6 +70,12 @@
#undef IREEPYDM_DECLARE_NULLARY_TYPE
+// Non-nullary Type constructors from the above.
+MLIR_CAPI_EXPORTED MlirType mlirIREEPyDMIntegerTypeGetExplicit(MlirContext ctx,
+ int bitWidth,
+ bool isSigned);
+
+// ObjectType.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDMObject(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirIREEPyDMObjectTypeGet(MlirContext context,
MlirType primitive);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Constants.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Constants.h
new file mode 100644
index 0000000..2e54a60
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Constants.h
@@ -0,0 +1,133 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_CONSTANTS_H
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_CONSTANTS_H
+
+namespace mlir {
+namespace iree_pydm {
+
+/// Category of the numeric type. These are arranged such that during promotion,
+/// the type with the largest category value determines the category of
+/// promotion.
+enum class NumericCategory : int {
+ Bool = 0,
+ WeakInteger = 1,
+ Unsigned = 2,
+ Signed = 3,
+ APSigned = 4,
+ WeakReal = 5,
+ Real = 6,
+ WeakComplex = 7,
+ Complex = 8,
+};
+
+/// For integer types (Unsigned, Signed, and Bool category), this is the type
+/// specific sub type code for sizes that we support.
+/// Only POT bit sizes up to 64bits are supported. They sort into promotion
+/// order within a category.
+enum class IntegerSubTypeCode : int {
+ Integer8 = 0,
+ Integer16 = 1,
+ Integer32 = 2,
+ Integer64 = 3,
+};
+
+/// As with integer types, this is the type specific code for supported
+/// floating point types within the Real category. They sort into promotion
+/// order with the special case that combining an FP16 and BF16 promotes to
+/// FP32.
+enum class RealSubTypeCode : int {
+ FP16 = 0,
+ BF16 = 1,
+ FP32 = 2,
+ FP64 = 3,
+};
+
+/// Sub type code for complex types, which consist of two floating point
+/// values (either FP32 or FP64). Space is retained in the enumeration for
+/// 16bit elements.
+enum class ComplexSubTypeCode : int {
+ UNUSED0 = 0,
+ UNUSED1 = 1,
+ COMPLEX64 = 2,
+ COMPLEX128 = 3,
+};
+
+/// Makes a numeric category code with bit pattern:
+/// 1 C C C C S S
+/// Where 'C' is category code and 'S' is sub type code.
+/// These range from 0x40 - 0x7f
+template <typename SubTypeCode>
+constexpr int makeNumericTypeCode(const NumericCategory cat,
+ const SubTypeCode subType) {
+ return 0x40 | (static_cast<int>(cat) << 2) | (static_cast<int>(subType));
+}
+
+// Each built-in (to the compiler) type has a unique code, enumerated here.
+// Generally, the closed part of the type system will have type codes <
+// FirstCustom.
+// If editing, also update the constants in rtl/modules/constants.py.
+enum class BuiltinTypeCode : int {
+ // Built-in types, ordered by rough "core-ness" so that lower numbers
+ // are easier to spot for common cases.
+ None = 0x1,
+ Tuple = 0x2,
+ List = 0x3,
+ Str = 0x4,
+ Bytes = 0x5,
+ ExceptionResult = 0x6,
+ Type = 0x7,
+
+ // Start of the encoded numeric types codes. Lower 5 bits represent a bit
+ // packed encoding of the numeric category (3 bits) and sub type
+ // code (2 bits):
+ NumericStart = 0x20,
+ NumericBool = makeNumericTypeCode(NumericCategory::Bool, 0),
+ WeakInteger = makeNumericTypeCode(NumericCategory::WeakInteger, 0),
+ NumericUnsigned8Bit = makeNumericTypeCode(NumericCategory::Unsigned,
+ IntegerSubTypeCode::Integer8),
+ NumericUnsigned16Bit = makeNumericTypeCode(NumericCategory::Unsigned,
+ IntegerSubTypeCode::Integer16),
+ NumericUnsigned32Bit = makeNumericTypeCode(NumericCategory::Unsigned,
+ IntegerSubTypeCode::Integer32),
+ NumericUnsigned64Bit = makeNumericTypeCode(NumericCategory::Unsigned,
+ IntegerSubTypeCode::Integer64),
+ NumericSigned8Bit = makeNumericTypeCode(NumericCategory::Signed,
+ IntegerSubTypeCode::Integer8),
+ NumericSigned16Bit = makeNumericTypeCode(NumericCategory::Signed,
+ IntegerSubTypeCode::Integer16),
+ NumericSigned32Bit = makeNumericTypeCode(NumericCategory::Signed,
+ IntegerSubTypeCode::Integer32),
+ NumericSigned64Bit = makeNumericTypeCode(NumericCategory::Signed,
+ IntegerSubTypeCode::Integer64),
+ NumericAPSigned = makeNumericTypeCode(NumericCategory::APSigned, 0),
+ WeakReal = makeNumericTypeCode(NumericCategory::WeakReal, 0),
+ NumericRealFP16 =
+ makeNumericTypeCode(NumericCategory::Real, RealSubTypeCode::FP16),
+ NumericRealBF16 =
+ makeNumericTypeCode(NumericCategory::Real, RealSubTypeCode::BF16),
+ NumericRealFP32 =
+ makeNumericTypeCode(NumericCategory::Real, RealSubTypeCode::FP32),
+ NumericRealFP64 =
+ makeNumericTypeCode(NumericCategory::Real, RealSubTypeCode::FP64),
+ WeakComplex = makeNumericTypeCode(NumericCategory::WeakComplex, 0),
+ NumericComplex64 = makeNumericTypeCode(NumericCategory::Complex,
+ ComplexSubTypeCode::COMPLEX64),
+ NumericComplex128 = makeNumericTypeCode(NumericCategory::Complex,
+ ComplexSubTypeCode::COMPLEX128),
+ NumericEnd = 0x7f,
+
+ // Objects start at 0x100, with 0x100 being the generic "object" type
+ // and then all following corresponding to user-defined types.
+ Object = 0x100,
+ FirstCustom = 0x101,
+};
+
+} // namespace iree_pydm
+} // namespace mlir
+
+#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_CONSTANTS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h
index 48346d4..47e2c2a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h
@@ -7,64 +7,15 @@
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
+#include "iree-dialects/Dialect/IREEPyDM/IR/Constants.h"
#include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace iree_pydm {
-// Each built-in (to the compiler) type has a unique code, enumerated here.
-// Generally, the closed part of the type system will have type codes <
-// FirstCustom.
-// If editing, also update the constants in rtl/modules/constants.py.
-enum class BuiltinTypeCode : int {
- // Built-in types, ordered by rough "core-ness" so that lower numbers
- // are easier to spot for common cases.
- None = 0x1,
- Tuple = 0x2,
- List = 0x3,
- Str = 0x4,
- Bytes = 0x5,
- ExceptionResult = 0x6,
- Type = 0x7,
-
- // Weak-sized numeric types are of implementation defined size and are
- // always considered lower in the promotion order than a discrete
- // sized type of the same class.
- Bool = 0x8,
- Integer = 0x9,
- Real = 0xa,
- Complex = 0xb,
-
- // Discrete sized integer types.
- // TODO: Fiddle with all of these values so that promotion can be
- // done cleverly with bit twiddling of some kind.
- Integer1 = 0x10,
- Integer2 = 0x11,
- Integer4 = 0x12,
- Integer8 = 0x13,
- UInteger1 = 0x14,
- UInteger2 = 0x15,
- UInteger4 = 0x16,
- UInteger8 = 0x17,
-
- // Discrete sized FP types.
- Float2 = 0x18,
- Float4 = 0x19,
- Float8 = 0x1a,
- BFloat2 = 0x1b,
-
- // Complex.
- Complex4 = 0x1c,
- Complex8 = 0x1d,
-
- // Objects start at 0x100, with 0x100 being the generic "object" type
- // and then all following corresponding to user-defined types.
- Object = 0x100,
- FirstCustom = 0x101,
-};
-
/// Base class for all unboxed primitive types.
class PrimitiveType : public mlir::Type {
public:
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td
index 7cab375..2af65b7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td
@@ -48,7 +48,8 @@
"unboxed primitive type",
"::mlir::iree_pydm::PrimitiveType">;
-def IREEPyDM_BoolType : IREEPyDM_PrimitiveTypeDef<"Bool", ["getNumericPromotionOrder"]> {
+def IREEPyDM_BoolType : IREEPyDM_PrimitiveTypeDef<"Bool", [
+ "getNumericPromotionOrder", "getNumericCategory", "getNumericSubTypeCode"]> {
let mnemonic = "bool";
let summary = "Type of bool values";
@@ -84,16 +85,116 @@
}];
}
-def IREEPyDM_IntegerType : IREEPyDM_PrimitiveTypeDef<"Integer", ["getNumericPromotionOrder"]> {
+def IREEPyDM_IntegerType : IREEPyDM_PrimitiveTypeDef<"Integer", [
+ "getNumericPromotionOrder", "getNumericCategory", "getNumericSubTypeCode"]> {
let mnemonic = "integer";
let summary = "Type of integer values";
let description = [{
- Represents the `numbers.Integral` type in the data model. At this abstract
- level, it should be considered conformant with the data model (i.e.
- unbounded). However, compiler flags will generally be used to interpret
- this type in a more bounded fashion (i32, i64, etc).
+ Represents the `numbers.Integral` type in the data model. Without further
+ qualification, the type is considered weak and open to further inference.
+ It can be further qualified to an explicit type:
+ Signed of a given bitwidth
+ Unsigned of a given bitwidth
+ Arbitrary precision
+
+ Prints as:
+ integer : Weak integer
+ integer<32> : Signed integer of specific bit width
+ integer<unsigned 32> : Unsigned integer of specific bit width
+ integer<*> : Arbitrary precision integer
+ }];
+
+ let parameters = (ins
+ // Encodes:
+ // None: Weak integer
+ // 0 : Arbitrary precision integer
+ // >0 : Signed
+ // <0 : Unsigned
+ "Optional<int>":$bitWidth
+ );
+
+ let skipDefaultBuilders = 1;
+ let genVerifyDecl = 1;
+
+ let builders = [
+ // Builds a weak integer.
+ TypeBuilder<(ins), [{
+ return Base::get($_ctxt, None);
+ }]>,
+ // Builds:
+ // (None): Arbitrary precision integer
+ // (32 [, true]): Signed integer of explicit size
+ // (32, false): Unsigned integer of explicit size
+ TypeBuilder<(ins CArg<"Optional<unsigned>">:$bitWidth,
+ CArg<"bool", "true">:$isSigned), [{
+ if (bitWidth) {
+ // Explicit size
+ int w = *bitWidth;
+ if (!isSigned) w = -w;
+ return Base::get($_ctxt, w);
+ } else {
+ // AP
+ return Base::get($_ctxt, 0);
+ }
+ }]>,
+ ];
+
+ let genAccessors = 0;
+
+ let extraClassDeclaration = [{
+ bool isWeak() const;
+ bool isExplicit() const { return !isWeak(); }
+ unsigned getBitWidth() const;
+ bool isSigned() const;
+ }];
+
+ let printer = [{
+ auto w = getImpl()->bitWidth;
+ $_printer << "integer";
+ if (w) {
+ $_printer << "<";
+ if (*w == 0) {
+ $_printer << "*";
+ } else if (*w > 0) {
+ $_printer << *w;
+ } else {
+ $_printer << "unsigned " << (-*w);
+ }
+ $_printer << ">";
+ }
+ }];
+
+ let parser = [{
+ auto emitError = [&]() -> InFlightDiagnostic{
+ return $_parser.emitError($_parser.getCurrentLocation());
+ };
+ // Weak
+ if (failed($_parser.parseOptionalLess()))
+ return get($_ctxt);
+ // AP
+ if (succeeded($_parser.parseOptionalStar())) {
+ if (failed($_parser.parseGreater()))
+ return Type();
+ return get($_ctxt, None);
+ }
+
+ // Explicit
+ bool isSigned;
+ if (succeeded($_parser.parseOptionalKeyword("unsigned"))) {
+ isSigned = false;
+ } else {
+ isSigned = true;
+ }
+
+ int width;
+ if (failed($_parser.parseInteger(width)))
+ return Type();
+ if (failed($_parser.parseGreater()))
+ return Type();
+ if (!isSigned) width = -width;
+ return getChecked(emitError, $_ctxt, width);
}];
}
@@ -117,16 +218,70 @@
}];
}
-def IREEPyDM_RealType : IREEPyDM_PrimitiveTypeDef<"Real", ["getNumericPromotionOrder"]> {
+def IREEPyDM_RealType : IREEPyDM_PrimitiveTypeDef<"Real",
+ ["getNumericPromotionOrder", "getNumericCategory", "getNumericSubTypeCode"]> {
let mnemonic = "real";
let summary = "Type of floating point values";
let description = [{
- Represents the `numbers.Real` type in the data model. At this abstract
- level, it should be considered conformant with the data model (i.e.
- double precision). However, compiler flags will generally be used to
- interpret this type in a more bounded fashion (f32).
+ Represents the `numbers.Real` type in the data model. Without qualification,
+ the type is considered "weak" and left open to further inference and/or
+ lowering defaults (which may differ from Python norms of treating this as
+ an f64 if so configured).
+
+ Prints as:
+ `real` : Weak real type
+ `real<f32>` : Explicit real type
+ }];
+
+ let parameters = (ins
+ // Encodes:
+ // nullptr: Weak real
+ // FloatType: Explicit real of given floating point type
+ "FloatType":$floatType
+ );
+
+ let skipDefaultBuilders = 1;
+ let genVerifyDecl = 1;
+
+ let builders = [
+ // Builds a weak RealType
+ TypeBuilder<(ins), [{
+ return Base::get($_ctxt, nullptr);
+ }]>,
+ // Builds an explicit RealType
+ TypeBuilder<(ins CArg<"FloatType">:$floatType), [{
+ return Base::get($_ctxt, floatType);
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ bool isWeak() const;
+ bool isExplicit() const { return !isWeak(); }
+ }];
+
+ let printer = [{
+ auto ft = getImpl()->floatType;
+ $_printer << "real";
+ if (ft)
+ $_printer << "<" << ft << ">";
+ }];
+
+ let parser = [{
+ auto emitError = [&]() -> InFlightDiagnostic{
+ return $_parser.emitError($_parser.getCurrentLocation());
+ };
+ // Weak
+ if (failed($_parser.parseOptionalLess()))
+ return get($_ctxt);
+ // Explicit
+ FloatType subType;
+ if (failed($_parser.parseType(subType)))
+ return Type();
+ if (failed($_parser.parseGreater()))
+ return Type();
+ return getChecked(emitError, $_ctxt, subType);
}];
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h
index 0dd249b..0a81dc8 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h
@@ -7,6 +7,7 @@
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_H
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_H
+#include "iree-dialects/Dialect/IREEPyDM/IR/Constants.h"
#include "mlir/IR/Types.h"
namespace mlir {
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td
index 804ded8..db1d55f 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td
@@ -29,6 +29,22 @@
}], "llvm::StringRef", "getPythonTypeName", (ins)>,
InterfaceMethod<[{
+ For numeric types, returns the NumericCategory.
+ }], "llvm::Optional<NumericCategory>", "getNumericCategory", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{
+ return {};
+ }]>,
+
+ InterfaceMethod<[{
+ For numeric types, returns an appropriate subtype code, which is an
+ integer from 0-3 representing the specific type with the NumericCategory.
+ Weak types return None if though getNumericCategory() returns a value.
+ }], "llvm::Optional<int>", "getNumericSubTypeCode", (ins),
+ /*methodBody=*/[{}], /*defaultImplementation=*/[{
+ return {};
+ }]>,
+
+ InterfaceMethod<[{
For numeric types, returns the promotion order.
Types with a lower promotion order will be promoted to the higher order
for most binary functions.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td
index b4c7713..a9d3070 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td
@@ -477,7 +477,7 @@
let arguments = (ins IREEPyDM_AnyValueType:$value);
let results = (outs IREEPyDM_IntegerType);
let assemblyFormat = [{
- $value `:` type($value) attr-dict
+ $value `:` type($value) `->` type(results) attr-dict
}];
}
@@ -496,7 +496,7 @@
let arguments = (ins IREEPyDM_AnyValueType:$value);
let results = (outs IREEPyDM_IntegerType);
let assemblyFormat = [{
- $value `:` type($value) attr-dict
+ $value `:` type($value) `->` type(results) attr-dict
}];
}
@@ -650,6 +650,7 @@
let assemblyFormat = [{
$dunder_name `,` $left `,` $right `:` type(operands) attr-dict
}];
+ let hasCanonicalizer = 1;
}
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
index a169051..cc9c619 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
@@ -34,9 +34,10 @@
};
std::unique_ptr<OperationPass<ModuleOp>> createConvertIREEPyDMToIREEPass();
-std::unique_ptr<OperationPass<ModuleOp>> createLowerIREEPyDMToRTLPass();
+std::unique_ptr<OperationPass<>> createFixateWeakNumericPass();
std::unique_ptr<OperationPass<ModuleOp>> createLinkIREEPyDMRTLPass(
Optional<SourceBundle> linkRtlSourceBundle = None);
+std::unique_ptr<OperationPass<ModuleOp>> createLowerIREEPyDMToRTLPass();
void buildLowerToIREEPassPipeline(OpPassManager& passManager,
const LowerToIREEOptions& options);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td
index e98f4cc..8b76c38 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td
@@ -9,6 +9,17 @@
include "mlir/Pass/PassBase.td"
+def FixateWeakNumeric : Pass<"fixate-weak-numeric", ""> {
+ let summary = "Fixates weak numeric (integer/real/complex) to specific types";
+ let description = [{
+ After all type inference is complete, it is necessary to fixate weak numeric
+ types to a concrete type (either an explicit size or arbitrary-precision).
+ This can also be a useful thing to do early in development in order to
+ eliminate these weak types.
+ }];
+ let constructor = "mlir::iree_pydm::createFixateWeakNumericPass()";
+}
+
def LowerIREEPyDMToRTL : Pass<"lower-iree-pydm-to-rtl", "ModuleOp"> {
let summary = "Lowers PyDM ops with runtime implementations to calls";
let description = [{
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index d316fbe..f50d864 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -60,6 +60,14 @@
IREEPYDM_DEFINE_NULLARY_TYPE(Tuple)
IREEPYDM_DEFINE_NULLARY_TYPE(Type)
+// Non-nullary Type constructors from the above.
+MlirType mlirIREEPyDMIntegerTypeGetExplicit(MlirContext ctx, int bitWidth,
+ bool isSigned) {
+ return wrap(
+ mlir::iree_pydm::IntegerType::get(unwrap(ctx), bitWidth, isSigned));
+}
+
+// ObjectType.
bool mlirTypeIsAIREEPyDMObject(MlirType type) {
return unwrap(type).isa<mlir::iree_pydm::ObjectType>();
}
@@ -73,6 +81,7 @@
return wrap(mlir::iree_pydm::ObjectType::get(unwrap(ctx), cppType));
}
+// LowerToIREE Pass Pipeline.
void mlirIREEPyDMBuildLowerToIREEPassPipeline(MlirOpPassManager passManager,
IREEPyDMLoweringOptions options) {
auto *passManagerCpp = unwrap(passManager);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
index f52fdd5..e978ef0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
@@ -92,13 +92,20 @@
//------------------------------------------------------------------------------
BuiltinTypeCode iree_pydm::BoolType::getTypeCode() const {
- return BuiltinTypeCode::Bool;
+ return static_cast<BuiltinTypeCode>(
+ makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef iree_pydm::BoolType::getPythonTypeName() const { return "bool"; }
+Optional<NumericCategory> iree_pydm::BoolType::getNumericCategory() const {
+ return NumericCategory::Bool;
+}
+
+Optional<int> iree_pydm::BoolType::getNumericSubTypeCode() const { return 0; }
+
Optional<int> iree_pydm::BoolType::getNumericPromotionOrder() const {
- return 1;
+ return static_cast<int>(getTypeCode());
}
BuiltinTypeCode iree_pydm::BytesType::getTypeCode() const {
@@ -115,14 +122,63 @@
return "Exception";
}
+LogicalResult iree_pydm::IntegerType::verify(
+ function_ref<InFlightDiagnostic()> emitError, Optional<int> bitWidth) {
+ if (!bitWidth) return success();
+ int w = abs(*bitWidth);
+ if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64) return success();
+ return emitError() << "unsupported python integer bit width: " << w;
+}
+
BuiltinTypeCode iree_pydm::IntegerType::getTypeCode() const {
- return BuiltinTypeCode::Integer;
+ return static_cast<BuiltinTypeCode>(
+ makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef iree_pydm::IntegerType::getPythonTypeName() const { return "int"; }
+Optional<NumericCategory> iree_pydm::IntegerType::getNumericCategory() const {
+ if (isWeak()) return NumericCategory::WeakInteger;
+ if (getBitWidth() == 0) return NumericCategory::APSigned;
+ if (isSigned()) return NumericCategory::Signed;
+ return NumericCategory::Unsigned;
+}
+
+Optional<int> iree_pydm::IntegerType::getNumericSubTypeCode() const {
+ if (isWeak()) return 0;
+ IntegerSubTypeCode stc;
+ switch (getBitWidth()) {
+ case 8:
+ stc = IntegerSubTypeCode::Integer8;
+ break;
+ case 16:
+ stc = IntegerSubTypeCode::Integer16;
+ break;
+ case 32:
+ stc = IntegerSubTypeCode::Integer32;
+ break;
+ case 64:
+ stc = IntegerSubTypeCode::Integer64;
+ break;
+ default: {
+ llvm_unreachable("unsupported numeric bitwidth");
+ }
+ }
+ return static_cast<int>(stc);
+}
+
Optional<int> iree_pydm::IntegerType::getNumericPromotionOrder() const {
- return 2;
+ return static_cast<int>(getTypeCode());
+}
+
+bool iree_pydm::IntegerType::isWeak() const { return !getImpl()->bitWidth; }
+
+unsigned iree_pydm::IntegerType::getBitWidth() const {
+ return abs(*getImpl()->bitWidth);
+}
+
+bool iree_pydm::IntegerType::isSigned() const {
+ return *getImpl()->bitWidth >= 0;
}
BuiltinTypeCode iree_pydm::ListType::getTypeCode() const {
@@ -143,16 +199,49 @@
StringRef iree_pydm::ObjectType::getPythonTypeName() const { return "object"; }
+LogicalResult iree_pydm::RealType::verify(
+ function_ref<InFlightDiagnostic()> emitError, FloatType floatType) {
+ if (!floatType) return success();
+ if (!floatType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>()) {
+ return emitError() << "unsupported Python floating point type: "
+ << floatType;
+ }
+ return success();
+}
+
BuiltinTypeCode iree_pydm::RealType::getTypeCode() const {
- return BuiltinTypeCode::Real;
+ return static_cast<BuiltinTypeCode>(
+ makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
}
StringRef iree_pydm::RealType::getPythonTypeName() const { return "float"; }
-Optional<int> iree_pydm::RealType::getNumericPromotionOrder() const {
- return 3;
+Optional<NumericCategory> iree_pydm::RealType::getNumericCategory() const {
+ if (isWeak()) return NumericCategory::WeakReal;
+ return NumericCategory::Real;
}
+Optional<int> iree_pydm::RealType::getNumericSubTypeCode() const {
+ if (isWeak()) return 0;
+ RealSubTypeCode stc =
+ TypeSwitch<Type, RealSubTypeCode>(getFloatType())
+ .Case([](BFloat16Type t) { return RealSubTypeCode::BF16; })
+ .Case([](Float16Type t) { return RealSubTypeCode::FP16; })
+ .Case([](Float32Type t) { return RealSubTypeCode::FP32; })
+ .Case([](Float64Type t) { return RealSubTypeCode::FP64; })
+ .Default([](Type t) {
+ llvm_unreachable("unsupported float type");
+ return RealSubTypeCode::FP64;
+ });
+ return static_cast<int>(stc);
+}
+
+Optional<int> iree_pydm::RealType::getNumericPromotionOrder() const {
+ return static_cast<int>(getTypeCode());
+}
+
+bool iree_pydm::RealType::isWeak() const { return !getImpl()->floatType; }
+
BuiltinTypeCode iree_pydm::StrType::getTypeCode() const {
return BuiltinTypeCode::Str;
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
index 245e6be..3f93482 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
@@ -77,6 +77,39 @@
}
//===----------------------------------------------------------------------===//
+// ApplyCompareOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Matches an `apply_compare` op where both operands are defined by
+/// `box` ops that have the same operand type. Replaces the operands with the
+/// operands of the `box`.
+struct UnboxApplyCompareOperands : public OpRewritePattern<ApplyCompareOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ApplyCompareOp op,
+ PatternRewriter &rewriter) const override {
+ auto boxLeft = op.left().getDefiningOp<BoxOp>();
+ auto boxRight = op.right().getDefiningOp<BoxOp>();
+ if (!boxLeft || !boxRight) return failure();
+ if (boxLeft.primitive().getType() != boxRight.primitive().getType())
+ return failure();
+ rewriter.replaceOpWithNewOp<ApplyCompareOp>(
+ op, rewriter.getType<BoolType>(), op.dunder_nameAttr(),
+ boxLeft.primitive(), boxRight.primitive());
+ return success();
+ }
+};
+
+} // namespace
+
+void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<UnboxApplyCompareOperands>(context);
+}
+
+//===----------------------------------------------------------------------===//
// AsBoolOp
//===----------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
index 44fe3b2..db6f84f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Optimize)
add_subdirectory(RTL)
add_subdirectory(ToIREE)
@@ -8,8 +9,10 @@
MLIRIREEPyDMTransformsPassesIncGen
LINK_LIBS PUBLIC
+ IREEDialectsIREEPyDMOptimizePasses
IREEDialectsIREEPyDMRTLPasses
IREEDialectsIREEPyDMToIREEPasses
+ MLIRTransforms
)
iree_dialects_target_includes(IREEDialectsIREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
new file mode 100644
index 0000000..4bd33e1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_library(IREEDialectsIREEPyDMOptimizePasses
+ FixateWeakNumeric.cpp
+
+ DEPENDS
+ MLIRIREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEDialectsIREEPyDMDialect
+ MLIRIR
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEDialectsIREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
new file mode 100644
index 0000000..3bc936f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
@@ -0,0 +1,113 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+namespace {
+
+struct FixateWeakNumericPass
+ : public FixateWeakNumericBase<FixateWeakNumericPass> {
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+ rootOp->walk([&](Operation *op) {
+ op->dump();
+ convertOperation(op);
+ return WalkResult::advance();
+ });
+ }
+
+ void convertOperation(Operation *op) {
+ // Process all regions/blocks to rewrite block arguments.
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ for (BlockArgument blockArg : block.getArguments()) {
+ convertValue(blockArg);
+ }
+ }
+ }
+
+ // And all results.
+ for (Value result : op->getResults()) {
+ convertValue(result);
+ }
+
+ // Special cases for operations.
+ if (auto funcOp = llvm::dyn_cast<iree_pydm::FuncOp>(op)) {
+ FunctionType existingFt = funcOp.getType();
+ FunctionType newFt = convertFunctionType(existingFt);
+ if (newFt != existingFt) {
+ funcOp.setType(newFt);
+ }
+ }
+ }
+
+ void convertValue(Value value) {
+ value.setType(convertType(value.getType()));
+ }
+
+ Type convertType(Type type) {
+ // TODO: The specific types we promote to need to be configured by the
+ // lowering options.
+ if (auto integerType = type.dyn_cast<iree_pydm::IntegerType>()) {
+ if (integerType.isWeak()) {
+ return iree_pydm::IntegerType::get(type.getContext(), 32);
+ }
+ } else if (auto realType = type.dyn_cast<iree_pydm::RealType>()) {
+ if (realType.isWeak()) {
+ return iree_pydm::RealType::get(
+ type.getContext(), mlir::Float32Type::get(type.getContext()));
+ }
+ } else if (auto objectType = type.dyn_cast<iree_pydm::ObjectType>()) {
+ Type primitiveType = objectType.getPrimitiveType();
+ if (primitiveType) {
+ Type newPrimitiveType = convertType(primitiveType);
+ if (newPrimitiveType != primitiveType) {
+ return iree_pydm::ObjectType::get(
+ type.getContext(),
+ newPrimitiveType.cast<iree_pydm::PrimitiveType>());
+ }
+ }
+ }
+
+ return type;
+ }
+
+ FunctionType convertFunctionType(FunctionType ft) {
+ SmallVector<Type> inputs(ft.getInputs().begin(), ft.getInputs().end());
+ SmallVector<Type> results(ft.getResults().begin(), ft.getResults().end());
+ bool modified = false;
+ for (Type &type : inputs) {
+ Type newType = convertType(type);
+ if (type != newType) {
+ type = newType;
+ modified = true;
+ }
+ }
+ for (Type &type : results) {
+ Type newType = convertType(type);
+ if (type != newType) {
+ type = newType;
+ modified = true;
+ }
+ }
+
+ if (!modified) return ft;
+
+ return FunctionType::get(ft.getContext(), inputs, results);
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<>>
+mlir::iree_pydm::createFixateWeakNumericPass() {
+ return std::make_unique<FixateWeakNumericPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
index caaaf14..20aeb14 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
@@ -7,6 +7,7 @@
#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::iree_pydm;
@@ -18,5 +19,14 @@
if (options.linkRtlSource) {
passManager.addPass(createLinkIREEPyDMRTLPass(options.linkRtlSource));
}
+ // TODO: Optimization passes need to be their own pipeline.
+ passManager.addPass(createFixateWeakNumericPass());
+ passManager.addPass(createCanonicalizerPass());
+
+ // Lowering passes.
passManager.addPass(createConvertIREEPyDMToIREEPass());
+
+ // Cleanup.
+ passManager.addPass(createCanonicalizerPass());
+ passManager.addPass(createCSEPass());
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
index 2d4d88d..4ac31e8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -43,12 +43,19 @@
// Some CFG ops can be present in the original pydm program. Need to
// verify legality based on types.
- target.addDynamicallyLegalOp<BranchOp>([&](BranchOp op) -> bool {
+ target.addDynamicallyLegalOp<BranchOp>([&](mlir::BranchOp op) -> bool {
return typeConverter.areTypesLegal(op.getOperandTypes());
});
- target.addDynamicallyLegalOp<CondBranchOp>([&](CondBranchOp op) -> bool {
- return typeConverter.areTypesLegal(op.getOperandTypes());
- });
+ target.addDynamicallyLegalOp<CondBranchOp>(
+ [&](mlir::CondBranchOp op) -> bool {
+ return typeConverter.areTypesLegal(op.getOperandTypes());
+ });
+
+ // Standard select can be emitted as part of CFG canonicalization.
+ target.addDynamicallyLegalOp<mlir::SelectOp>(
+ [&](mlir::SelectOp op) -> bool {
+ return typeConverter.areTypesLegal(op.getOperandTypes());
+ });
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
return signalPassFailure();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 9b3e1fd..e537125 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -175,6 +175,65 @@
}
};
+class ApplyBinaryNumericConversion
+ : public OpConversionPattern<pydm_d::ApplyBinaryOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::ApplyBinaryOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type leftType = adaptor.left().getType();
+ Type rightType = adaptor.right().getType();
+ Type resultType = typeConverter->convertType(srcOp.result().getType());
+ if (!resultType || leftType != rightType || leftType != resultType) {
+ return rewriter.notifyMatchFailure(srcOp,
+ "not same type operands/results");
+ }
+ if (leftType.isa<builtin_d::IntegerType>()) {
+ bool isSigned = true; // TODO: Unsigned.
+ Value converted =
+ convertIntegerOp(srcOp.getLoc(), adaptor.dunder_name().getValue(),
+ adaptor.left(), adaptor.right(), isSigned, rewriter);
+ if (!converted)
+ return rewriter.notifyMatchFailure(srcOp, "unsupported operation");
+ rewriter.replaceOp(srcOp, converted);
+ return success();
+ } else if (leftType.isa<builtin_d::FloatType>()) {
+ // TODO: Implement float binary
+ return rewriter.notifyMatchFailure(srcOp, "unsupported operation");
+ }
+
+ return rewriter.notifyMatchFailure(srcOp, "non numeric type");
+ }
+
+ Value convertIntegerOp(Location loc, StringRef dunderName, Value left,
+ Value right, bool isSigned,
+ ConversionPatternRewriter &rewriter) const {
+ // TODO: matmul, truediv, floordiv, mod, divmod, pow
+ if (dunderName == "add") {
+ return rewriter.create<arith_d::AddIOp>(loc, left, right);
+ } else if (dunderName == "and") {
+ return rewriter.create<arith_d::AndOp>(loc, left, right);
+ } else if (dunderName == "mul") {
+ return rewriter.create<arith_d::MulIOp>(loc, left, right);
+ } else if (dunderName == "lshift") {
+ return rewriter.create<arith_d::ShiftLeftOp>(loc, left, right);
+ } else if (dunderName == "or") {
+ return rewriter.create<arith_d::OrOp>(loc, left, right);
+ } else if (dunderName == "rshift") {
+ if (isSigned)
+ return rewriter.create<arith_d::SignedShiftRightOp>(loc, left, right);
+ else
+ return rewriter.create<arith_d::UnsignedShiftRightOp>(loc, left, right);
+ } else if (dunderName == "sub") {
+ return rewriter.create<arith_d::SubIOp>(loc, left, right);
+ } else if (dunderName == "xor") {
+ return rewriter.create<arith_d::XOrOp>(loc, left, right);
+ }
+ return nullptr;
+ }
+};
+
class ApplyCompareNumericConversion
: public OpConversionPattern<pydm_d::ApplyCompareOp> {
using OpConversionPattern::OpConversionPattern;
@@ -242,17 +301,6 @@
}
};
-class BranchConversion : public OpConversionPattern<std_d::BranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- std_d::BranchOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<std_d::BranchOp>(srcOp, srcOp.dest(),
- adaptor.destOperands());
- return success();
- }
-};
-
class CallOpConversion : public OpConversionPattern<pydm_d::CallOp> {
using OpConversionPattern::OpConversionPattern;
@@ -272,19 +320,6 @@
}
};
-class CondBranchConversion : public OpConversionPattern<std_d::CondBranchOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- std_d::CondBranchOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<std_d::CondBranchOp>(
- srcOp, adaptor.condition(), srcOp.trueDest(),
- adaptor.trueDestOperands(), srcOp.falseDest(),
- adaptor.falseDestOperands());
- return success();
- }
-};
-
class ConstantOpConversion : public OpConversionPattern<pydm_d::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
@@ -331,6 +366,22 @@
}
};
+/// Generates a failure exception code.
+/// This is just temporary to allow some libraries to signal exceptions.
+class FailureOpConversion : public OpConversionPattern<pydm_d::FailureOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ pydm_d::FailureOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type i32 = rewriter.getI32Type();
+ // '-3' == RuntimeError
+ rewriter.replaceOpWithNewOp<std_d::ConstantOp>(
+ srcOp, i32, rewriter.getIntegerAttr(i32, -3));
+ return success();
+ }
+};
+
class FuncOpConversion : public OpConversionPattern<pydm_d::FuncOp> {
using OpConversionPattern::OpConversionPattern;
@@ -603,19 +654,66 @@
}
};
+//------------------------------------------------------------------------------
+// Outside pydm op conversions
+// These are largely identity conversions for CFG related standard ops, and
+// those that can be emitted as part of canonicalizations.
+//------------------------------------------------------------------------------
+
+class BuiltinBranchConversion : public OpConversionPattern<std_d::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::BranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::BranchOp>(srcOp, srcOp.dest(),
+ adaptor.destOperands());
+ return success();
+ }
+};
+
+class BuiltinCondBranchConversion
+ : public OpConversionPattern<std_d::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::CondBranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::CondBranchOp>(
+ srcOp, adaptor.condition(), srcOp.trueDest(),
+ adaptor.trueDestOperands(), srcOp.falseDest(),
+ adaptor.falseDestOperands());
+ return success();
+ }
+};
+
+class BuiltinSelectConversion : public OpConversionPattern<std_d::SelectOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ std_d::SelectOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<std_d::SelectOp>(srcOp, adaptor.condition(),
+ adaptor.true_value(),
+ adaptor.false_value());
+ return success();
+ }
+};
+
} // namespace
void mlir::iree_pydm::populatePyDMToIREELoweringPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- // Structural.
- patterns.insert<AllocFreeVarOpConversion, ApplyCompareNumericConversion,
- BoolToPredConversion, BoxOpConversion, BranchConversion,
- CallOpConversion, CondBranchConversion, ConstantOpConversion,
- FuncOpConversion, GetTypeCodeConversion, LoadVarOpConversion,
- RaiseOnFailureOpConversion, ReturnOpConversion,
- StoreVarOpConversion, UnboxOpConversion>(typeConverter,
- context);
+ // PyDM conversions.
+ patterns.insert<AllocFreeVarOpConversion, ApplyBinaryNumericConversion,
+ ApplyCompareNumericConversion, BoolToPredConversion,
+ BoxOpConversion, CallOpConversion, ConstantOpConversion,
+ FailureOpConversion, FuncOpConversion, GetTypeCodeConversion,
+ LoadVarOpConversion, RaiseOnFailureOpConversion,
+ ReturnOpConversion, StoreVarOpConversion, UnboxOpConversion>(
+ typeConverter, context);
+
+ // External CFG ops.
+ patterns.insert<BuiltinBranchConversion, BuiltinCondBranchConversion,
+ BuiltinSelectConversion>(typeConverter, context);
// Constants and constructors.
patterns.insert<NoneOpConversion>(typeConverter, context);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
index 0ee4fb7..eac25d4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -45,13 +45,19 @@
// Integer type hierarchy.
addConversion([&](pydm_d::IntegerType t) -> Optional<Type> {
Builder b(t.getContext());
- return getWeakIntegerType(b);
+ if (t.isWeak()) {
+ return getWeakIntegerType(b);
+ }
+ return b.getIntegerType(t.getBitWidth());
});
// Real type hierarchy.
addConversion([&](pydm_d::RealType t) -> Optional<Type> {
Builder b(t.getContext());
- return getWeakFloatType(b);
+ if (t.isWeak()) {
+ return getWeakFloatType(b);
+ }
+ return t.getFloatType();
});
// Variable references.
diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
index 067c565..058adf5 100644
--- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
+++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
@@ -159,7 +159,6 @@
DEFINE_IREEPYDM_NULLARY_TYPE(Bytes)
DEFINE_IREEPYDM_NULLARY_TYPE(ExceptionResult)
DEFINE_IREEPYDM_NULLARY_TYPE(FreeVarRef)
- DEFINE_IREEPYDM_NULLARY_TYPE(Integer)
DEFINE_IREEPYDM_NULLARY_TYPE(List)
DEFINE_IREEPYDM_NULLARY_TYPE(None)
DEFINE_IREEPYDM_NULLARY_TYPE(Real)
@@ -167,6 +166,25 @@
DEFINE_IREEPYDM_NULLARY_TYPE(Tuple)
DEFINE_IREEPYDM_NULLARY_TYPE(Type)
+ // IntegerType.
+ mlir_type_subclass(iree_pydm_m, "IntegerType", mlirTypeIsAIREEPyDMInteger,
+ typeClass)
+ .def_classmethod(
+ "get",
+ [](py::object cls, MlirContext context) {
+ return cls(mlirIREEPyDMIntegerTypeGet(context));
+ },
+ py::arg("cls"), py::arg("context") = py::none())
+ .def_classmethod(
+ "get_explicit",
+ [](py::object cls, int bitWidth, bool isSigned, MlirContext context) {
+ return cls(mlirIREEPyDMIntegerTypeGetExplicit(context, bitWidth,
+ isSigned));
+ },
+ py::arg("cls"), py::arg("bit_width"), py::arg("is_signed") = true,
+ py::arg("context") = py::none());
+
+ // ObjectType.
mlir_type_subclass(iree_pydm_m, "ObjectType", mlirTypeIsAIREEPyDMObject,
typeClass)
.def_classmethod(
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/booleans.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/booleans.py
index d1f5443..df7060f 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/booleans.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/booleans.py
@@ -14,13 +14,41 @@
@RTL_MODULE.export_pyfunc
def object_as_bool(v) -> bool:
- if is_type(v, TYPE_BOOL):
- return unbox_unchecked_bool(v)
- elif is_type(v, TYPE_NONE):
+ type_code = get_type_code(v)
+ if unbox_i32(type_code) == TYPE_NONE:
return False
- elif is_type(v, TYPE_INTEGER):
- return raw_compare_ne(unbox_unchecked_integer(v), 0)
- elif is_type(v, TYPE_REAL):
- return raw_compare_ne(unbox_unchecked_real(v), 0.0)
+ elif is_numeric_type_code(type_code):
+ numeric_cat = get_type_code_numeric_category(type_code)
+ if unbox_i32(numeric_cat) == TYPE_NUMERIC_CATEGORY_BOOL:
+ return unbox_bool(v)
+ numeric_subtype = get_type_code_numeric_subtype(type_code)
+
+ # Switch based on numeric category and subtype. Generally, we either
+ # promote to the 32bit variant of a category, or if v is the 64bit
+ # variant, then we compare directly with 64bit comparison functions.
+ if unbox_i32(numeric_cat) == TYPE_NUMERIC_CATEGORY_SIGNED:
+ if unbox_i32(numeric_subtype) == TYPE_NUMERIC_SUBTYPE_INTEGER32:
+ # 32bit comparison.
+ return cmpnz_i32(v)
+ if unbox_i32(numeric_subtype) == TYPE_NUMERIC_SUBTYPE_INTEGER64:
+ # Do 64 bit comparison.
+ return cmpnz_i64(v)
+ else:
+ # TODO: Support 8/16 bit promotion.
+ return raise_value_error(False)
+ elif unbox_i32(numeric_cat) == TYPE_NUMERIC_CATEGORY_REAL:
+ if unbox_i32(numeric_subtype) == TYPE_NUMERIC_SUBTYPE_FP32:
+ return True
+ elif unbox_i32(numeric_subtype) == TYPE_NUMERIC_SUBTYPE_FP64:
+ return True
+ elif unbox_i32(numeric_subtype) == TYPE_NUMERIC_SUBTYPE_FP16:
+ return True
+ else:
+ # TODO: BF16?
+ return raise_value_error(False)
+ else:
+ # TODO: Unsigned, apsigned, weak.
+ return raise_value_error(False)
+
# TODO: List, Str, Bytes, Tuple, user objects, etc.
- return True
+ return raise_value_error(False)
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/constants.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/constants.py
index 49fdf9a..659f46f 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/constants.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/constants.py
@@ -14,8 +14,46 @@
TYPE_BYTES = 5
TYPE_EXCEPTION_RESULT = 6
TYPE_TYPE = 7
-TYPE_BOOL = 8
-TYPE_INTEGER = 9
-TYPE_REAL = 0xa
-TYPE_COMPLEX = 0xb
TYPE_OBJECT = 0x100
+
+# Numeric type exist in a range between 0x20 and 0x7f (inclusive).
+TYPE_NUMERIC_START = 0x20
+TYPE_NUMERIC_END = 0x7f
+
+# To test if numeric, shift right by this amount and compare to
+# TYPE_NUMERIC_SHITED_VALUE
+TYPE_NUMERIC_SHIFT = 6
+TYPE_NUMERIC_SHIFTED_VALUE = 1
+
+# The lower 6 bits represent bit-packed numeric category and sub-type codes:
+# C C C C S S
+TYPE_NUMERIC_MASK = 0x3f
+
+# Mask of just the category bits.
+TYPE_NUMERIC_CATEGORY_MASK = 0x3c
+TYPE_NUMERIC_CATEGORY_SHIFT = 2
+TYPE_NUMERIC_CATEGORY_BOOL = 0x0
+TYPE_NUMERIC_CATEGORY_WEAK_INTEGER = 0x1
+TYPE_NUMERIC_CATEGORY_UNSIGNED = 0x2
+TYPE_NUMERIC_CATEGORY_SIGNED = 0x3
+TYPE_NUMERIC_CATEGORY_APSIGNED = 0x4
+TYPE_NUMERIC_CATEGORY_WEAK_REAL = 0x5
+TYPE_NUMERIC_CATEGORY_REAL = 0x6
+TYPE_NUMERIC_CATEGORY_WEAK_COMPLEX = 0x7
+TYPE_NUMERIC_CATEGORY_COMPLEX = 0x8
+
+# Mask of the sub-type bits.
+TYPE_NUMERIC_SUBTYPE_MASK = 0x3
+# Integer subtypes (applies to UNSIGNED and SIGNED categories).
+TYPE_NUMERIC_SUBTYPE_INTEGER8 = 0x0
+TYPE_NUMERIC_SUBTYPE_INTEGER16 = 0x1
+TYPE_NUMERIC_SUBTYPE_INTEGER32 = 0x2
+TYPE_NUMERIC_SUBTYPE_INTEGER64 = 0x3
+# Real subtypes.
+TYPE_NUMERIC_SUBTYPE_FP16 = 0x0
+TYPE_NUMERIC_SUBTYPE_BF16 = 0x1
+TYPE_NUMERIC_SUBTYPE_FP32 = 0x2
+TYPE_NUMERIC_SUBTYPE_FP64 = 0x3
+# Complex subtypes.
+TYPE_NUMERIC_SUBTYPE_COMPLEX64 = 0x2
+TYPE_NUMERIC_SUBTYPE_COMPLEX128 = 0x3
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/macros.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/macros.py
index 2c0b083..0410a72 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/macros.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/macros.py
@@ -5,6 +5,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Macros that can be used freely when building RTL modules."""
+from .constants import *
+
from ...importer import (
def_ir_macro_intrinsic,
ImportStage,
@@ -14,99 +16,132 @@
from ..... import ir
+def _constant_i32(value: int):
+ """Emits a constant i32 value."""
+ return d.ConstantOp(
+ d.IntegerType.get_explicit(32),
+ ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value)).result
+
+
+def _constant_i64(value: int):
+ """Emits a constant i64 value."""
+ return d.ConstantOp(
+ d.IntegerType.get_explicit(64),
+ ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)).result
+
+
+def _unbox_i32(stage: ImportStage, value: ir.Value) -> ir.Value:
+ i32_type = d.IntegerType.get_explicit(32)
+ if d.ObjectType.isinstance(value.type):
+ return d.UnboxOp(d.ExceptionResultType.get(), i32_type, value).primitive
+ else:
+ if value.type != i32_type:
+ stage.ic.abort(
+ f"Type error unbox a non object type -> integer<32>: {value.type}")
+ return value
+
+
+def _unbox_i64(stage: ImportStage, value: ir.Value) -> ir.Value:
+ i64_type = d.IntegerType.get_explicit(64)
+ if d.ObjectType.isinstance(value.type):
+ return d.UnboxOp(d.ExceptionResultType.get(), i64_type, value).primitive
+ else:
+ if value.type != i64_type:
+ stage.ic.abort(
+ f"Type error unbox a non object type -> integer<64>: {value.type}")
+ return value
+
+
+@def_ir_macro_intrinsic
+def unbox_i32(stage: ImportStage, value: ir.Value) -> ir.Value:
+ """Performs an unchecked unbox of an integer<32> value.
+
+ Typically this will be the result of a variable looked. It is important that
+ the variable was stored with an integer<32> or else it is UB.
+
+ If the value is not an ObjectType, it is returned as-is, assuming that it
+ is already an integer<32>.
+
+ This shouldn't be needed in the fullness of time but gets around type
+ inference limitations in contexts where we don't want to (or can't) be
+ on the general path.
+ """
+ return _unbox_i32(stage, value)
+
+
@def_ir_macro_intrinsic
def get_type_code(stage: ImportStage, value: ir.Value) -> ir.Value:
- """Gets the TypeCode (see C++ BuiltinTypeCode) associated with a value."""
- return d.GetTypeCodeOp(d.IntegerType.get(), value).result
+ """Gets the TypeCode (see C++ BuiltinTypeCode) associated with a value.
-
-@def_ir_macro_intrinsic
-def is_type(stage: ImportStage, value: ir.Value,
- type_code: ir.Value) -> ir.Value:
- """Efficiently checks whether a value has a given type code."""
- ic = stage.ic
- if not d.IntegerType.isinstance(type_code.type):
- ic.abort(f"is_type() macro must be called with a constant type_code. "
- f"Got {type_code}")
- actual_type_code = get_type_code(stage, value)
- cmp_result = d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("eq"),
- type_code, actual_type_code).result
- return cmp_result
-
-
-@def_ir_macro_intrinsic
-def get_numeric_promotion_order(stage: ImportStage,
- value: ir.Value) -> ir.Value:
- """Gets the numeric promotion order.
-
- See get_numeric_promotion_order op.
+ This always returns an integer<32>, which is expected by macros which
+ operate on type codes.
"""
- return d.GetNumericPromotionOrderOp(d.IntegerType.get(), value).result
+ return d.GetTypeCodeOp(d.IntegerType.get_explicit(32), value).result
@def_ir_macro_intrinsic
-def promote_numeric_to_integer(stage: ImportStage, value: ir.Value) -> ir.Value:
- """Promotes the value to IntegerType."""
- return d.PromoteNumericOp(d.IntegerType.get(), value).result
+def is_numeric_type_code(stage: ImportStage, type_code: ir.Value):
+ """Determines whether the type code is part of the numeric hierarchy."""
+ type_code_i32 = _unbox_i32(stage, type_code)
+ t = type_code_i32.type
+ shifted = d.ApplyBinaryOp(t, ir.StringAttr.get("rshift"), type_code_i32,
+ _constant_i32(TYPE_NUMERIC_SHIFT)).result
+ return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("eq"), shifted,
+ _constant_i32(TYPE_NUMERIC_SHIFTED_VALUE)).result
@def_ir_macro_intrinsic
-def promote_numeric_to_real(stage: ImportStage, value: ir.Value) -> ir.Value:
- """Promotes the value to RealType."""
- return d.PromoteNumericOp(d.RealType.get(), value).result
+def get_type_code_numeric_category(stage: ImportStage, type_code: ir.Value):
+ type_code_i32 = _unbox_i32(stage, type_code)
+ t = type_code_i32.type
+ masked = d.ApplyBinaryOp(t, ir.StringAttr.get("and"), type_code_i32,
+ _constant_i32(TYPE_NUMERIC_CATEGORY_MASK)).result
+ shifted = d.ApplyBinaryOp(t, ir.StringAttr.get("rshift"), masked,
+ _constant_i32(TYPE_NUMERIC_CATEGORY_SHIFT)).result
+ return shifted
@def_ir_macro_intrinsic
-def unbox_unchecked_bool(stage: ImportStage, value: ir.Value) -> ir.Value:
+def get_type_code_numeric_subtype(stage: ImportStage, type_code: ir.Value):
+ type_code_i32 = _unbox_i32(stage, type_code)
+ t = type_code_i32.type
+ return d.ApplyBinaryOp(t, ir.StringAttr.get("and"), type_code_i32,
+ _constant_i32(TYPE_NUMERIC_SUBTYPE_MASK)).result
+
+
+@def_ir_macro_intrinsic
+def unbox_bool(stage: ImportStage, value: ir.Value) -> ir.Value:
"""Unboxes an object value to a bool, not checking for success."""
return d.UnboxOp(d.ExceptionResultType.get(), d.BoolType.get(),
value).primitive
@def_ir_macro_intrinsic
-def unbox_unchecked_integer(stage: ImportStage, value: ir.Value) -> ir.Value:
- """Unboxes an object value to an integer, not checking for success."""
- return d.UnboxOp(d.ExceptionResultType.get(), d.IntegerType.get(),
- value).primitive
+def cmpnz_i32(stage: ImportStage, value: ir.Value) -> ir.Value:
+ """Promotes a numeric value to i32 and compares it to zero.
+
+ Returns True if not zero.
+ This should not be needed in the fullness of time but works around type
+ inference limitations in low level code.
+ """
+ value_i32 = _unbox_i32(stage, value)
+ zero = _constant_i32(0)
+ return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("ne"), value_i32,
+ zero).result
@def_ir_macro_intrinsic
-def unbox_unchecked_real(stage: ImportStage, value: ir.Value) -> ir.Value:
- """Unboxes an object value to a real, not checking for success."""
- return d.UnboxOp(d.ExceptionResultType.get(), d.RealType.get(),
- value).primitive
+def cmpnz_i64(stage: ImportStage, value: ir.Value) -> ir.Value:
+ """Promotes a numeric value to i64 and compares it to zero.
-
-@def_ir_macro_intrinsic
-def raw_compare_eq(stage: ImportStage, left: ir.Value,
- right: ir.Value) -> ir.Value:
- """Emits an ApplyCompareOp for 'eq'."""
- return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("eq"), left,
- right).result
-
-
-@def_ir_macro_intrinsic
-def raw_compare_gt(stage: ImportStage, left: ir.Value,
- right: ir.Value) -> ir.Value:
- """Emits an ApplyCompareOp for 'gt'."""
- return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("gt"), left,
- right).result
-
-
-@def_ir_macro_intrinsic
-def raw_compare_ge(stage: ImportStage, left: ir.Value,
- right: ir.Value) -> ir.Value:
- """Emits an ApplyCompareOp for 'ge'."""
- return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("ge"), left,
- right).result
-
-
-@def_ir_macro_intrinsic
-def raw_compare_ne(stage: ImportStage, left: ir.Value,
- right: ir.Value) -> ir.Value:
- """Emits an ApplyCompareOp for 'ne'."""
- return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("ne"), left,
- right).result
+ Returns True if not zero.
+ This should not be needed in the fullness of time but works around type
+ inference limitations in low level code.
+ """
+ value_i64 = _unbox_i64(stage, value)
+ zero = _constant_i64(0)
+ return d.ApplyCompareOp(d.BoolType.get(), ir.StringAttr.get("ne"), value_i64,
+ zero).result
@def_ir_macro_intrinsic
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/numerics.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/numerics.py
index 1b29779..cb90760 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/numerics.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/modules/numerics.py
@@ -11,42 +11,40 @@
RTL_MODULE = RtlModule("numerics")
+# TODO: This has drifted with the full numeric hierarchy and needs to be
+# rewritten.
+# @RTL_MODULE.export_pyfunc
+# def dynamic_binary_promote(left, right) -> tuple:
+# left_order = get_numeric_promotion_order(left)
+# right_order = get_numeric_promotion_order(right)
+# # Note that since we are defining the numeric promotion rules, we have to
+# # use raw functions to compare (or else we would be using the thing we are
+# # defining).
+# if raw_compare_eq(left_order, right_order):
+# return left, right
+# elif raw_compare_gt(left_order, right_order):
+# return left, _promote_to(get_type_code(left), right)
+# else:
+# return _promote_to(get_type_code(right), left), right
-@RTL_MODULE.export_pyfunc
-def dynamic_binary_promote(left, right) -> tuple:
- left_order = get_numeric_promotion_order(left)
- right_order = get_numeric_promotion_order(right)
- # Note that since we are defining the numeric promotion rules, we have to
- # use raw functions to compare (or else we would be using the thing we are
- # defining).
- if raw_compare_eq(left_order, right_order):
- return left, right
- elif raw_compare_gt(left_order, right_order):
- return left, _promote_to(get_type_code(left), right)
- else:
- return _promote_to(get_type_code(right), left), right
+# @RTL_MODULE.internal_pyfunc
+# def _promote_to(type_code: int, value):
+# if raw_compare_eq(type_code, TYPE_INTEGER):
+# return _promote_to_integer(value)
+# elif raw_compare_eq(type_code, TYPE_REAL):
+# return _promote_to_real(value)
+# return raise_value_error(None)
+# @RTL_MODULE.internal_pyfunc
+# def _promote_to_integer(value) -> int:
+# if is_type(value, TYPE_BOOL):
+# return promote_numeric_to_integer(unbox_unchecked_bool(value))
+# return raise_value_error(0)
-@RTL_MODULE.internal_pyfunc
-def _promote_to(type_code: int, value):
- if raw_compare_eq(type_code, TYPE_INTEGER):
- return _promote_to_integer(value)
- elif raw_compare_eq(type_code, TYPE_REAL):
- return _promote_to_real(value)
- return raise_value_error(None)
-
-
-@RTL_MODULE.internal_pyfunc
-def _promote_to_integer(value) -> int:
- if is_type(value, TYPE_BOOL):
- return promote_numeric_to_integer(unbox_unchecked_bool(value))
- return raise_value_error(0)
-
-
-@RTL_MODULE.internal_pyfunc
-def _promote_to_real(value) -> float:
- if is_type(value, TYPE_BOOL):
- return promote_numeric_to_real(unbox_unchecked_bool(value))
- elif is_type(value, TYPE_INTEGER):
- return promote_numeric_to_real(unbox_unchecked_integer(value))
- return raise_value_error(0.0)
+# @RTL_MODULE.internal_pyfunc
+# def _promote_to_real(value) -> float:
+# if is_type(value, TYPE_BOOL):
+# return promote_numeric_to_real(unbox_unchecked_bool(value))
+# elif is_type(value, TYPE_INTEGER):
+# return promote_numeric_to_real(unbox_unchecked_integer(value))
+# return raise_value_error(0.0)
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/ops_types_parse.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/ops_types_parse.mlir
index 5b908ce..9668beb 100644
--- a/llvm-external-projects/iree-dialects/test/iree_pydm/ops_types_parse.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/ops_types_parse.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s | iree-dialects-opt
+// RUN: iree-dialects-opt %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
iree_pydm.func @free_var(%arg0 : !iree_pydm.bool) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
%var = alloc_free_var "foo" -> !iree_pydm.free_var_ref
@@ -13,3 +13,21 @@
%0 = load_var %var : !iree_pydm.free_var_ref -> !iree_pydm.bool
return %0 : !iree_pydm.bool
}
+
+// CHECK-LABEL: @integer_types
+// CHECK-SAME: !iree_pydm.integer
+// CHECK-SAME: !iree_pydm.integer<32>
+// CHECK-SAME: !iree_pydm.integer<unsigned 32>
+// CHECK-SAME: !iree_pydm.integer<*>
+iree_pydm.func private @integer_types(
+ !iree_pydm.integer,
+ !iree_pydm.integer<32>,
+ !iree_pydm.integer<unsigned 32>,
+ !iree_pydm.integer<*>)-> (!iree_pydm.exception_result, !iree_pydm.bool)
+
+// CHECK-LABEL: @real_types
+// CHECK-SAME: !iree_pydm.real
+// CHECK-SAME: !iree_pydm.real<f32>
+iree_pydm.func private @real_types(
+ !iree_pydm.real,
+ !iree_pydm.real<f32>)-> (!iree_pydm.exception_result, !iree_pydm.bool)
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
index 46cafc7..b54dae6 100644
--- a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
@@ -143,6 +143,6 @@
iree_pydm.func @get_type_code(%arg0 : !iree_pydm.object) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[R:.*]] = iree.list.get %arg0[%[[c0]]] : !iree.list<!iree.variant> -> i32
- %0 = get_type_code %arg0 : !iree_pydm.object
+ %0 = get_type_code %arg0 : !iree_pydm.object -> !iree_pydm.integer
return %0 : !iree_pydm.integer
}
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 53ab9d9..c721847 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -7,8 +7,7 @@
MLIRTransforms
IREEDialectsIREEDialect
IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMRTLPasses
- IREEDialectsIREEPyDMToIREEPasses
+ IREEDialectsIREEPyDMPasses
)
add_llvm_tool(iree-dialects-opt