Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 1 | // Copyright 2021 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
Stella Laurenzo | 02cfcd1 | 2021-11-14 13:20:53 -0800 | [diff] [blame] | 7 | #include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h" |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 8 | |
Stella Laurenzo | 02cfcd1 | 2021-11-14 13:20:53 -0800 | [diff] [blame] | 9 | #include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h" |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 10 | #include "llvm/ADT/SmallSet.h" |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 11 | #include "llvm/ADT/TypeSwitch.h" |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 12 | #include "llvm/Support/Debug.h" |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 13 | #include "mlir/IR/Builders.h" |
| 14 | #include "mlir/IR/BuiltinTypes.h" |
| 15 | #include "mlir/IR/FunctionImplementation.h" |
| 16 | #include "mlir/IR/OpImplementation.h" |
| 17 | #include "mlir/IR/TypeUtilities.h" |
| 18 | |
| 19 | using namespace mlir; |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 20 | namespace PYDM = mlir::iree_compiler::IREE::PYDM; |
| 21 | using namespace PYDM; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 22 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 23 | using llvm::dbgs; |
| 24 | |
| 25 | using PyBoolType = PYDM::BoolType; |
| 26 | using PyConstantOp = PYDM::ConstantOp; |
| 27 | using PyIntegerType = PYDM::IntegerType; |
| 28 | using PyRealType = PYDM::RealType; |
| 29 | using PyCallOp = PYDM::CallOp; |
| 30 | using PyFuncOp = PYDM::FuncOp; |
| 31 | |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 32 | //===----------------------------------------------------------------------===// |
| 33 | // Utilities |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 36 | namespace { |
| 37 | |
| 38 | /// Generic pattern to unbox any operands that are a specific object |
| 39 | /// type (i.e. object<integer>). |
| 40 | struct UnboxOperands : public RewritePattern { |
| 41 | UnboxOperands(StringRef rootName, MLIRContext *context, |
| 42 | Optional<llvm::SmallSet<int, 4>> operandIndices = None) |
| 43 | : RewritePattern(rootName, 1, context), operandIndices(operandIndices) {} |
| 44 | LogicalResult matchAndRewrite(Operation *op, |
| 45 | PatternRewriter &rewriter) const override { |
| 46 | Location loc = op->getLoc(); |
| 47 | bool changed = false; |
| 48 | SmallVector<Value> operands(op->getOperands()); |
| 49 | auto excResultType = rewriter.getType<ExceptionResultType>(); |
| 50 | for (int operandIndex = 0, e = operands.size(); operandIndex < e; |
| 51 | ++operandIndex) { |
| 52 | Value &operand = operands[operandIndex]; |
| 53 | if (operandIndices && !operandIndices->contains(operandIndex)) continue; |
| 54 | if (auto objectType = operand.getType().dyn_cast<ObjectType>()) { |
| 55 | Type primitiveType = objectType.getPrimitiveType(); |
| 56 | if (primitiveType) { |
| 57 | // Unbox. |
| 58 | auto unboxOp = rewriter.create<UnboxOp>( |
| 59 | loc, TypeRange{excResultType, primitiveType}, operand); |
| 60 | operand = unboxOp.primitive(); |
| 61 | changed = true; |
| 62 | } |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | if (changed) { |
| 67 | rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); |
| 68 | return success(); |
| 69 | } |
| 70 | |
| 71 | return failure(); |
| 72 | } |
| 73 | Optional<llvm::SmallSet<int, 4>> operandIndices; |
| 74 | }; |
| 75 | |
| 76 | } // namespace |
| 77 | |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 78 | static Value getNumericZeroConstant(Location loc, Type numericType, |
| 79 | OpBuilder &builder) { |
| 80 | return TypeSwitch<Type, Value>(numericType) |
| 81 | .Case([&](PyBoolType t) -> Value { |
| 82 | return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false)); |
| 83 | }) |
| 84 | .Case([&](PyIntegerType t) -> Value { |
| 85 | return builder.create<PyConstantOp>(loc, t, |
| 86 | builder.getI64IntegerAttr(0)); |
| 87 | }) |
| 88 | .Case([&](PyRealType t) -> Value { |
| 89 | return builder.create<PyConstantOp>(loc, t, |
| 90 | builder.getF64FloatAttr(0.0)); |
| 91 | }); |
| 92 | } |
| 93 | |
| 94 | static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) { |
| 95 | return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(), |
| 96 | builder.getBoolAttr(pred)); |
| 97 | } |
| 98 | |
| 99 | //===----------------------------------------------------------------------===// |
| 100 | // Constants |
| 101 | //===----------------------------------------------------------------------===// |
| 102 | |
| 103 | OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) { |
| 104 | assert(operands.empty() && "constant has no operands"); |
| 105 | return getValue(); |
| 106 | } |
| 107 | |
| 108 | OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) { |
| 109 | assert(operands.empty() && "constant has no operands"); |
| 110 | return UnitAttr::get(getContext()); |
| 111 | } |
| 112 | |
| 113 | OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) { |
| 114 | assert(operands.empty() && "constant has no operands"); |
| 115 | return UnitAttr::get(getContext()); |
| 116 | } |
| 117 | |
| 118 | //===----------------------------------------------------------------------===// |
| 119 | // Variables |
| 120 | //===----------------------------------------------------------------------===// |
| 121 | |
| 122 | void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { |
| 123 | setNameFn(getResult(), name()); |
| 124 | } |
| 125 | |
| 126 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 127 | // ApplyBinaryOp |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 128 | //===----------------------------------------------------------------------===// |
| 129 | |
| 130 | namespace { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 131 | struct ApplyBinaryToSequenceClone : public OpRewritePattern<ApplyBinaryOp> { |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 132 | using OpRewritePattern::OpRewritePattern; |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 133 | LogicalResult matchAndRewrite(ApplyBinaryOp op, |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 134 | PatternRewriter &rewriter) const override { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 135 | if (op.dunder_name() != "mul") return failure(); |
| 136 | Value listOperand; |
| 137 | Value countOperand; |
| 138 | if (isBuiltinSequence(op.left()) && isInteger(op.right())) { |
| 139 | listOperand = op.left(); |
| 140 | countOperand = op.right(); |
| 141 | } else if (isInteger(op.left()) && isBuiltinSequence(op.right())) { |
| 142 | countOperand = op.left(); |
| 143 | listOperand = op.right(); |
| 144 | } else { |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 145 | return failure(); |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 146 | } |
| 147 | Type resultType = op.getResult().getType(); |
| 148 | rewriter.replaceOpWithNewOp<SequenceCloneOp>(op, resultType, listOperand, |
| 149 | countOperand); |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 150 | return success(); |
| 151 | } |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 152 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 153 | static bool isBuiltinSequence(Value operand) { |
| 154 | return operand.getType().isa<PYDM::ListType, PYDM::TupleType>(); |
| 155 | } |
| 156 | static bool isInteger(Value operand) { |
| 157 | return operand.getType().isa<PYDM::IntegerType>(); |
| 158 | } |
| 159 | }; |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 160 | } // namespace |
| 161 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 162 | void ApplyBinaryOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 163 | MLIRContext *context) { |
| 164 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 165 | patterns.add<ApplyBinaryToSequenceClone>(context); |
| 166 | } |
| 167 | |
| 168 | bool ApplyBinaryOp::refineResultTypes() { |
| 169 | auto leftType = left().getType(); |
| 170 | auto rightType = right().getType(); |
| 171 | auto applyUpdates = [&](Type newResultType) -> bool { |
| 172 | if (newResultType != getResult().getType()) { |
| 173 | getResult().setType(newResultType); |
| 174 | return true; |
| 175 | } |
| 176 | return false; |
| 177 | }; |
| 178 | |
| 179 | // Both numeric types. It is only dynamically legal for statically known |
| 180 | // numeric types to be the same, in which case the result type must be the |
| 181 | // same as well. |
| 182 | auto ptLeft = leftType.dyn_cast<PythonTypeInterface>(); |
| 183 | auto ptRight = rightType.dyn_cast<PythonTypeInterface>(); |
| 184 | if (ptLeft && ptRight && ptLeft.getNumericPromotionOrder() && |
| 185 | ptRight.getNumericPromotionOrder()) { |
| 186 | if (leftType == rightType) { |
| 187 | return applyUpdates(leftType); |
| 188 | } |
| 189 | } |
| 190 | |
| 191 | // (list, integer) or (integer, list) refine to the list type. |
| 192 | if (dunder_name() == "mul") { |
| 193 | auto leftList = leftType.dyn_cast<ListType>(); |
| 194 | auto rightList = rightType.dyn_cast<ListType>(); |
| 195 | auto leftInteger = leftType.dyn_cast<IntegerType>(); |
| 196 | auto rightInteger = rightType.dyn_cast<IntegerType>(); |
| 197 | if (leftList && rightInteger) { |
| 198 | return applyUpdates(leftList); |
| 199 | } else if (leftInteger && rightList) { |
| 200 | return applyUpdates(rightList); |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | return false; |
| 205 | } |
| 206 | |
| 207 | //===----------------------------------------------------------------------===// |
| 208 | // ApplyCompareOp |
| 209 | //===----------------------------------------------------------------------===// |
| 210 | |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 211 | void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 212 | MLIRContext *context) { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 213 | patterns.add<UnboxOperands>(getOperationName(), context); |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 214 | } |
| 215 | |
| 216 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 217 | // AsBoolOp |
| 218 | //===----------------------------------------------------------------------===// |
| 219 | |
| 220 | namespace { |
| 221 | struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> { |
| 222 | public: |
| 223 | using OpRewritePattern::OpRewritePattern; |
| 224 | LogicalResult matchAndRewrite(AsBoolOp op, |
| 225 | PatternRewriter &rewriter) const override { |
| 226 | if (op.value().getType().isa<BoolType>()) { |
| 227 | rewriter.replaceOp(op, op.value()); |
| 228 | return success(); |
| 229 | } |
| 230 | return failure(); |
| 231 | } |
| 232 | }; |
| 233 | |
| 234 | struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> { |
| 235 | public: |
| 236 | using OpRewritePattern::OpRewritePattern; |
| 237 | LogicalResult matchAndRewrite(AsBoolOp op, |
| 238 | PatternRewriter &rewriter) const override { |
| 239 | auto loc = op.getLoc(); |
| 240 | auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>(); |
| 241 | if (!ptType) return failure(); |
| 242 | if (!ptType.getNumericPromotionOrder()) return failure(); |
| 243 | |
| 244 | auto boolType = rewriter.getType<BoolType>(); |
| 245 | Value zeroValue = |
| 246 | getNumericZeroConstant(loc, op.value().getType(), rewriter); |
| 247 | Value trueValue = getBoolConstant(loc, true, rewriter); |
| 248 | Value falseValue = getBoolConstant(loc, false, rewriter); |
| 249 | Value cmpResult = rewriter.create<ApplyCompareOp>( |
| 250 | loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue); |
| 251 | rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue, |
| 252 | trueValue); |
| 253 | return success(); |
| 254 | } |
| 255 | }; |
| 256 | |
| 257 | } // namespace |
| 258 | |
| 259 | void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 260 | MLIRContext *context) { |
| 261 | patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context); |
| 262 | } |
| 263 | |
| 264 | OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) { |
| 265 | Builder b(getContext()); |
| 266 | // Fold NoneType to False. |
| 267 | if (value().getType().isa<NoneType>()) { |
| 268 | return b.getBoolAttr(false); |
| 269 | } |
| 270 | |
| 271 | return {}; |
| 272 | } |
| 273 | |
| 274 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 275 | // AssignSubscriptOp |
| 276 | //===----------------------------------------------------------------------===// |
| 277 | |
| 278 | void AssignSubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 279 | MLIRContext *context) { |
| 280 | llvm::SmallSet<int, 4> unboxIndices; |
| 281 | unboxIndices.insert(0); |
| 282 | unboxIndices.insert(1); |
| 283 | patterns.add<UnboxOperands>(getOperationName(), context, unboxIndices); |
| 284 | } |
| 285 | |
| 286 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 287 | // BoolToPredOp |
| 288 | //===----------------------------------------------------------------------===// |
| 289 | |
| 290 | OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) { |
| 291 | if (!operands[0]) return {}; |
| 292 | // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1), |
| 293 | // we can just return it. |
| 294 | return operands[0]; |
| 295 | } |
| 296 | |
| 297 | //===----------------------------------------------------------------------===// |
| 298 | // BoxOp and UnboxOp |
| 299 | //===----------------------------------------------------------------------===// |
| 300 | |
| 301 | LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) { |
| 302 | // Sometimes boxes are emitted when the input is an object. Just remove. |
| 303 | if (op.primitive().getType().isa<ObjectType>()) { |
| 304 | rewriter.replaceOp(op, op.primitive()); |
| 305 | return success(); |
| 306 | } |
| 307 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 308 | // Box to an appropriate type and static info cast. |
| 309 | ObjectType objectType = rewriter.getType<ObjectType>(nullptr); |
| 310 | if (op.object().getType() == objectType && |
| 311 | !op.primitive().getType().isa<ObjectType>()) { |
| 312 | auto refinedBox = rewriter.create<BoxOp>( |
| 313 | op.getLoc(), |
| 314 | rewriter.getType<ObjectType>( |
| 315 | op.primitive().getType().cast<PrimitiveType>()), |
| 316 | op.primitive()); |
| 317 | rewriter.replaceOpWithNewOp<StaticInfoCastOp>(op, op.object().getType(), |
| 318 | refinedBox); |
| 319 | return success(); |
| 320 | } |
| 321 | |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 322 | return failure(); |
| 323 | } |
| 324 | |
| 325 | LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp, |
| 326 | PatternRewriter &rewriter) { |
| 327 | auto loc = unboxOp.getLoc(); |
| 328 | |
| 329 | // Handle the case of an immediate BoxOp producer. |
| 330 | if (auto boxProducer = |
| 331 | dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) { |
| 332 | // If the producer is boxing to the same type we are unboxing, then |
| 333 | // just elide everything. |
| 334 | if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) { |
| 335 | auto successValue = rewriter.create<SuccessOp>( |
| 336 | loc, rewriter.getType<ExceptionResultType>()); |
| 337 | rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()}); |
| 338 | return success(); |
| 339 | } |
| 340 | } |
| 341 | return failure(); |
| 342 | } |
| 343 | |
| 344 | //===----------------------------------------------------------------------===// |
| 345 | // DynamicBinaryPromoteOp |
| 346 | //===----------------------------------------------------------------------===// |
| 347 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 348 | namespace { |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 349 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 350 | /// Resolves a DynamicBinaryPromote over numeric operands to either elide |
| 351 | /// or insert specific PromoteNumeric ops. |
| 352 | struct ResolveNumericDynamicBinaryPromote |
| 353 | : public OpRewritePattern<DynamicBinaryPromoteOp> { |
| 354 | public: |
| 355 | using OpRewritePattern::OpRewritePattern; |
| 356 | LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op, |
| 357 | PatternRewriter &rewriter) const override { |
| 358 | auto loc = op.getLoc(); |
| 359 | auto leftType = op.left().getType(); |
| 360 | auto rightType = op.right().getType(); |
| 361 | auto leftResultType = op.getResultTypes()[0]; |
| 362 | auto rightResultType = op.getResultTypes()[1]; |
| 363 | auto leftPt = leftType.dyn_cast<PythonTypeInterface>(); |
| 364 | auto rightPt = rightType.dyn_cast<PythonTypeInterface>(); |
| 365 | if (!leftPt || !rightPt) return failure(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 366 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 367 | Optional<int> leftOrder = leftPt.getNumericPromotionOrder(); |
| 368 | Optional<int> rightOrder = rightPt.getNumericPromotionOrder(); |
| 369 | Value newLeft = op.left(); |
| 370 | Value newRight = op.right(); |
| 371 | |
| 372 | if (leftOrder && rightOrder) { |
| 373 | // Both numeric. |
| 374 | if (*leftOrder > *rightOrder) { |
| 375 | newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight); |
| 376 | } |
| 377 | if (*rightOrder > *leftOrder) { |
| 378 | newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft); |
| 379 | } |
| 380 | } else { |
| 381 | return failure(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 382 | } |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 383 | |
| 384 | // Need to box back to the original type (which will always be a generic |
| 385 | // object). |
| 386 | newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft); |
| 387 | newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight); |
| 388 | |
| 389 | rewriter.replaceOp(op, {newLeft, newRight}); |
| 390 | return success(); |
| 391 | } |
| 392 | }; |
| 393 | |
| 394 | /// If we statically determine one of the arguments to be a concrete, non |
| 395 | /// numeric type, then the op has no meaning and is elided. |
| 396 | struct ElideNonNumericDynamicBinaryPromote |
| 397 | : public OpRewritePattern<DynamicBinaryPromoteOp> { |
| 398 | public: |
| 399 | using OpRewritePattern::OpRewritePattern; |
| 400 | LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op, |
| 401 | PatternRewriter &rewriter) const override { |
| 402 | if ((!isConcreteNonNumericType(op.left().getType()) && |
| 403 | !isConcreteNonNumericType(op.right().getType()))) |
| 404 | return failure(); |
| 405 | |
| 406 | // Since DynamicBinaryPromote already returns object, and we only match |
| 407 | // non-object operands, box them back. |
| 408 | auto loc = op.getLoc(); |
| 409 | auto leftResultType = op.getResultTypes()[0]; |
| 410 | auto rightResultType = op.getResultTypes()[1]; |
| 411 | Value newLeft = rewriter.create<BoxOp>(loc, leftResultType, op.left()); |
| 412 | Value newRight = rewriter.create<BoxOp>(loc, rightResultType, op.right()); |
| 413 | rewriter.replaceOp(op, {newLeft, newRight}); |
| 414 | return success(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 415 | } |
| 416 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 417 | static bool isConcreteNonNumericType(Type t) { |
| 418 | if (t.isa<ObjectType>()) return false; |
| 419 | auto pt = t.dyn_cast<PythonTypeInterface>(); |
| 420 | if (!pt || pt.getNumericPromotionOrder()) return false; |
| 421 | return true; |
| 422 | } |
| 423 | }; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 424 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 425 | } // namespace |
| 426 | |
| 427 | void DynamicBinaryPromoteOp::getCanonicalizationPatterns( |
| 428 | RewritePatternSet &patterns, MLIRContext *context) { |
| 429 | patterns.add<ResolveNumericDynamicBinaryPromote>(context); |
| 430 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 431 | patterns.add<ElideNonNumericDynamicBinaryPromote>(context); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 432 | } |
| 433 | |
| 434 | //===----------------------------------------------------------------------===// |
| 435 | // FunctionalIfOp |
| 436 | //===----------------------------------------------------------------------===// |
| 437 | |
| 438 | ::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; } |
| 439 | |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 440 | LogicalResult FunctionalIfOp::verify() { |
| 441 | if (getNumResults() != 0 && elseRegion().empty()) |
| 442 | return emitOpError("must have an else block if defining values"); |
Han-Chung Wang | 94dcb46 | 2022-03-09 10:51:31 -0800 | [diff] [blame] | 443 | return success(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 444 | } |
| 445 | |
Jacques Pienaar | ff38cb4 | 2022-02-12 08:42:27 -0800 | [diff] [blame] | 446 | ParseResult FunctionalIfOp::parse(OpAsmParser &parser, OperationState &result) { |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 447 | // Create the regions for 'then'. |
| 448 | result.regions.reserve(2); |
| 449 | Region *thenRegion = result.addRegion(); |
| 450 | Region *elseRegion = result.addRegion(); |
| 451 | |
| 452 | auto &builder = parser.getBuilder(); |
| 453 | OpAsmParser::OperandType cond; |
| 454 | Type conditionType = builder.getType<PyBoolType>(); |
| 455 | if (parser.parseOperand(cond) || |
| 456 | parser.resolveOperand(cond, conditionType, result.operands)) |
| 457 | return failure(); |
| 458 | // Parse optional results type list. |
| 459 | if (parser.parseOptionalArrowTypeList(result.types)) return failure(); |
| 460 | // Parse the 'then' region. |
| 461 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
| 462 | return failure(); |
| 463 | // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); |
| 464 | |
| 465 | // If we find an 'else' keyword then parse the 'else' region. |
| 466 | if (!parser.parseOptionalKeyword("else")) { |
| 467 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
| 468 | return failure(); |
| 469 | // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
| 470 | // result.location); |
| 471 | } |
| 472 | |
| 473 | // Parse the optional attribute list. |
| 474 | if (parser.parseOptionalAttrDict(result.attributes)) return failure(); |
| 475 | return success(); |
| 476 | } |
| 477 | |
Jacques Pienaar | ff38cb4 | 2022-02-12 08:42:27 -0800 | [diff] [blame] | 478 | void FunctionalIfOp::print(OpAsmPrinter &p) { |
| 479 | FunctionalIfOp op = *this; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 480 | bool printBlockTerminators = false; |
| 481 | |
| 482 | p << " " << op.condition(); |
| 483 | if (!op.results().empty()) { |
| 484 | p << " -> (" << op.getResultTypes() << ")"; |
| 485 | // Print yield explicitly if the op defines values. |
| 486 | printBlockTerminators = true; |
| 487 | } |
Stella Laurenzo | 44c4187 | 2022-01-19 19:39:08 -0800 | [diff] [blame] | 488 | p << " "; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 489 | p.printRegion(op.thenRegion(), |
| 490 | /*printEntryBlockArgs=*/false, |
| 491 | /*printBlockTerminators=*/printBlockTerminators); |
| 492 | |
| 493 | // Print the 'else' regions if it exists and has a block. |
| 494 | auto &elseRegion = op.elseRegion(); |
| 495 | if (!elseRegion.empty()) { |
Stella Laurenzo | 44c4187 | 2022-01-19 19:39:08 -0800 | [diff] [blame] | 496 | p << " else "; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 497 | p.printRegion(elseRegion, |
| 498 | /*printEntryBlockArgs=*/false, |
| 499 | /*printBlockTerminators=*/printBlockTerminators); |
| 500 | } |
| 501 | |
| 502 | p.printOptionalAttrDict(op->getAttrs()); |
| 503 | } |
| 504 | |
| 505 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 506 | /// return the successor regions. These are the regions that may be selected |
| 507 | /// during the flow of control. `operands` is a set of optional attributes that |
| 508 | /// correspond to a constant value for each operand, or null if that operand is |
| 509 | /// not a constant. |
| 510 | void FunctionalIfOp::getSuccessorRegions( |
| 511 | Optional<unsigned> index, ArrayRef<Attribute> operands, |
| 512 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 513 | // The `then` and the `else` region branch back to the parent operation. |
| 514 | if (index.hasValue()) { |
| 515 | regions.push_back(RegionSuccessor(getResults())); |
| 516 | return; |
| 517 | } |
| 518 | |
| 519 | // Don't consider the else region if it is empty. |
| 520 | Region *elseRegion = &this->elseRegion(); |
| 521 | if (elseRegion->empty()) elseRegion = nullptr; |
| 522 | |
| 523 | // Otherwise, the successor is dependent on the condition. |
| 524 | if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) { |
| 525 | bool condition = condAttr.getValue(); |
| 526 | // Add the successor regions using the condition. |
| 527 | regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion)); |
| 528 | } else { |
| 529 | // If the condition isn't constant, both regions may be executed. |
| 530 | regions.push_back(RegionSuccessor(&thenRegion())); |
| 531 | // If the else region does not exist, it is not a viable successor. |
| 532 | if (elseRegion) regions.push_back(RegionSuccessor(elseRegion)); |
| 533 | } |
| 534 | } |
| 535 | |
| 536 | //===----------------------------------------------------------------------===// |
| 537 | // FuncOp |
| 538 | //===----------------------------------------------------------------------===// |
| 539 | |
| 540 | ::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; } |
| 541 | |
| 542 | LogicalResult PyFuncOp::verifyType() { |
| 543 | // TODO: Enforce arg/result invariants. |
| 544 | return success(); |
| 545 | } |
| 546 | |
Jacques Pienaar | ff38cb4 | 2022-02-12 08:42:27 -0800 | [diff] [blame] | 547 | ParseResult PyFuncOp::parse(OpAsmParser &parser, OperationState &result) { |
Stella Laurenzo | 68e7810 | 2022-01-25 22:24:56 -0800 | [diff] [blame] | 548 | auto buildFuncType = |
| 549 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
| 550 | function_interface_impl::VariadicFlag, |
| 551 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 552 | |
Stella Laurenzo | 68e7810 | 2022-01-25 22:24:56 -0800 | [diff] [blame] | 553 | return function_interface_impl::parseFunctionOp( |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 554 | parser, result, /*allowVariadic=*/false, buildFuncType); |
| 555 | } |
| 556 | |
Jacques Pienaar | ff38cb4 | 2022-02-12 08:42:27 -0800 | [diff] [blame] | 557 | void PyFuncOp::print(OpAsmPrinter &p) { |
| 558 | FunctionType fnType = getType(); |
Stella Laurenzo | 68e7810 | 2022-01-25 22:24:56 -0800 | [diff] [blame] | 559 | function_interface_impl::printFunctionOp( |
Jacques Pienaar | ff38cb4 | 2022-02-12 08:42:27 -0800 | [diff] [blame] | 560 | p, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 561 | } |
| 562 | |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 563 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 564 | // MakeListOp |
| 565 | //===----------------------------------------------------------------------===// |
| 566 | |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 567 | LogicalResult MakeListOp::verify() { |
| 568 | auto listType = list().getType().cast<ListType>(); |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 569 | switch (listType.getStorageClass()) { |
| 570 | case CollectionStorageClass::Boxed: |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 571 | for (auto element : elements()) { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 572 | if (!element.getType().isa<ObjectType>()) { |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 573 | return emitOpError() << "making a list with boxed storage class " |
| 574 | "must have object elements. Got: " |
| 575 | << element.getType(); |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 576 | } |
| 577 | } |
| 578 | break; |
| 579 | case CollectionStorageClass::Unboxed: |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 580 | for (auto element : elements()) { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 581 | if (element.getType().isa<ObjectType>()) { |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 582 | return emitOpError() << "making a list with unboxed storage class " |
| 583 | "must not have object elements. Got: " |
| 584 | << element.getType(); |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 585 | } |
| 586 | } |
| 587 | break; |
| 588 | case CollectionStorageClass::Empty: |
MaheshRavishankar | f488f17 | 2022-03-18 18:19:02 -0700 | [diff] [blame^] | 589 | if (!elements().empty()) { |
| 590 | return emitOpError() |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 591 | << "making a list with empty storage class must have zero " |
| 592 | "elements"; |
| 593 | } |
| 594 | break; |
| 595 | } |
| 596 | return success(); |
| 597 | } |
| 598 | |
| 599 | //===----------------------------------------------------------------------===// |
| 600 | // NegOp |
| 601 | //===----------------------------------------------------------------------===// |
| 602 | |
| 603 | void NegOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 604 | MLIRContext *context) { |
| 605 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 606 | } |
| 607 | |
| 608 | bool NegOp::refineResultTypes() { |
| 609 | if (value().getType() != getResult().getType()) { |
| 610 | getResult().setType(value().getType()); |
| 611 | return true; |
| 612 | } |
| 613 | return false; |
| 614 | } |
| 615 | |
| 616 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 617 | // PatternMatchCallOp |
| 618 | //===----------------------------------------------------------------------===// |
| 619 | |
| 620 | LogicalResult PatternMatchCallOp::verifySymbolUses( |
| 621 | SymbolTableCollection &symbolTable) { |
| 622 | auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult { |
| 623 | for (auto symbolAttr : symbols) { |
| 624 | auto symbol = symbolAttr.cast<FlatSymbolRefAttr>(); |
| 625 | PyFuncOp fn = |
| 626 | symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol); |
| 627 | if (!fn) |
| 628 | return emitOpError() << "'" << symbol.getValue() |
| 629 | << "' does not reference a valid function"; |
| 630 | } |
| 631 | return success(); |
| 632 | }; |
| 633 | auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match"); |
| 634 | if (!genericsAttr) |
| 635 | return emitOpError( |
| 636 | "requires a 'generic_match' array of symbol reference attributes"); |
| 637 | if (failed(verifySymbols(genericsAttr))) return failure(); |
| 638 | |
| 639 | auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match"); |
| 640 | if (!specificsAttr) |
| 641 | return emitOpError( |
| 642 | "requires a 'specific_match' array of symbol reference attributes"); |
| 643 | if (failed(verifySymbols(specificsAttr))) return failure(); |
| 644 | |
| 645 | return success(); |
| 646 | } |
| 647 | |
| 648 | //===----------------------------------------------------------------------===// |
| 649 | // PromoteNumericOp |
| 650 | //===----------------------------------------------------------------------===// |
| 651 | |
| 652 | OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) { |
| 653 | if (!operands[0]) return {}; |
| 654 | |
| 655 | Builder b(getContext()); |
| 656 | Attribute fromAttr = operands[0]; |
| 657 | return TypeSwitch<Type, OpFoldResult>(getResult().getType()) |
| 658 | .Case([&](PyIntegerType toType) -> OpFoldResult { |
| 659 | return TypeSwitch<Attribute, OpFoldResult>(fromAttr) |
| 660 | .Case([&](BoolAttr fromBool) -> OpFoldResult { |
| 661 | return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0); |
| 662 | }) |
| 663 | .Default([](Attribute) -> OpFoldResult { return {}; }); |
| 664 | }) |
| 665 | .Case([&](PyRealType toType) -> OpFoldResult { |
| 666 | return TypeSwitch<Attribute, OpFoldResult>(fromAttr) |
| 667 | .Case([&](BoolAttr fromBool) -> OpFoldResult { |
| 668 | return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0); |
| 669 | }) |
| 670 | .Case([&](IntegerAttr fromInteger) -> OpFoldResult { |
| 671 | APInt value = fromInteger.getValue(); |
| 672 | return b.getF64FloatAttr(value.getSExtValue()); |
| 673 | }) |
| 674 | .Default([](Attribute) -> OpFoldResult { return {}; }); |
| 675 | }) |
| 676 | .Default([](Type) -> OpFoldResult { return {}; }); |
| 677 | } |
| 678 | |
| 679 | LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op, |
| 680 | PatternRewriter &rewriter) { |
| 681 | if (op.input().getType() == op.getResult().getType()) { |
| 682 | rewriter.replaceOp(op, op.input()); |
| 683 | return success(); |
| 684 | } |
| 685 | return failure(); |
| 686 | } |
| 687 | |
| 688 | //===----------------------------------------------------------------------===// |
| 689 | // RaiseOnFailureOp |
| 690 | //===----------------------------------------------------------------------===// |
| 691 | |
Stella Laurenzo | 826f1db | 2021-11-19 12:54:23 -0800 | [diff] [blame] | 692 | LogicalResult PYDM::RaiseOnFailureOp::canonicalize(RaiseOnFailureOp op, |
| 693 | PatternRewriter &rewriter) { |
| 694 | if (op.exc_result().getDefiningOp<SuccessOp>()) { |
| 695 | op.getOperation()->erase(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 696 | return success(); |
| 697 | } |
| 698 | return failure(); |
| 699 | } |
| 700 | |
| 701 | //===----------------------------------------------------------------------===// |
| 702 | // SelectOp |
| 703 | //===----------------------------------------------------------------------===// |
| 704 | |
| 705 | OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { |
| 706 | if (!operands[0]) return {}; |
| 707 | |
| 708 | BoolAttr boolAttr = operands[0].cast<BoolAttr>(); |
| 709 | if (boolAttr.getValue()) |
| 710 | return true_value(); |
| 711 | else |
| 712 | return false_value(); |
| 713 | } |
| 714 | |
| 715 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 716 | // SequenceCloneOp |
| 717 | //===----------------------------------------------------------------------===// |
| 718 | |
| 719 | bool SequenceCloneOp::refineResultTypes() { |
| 720 | if (sequence().getType() != getResult().getType()) { |
| 721 | getResult().setType(sequence().getType()); |
| 722 | return true; |
| 723 | } |
| 724 | return false; |
| 725 | } |
| 726 | |
| 727 | //===----------------------------------------------------------------------===// |
| 728 | // SubscriptOp |
| 729 | //===----------------------------------------------------------------------===// |
| 730 | |
| 731 | void SubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 732 | MLIRContext *context) { |
| 733 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 734 | } |
| 735 | |
| 736 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 737 | // CallOp |
| 738 | //===----------------------------------------------------------------------===// |
| 739 | |
| 740 | LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| 741 | // Check that the callee attribute was specified. |
| 742 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); |
| 743 | if (!fnAttr) |
| 744 | return emitOpError("requires a 'callee' symbol reference attribute"); |
| 745 | PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr); |
| 746 | if (!fn) |
| 747 | return emitOpError() << "'" << fnAttr.getValue() |
| 748 | << "' does not reference a valid function"; |
| 749 | |
| 750 | // Verify that the operand and result types match the callee. |
| 751 | auto fnType = fn.getType(); |
| 752 | if (fnType.getNumInputs() != getNumOperands()) |
| 753 | return emitOpError("incorrect number of operands for callee"); |
| 754 | |
| 755 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { |
| 756 | if (getOperand(i).getType() != fnType.getInput(i)) { |
| 757 | return emitOpError("operand type mismatch: expected operand type ") |
| 758 | << fnType.getInput(i) << ", but provided " |
| 759 | << getOperand(i).getType() << " for operand number " << i; |
| 760 | } |
| 761 | } |
| 762 | |
| 763 | if (fnType.getNumResults() != getNumResults()) |
| 764 | return emitOpError("incorrect number of results for callee"); |
| 765 | |
| 766 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { |
| 767 | if (getResult(i).getType() != fnType.getResult(i)) { |
| 768 | auto diag = emitOpError("result type mismatch at index ") << i; |
| 769 | diag.attachNote() << " op result types: " << getResultTypes(); |
| 770 | diag.attachNote() << "function result types: " << fnType.getResults(); |
| 771 | return diag; |
| 772 | } |
| 773 | } |
| 774 | |
| 775 | return success(); |
| 776 | } |
| 777 | |
| 778 | FunctionType PyCallOp::getCalleeType() { |
| 779 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
| 780 | } |
| 781 | |
| 782 | //===----------------------------------------------------------------------===// |
| 783 | // DynamicCallOp |
| 784 | //===----------------------------------------------------------------------===// |
| 785 | |
| 786 | LogicalResult DynamicCallOp::verifySymbolUses( |
| 787 | SymbolTableCollection &symbolTable) { |
| 788 | // Check that the callee attribute was specified. |
| 789 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); |
| 790 | if (!fnAttr) |
| 791 | return emitOpError("requires a 'callee' symbol reference attribute"); |
| 792 | Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); |
| 793 | if (!fn || !isa<PyFuncOp>(fn)) |
| 794 | return emitOpError() << "'" << fnAttr.getValue() |
| 795 | << "' does not reference a valid function"; |
| 796 | return success(); |
| 797 | } |
| 798 | |
| 799 | #define GET_OP_CLASSES |
Stella Laurenzo | 02cfcd1 | 2021-11-14 13:20:53 -0800 | [diff] [blame] | 800 | #include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc" |