Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 1 | // Copyright 2021 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
Stella Laurenzo | 02cfcd1 | 2021-11-14 13:20:53 -0800 | [diff] [blame] | 7 | #include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h" |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 8 | |
Stella Laurenzo | 02cfcd1 | 2021-11-14 13:20:53 -0800 | [diff] [blame] | 9 | #include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h" |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 10 | #include "llvm/ADT/SmallSet.h" |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 11 | #include "llvm/ADT/TypeSwitch.h" |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 12 | #include "llvm/Support/Debug.h" |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 13 | #include "mlir/IR/Builders.h" |
| 14 | #include "mlir/IR/BuiltinTypes.h" |
| 15 | #include "mlir/IR/FunctionImplementation.h" |
| 16 | #include "mlir/IR/OpImplementation.h" |
| 17 | #include "mlir/IR/TypeUtilities.h" |
| 18 | |
| 19 | using namespace mlir; |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 20 | namespace PYDM = mlir::iree_compiler::IREE::PYDM; |
| 21 | using namespace PYDM; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 22 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 23 | using llvm::dbgs; |
| 24 | |
| 25 | using PyBoolType = PYDM::BoolType; |
| 26 | using PyConstantOp = PYDM::ConstantOp; |
| 27 | using PyIntegerType = PYDM::IntegerType; |
| 28 | using PyRealType = PYDM::RealType; |
| 29 | using PyCallOp = PYDM::CallOp; |
| 30 | using PyFuncOp = PYDM::FuncOp; |
| 31 | |
| 32 | static LogicalResult verify(Operation *) { return success(); } |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 33 | |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | // Utilities |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 38 | namespace { |
| 39 | |
| 40 | /// Generic pattern to unbox any operands that are a specific object |
| 41 | /// type (i.e. object<integer>). |
| 42 | struct UnboxOperands : public RewritePattern { |
| 43 | UnboxOperands(StringRef rootName, MLIRContext *context, |
| 44 | Optional<llvm::SmallSet<int, 4>> operandIndices = None) |
| 45 | : RewritePattern(rootName, 1, context), operandIndices(operandIndices) {} |
| 46 | LogicalResult matchAndRewrite(Operation *op, |
| 47 | PatternRewriter &rewriter) const override { |
| 48 | Location loc = op->getLoc(); |
| 49 | bool changed = false; |
| 50 | SmallVector<Value> operands(op->getOperands()); |
| 51 | auto excResultType = rewriter.getType<ExceptionResultType>(); |
| 52 | for (int operandIndex = 0, e = operands.size(); operandIndex < e; |
| 53 | ++operandIndex) { |
| 54 | Value &operand = operands[operandIndex]; |
| 55 | if (operandIndices && !operandIndices->contains(operandIndex)) continue; |
| 56 | if (auto objectType = operand.getType().dyn_cast<ObjectType>()) { |
| 57 | Type primitiveType = objectType.getPrimitiveType(); |
| 58 | if (primitiveType) { |
| 59 | // Unbox. |
| 60 | auto unboxOp = rewriter.create<UnboxOp>( |
| 61 | loc, TypeRange{excResultType, primitiveType}, operand); |
| 62 | operand = unboxOp.primitive(); |
| 63 | changed = true; |
| 64 | } |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | if (changed) { |
| 69 | rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); |
| 70 | return success(); |
| 71 | } |
| 72 | |
| 73 | return failure(); |
| 74 | } |
| 75 | Optional<llvm::SmallSet<int, 4>> operandIndices; |
| 76 | }; |
| 77 | |
| 78 | } // namespace |
| 79 | |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 80 | static Value getNumericZeroConstant(Location loc, Type numericType, |
| 81 | OpBuilder &builder) { |
| 82 | return TypeSwitch<Type, Value>(numericType) |
| 83 | .Case([&](PyBoolType t) -> Value { |
| 84 | return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false)); |
| 85 | }) |
| 86 | .Case([&](PyIntegerType t) -> Value { |
| 87 | return builder.create<PyConstantOp>(loc, t, |
| 88 | builder.getI64IntegerAttr(0)); |
| 89 | }) |
| 90 | .Case([&](PyRealType t) -> Value { |
| 91 | return builder.create<PyConstantOp>(loc, t, |
| 92 | builder.getF64FloatAttr(0.0)); |
| 93 | }); |
| 94 | } |
| 95 | |
| 96 | static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) { |
| 97 | return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(), |
| 98 | builder.getBoolAttr(pred)); |
| 99 | } |
| 100 | |
| 101 | //===----------------------------------------------------------------------===// |
| 102 | // Constants |
| 103 | //===----------------------------------------------------------------------===// |
| 104 | |
| 105 | OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) { |
| 106 | assert(operands.empty() && "constant has no operands"); |
| 107 | return getValue(); |
| 108 | } |
| 109 | |
| 110 | OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) { |
| 111 | assert(operands.empty() && "constant has no operands"); |
| 112 | return UnitAttr::get(getContext()); |
| 113 | } |
| 114 | |
| 115 | OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) { |
| 116 | assert(operands.empty() && "constant has no operands"); |
| 117 | return UnitAttr::get(getContext()); |
| 118 | } |
| 119 | |
| 120 | //===----------------------------------------------------------------------===// |
| 121 | // Variables |
| 122 | //===----------------------------------------------------------------------===// |
| 123 | |
| 124 | void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { |
| 125 | setNameFn(getResult(), name()); |
| 126 | } |
| 127 | |
| 128 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 129 | // ApplyBinaryOp |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 130 | //===----------------------------------------------------------------------===// |
| 131 | |
| 132 | namespace { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 133 | struct ApplyBinaryToSequenceClone : public OpRewritePattern<ApplyBinaryOp> { |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 134 | using OpRewritePattern::OpRewritePattern; |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 135 | LogicalResult matchAndRewrite(ApplyBinaryOp op, |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 136 | PatternRewriter &rewriter) const override { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 137 | if (op.dunder_name() != "mul") return failure(); |
| 138 | Value listOperand; |
| 139 | Value countOperand; |
| 140 | if (isBuiltinSequence(op.left()) && isInteger(op.right())) { |
| 141 | listOperand = op.left(); |
| 142 | countOperand = op.right(); |
| 143 | } else if (isInteger(op.left()) && isBuiltinSequence(op.right())) { |
| 144 | countOperand = op.left(); |
| 145 | listOperand = op.right(); |
| 146 | } else { |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 147 | return failure(); |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 148 | } |
| 149 | Type resultType = op.getResult().getType(); |
| 150 | rewriter.replaceOpWithNewOp<SequenceCloneOp>(op, resultType, listOperand, |
| 151 | countOperand); |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 152 | return success(); |
| 153 | } |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 154 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 155 | static bool isBuiltinSequence(Value operand) { |
| 156 | return operand.getType().isa<PYDM::ListType, PYDM::TupleType>(); |
| 157 | } |
| 158 | static bool isInteger(Value operand) { |
| 159 | return operand.getType().isa<PYDM::IntegerType>(); |
| 160 | } |
| 161 | }; |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 162 | } // namespace |
| 163 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 164 | void ApplyBinaryOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 165 | MLIRContext *context) { |
| 166 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 167 | patterns.add<ApplyBinaryToSequenceClone>(context); |
| 168 | } |
| 169 | |
| 170 | bool ApplyBinaryOp::refineResultTypes() { |
| 171 | auto leftType = left().getType(); |
| 172 | auto rightType = right().getType(); |
| 173 | auto applyUpdates = [&](Type newResultType) -> bool { |
| 174 | if (newResultType != getResult().getType()) { |
| 175 | getResult().setType(newResultType); |
| 176 | return true; |
| 177 | } |
| 178 | return false; |
| 179 | }; |
| 180 | |
| 181 | // Both numeric types. It is only dynamically legal for statically known |
| 182 | // numeric types to be the same, in which case the result type must be the |
| 183 | // same as well. |
| 184 | auto ptLeft = leftType.dyn_cast<PythonTypeInterface>(); |
| 185 | auto ptRight = rightType.dyn_cast<PythonTypeInterface>(); |
| 186 | if (ptLeft && ptRight && ptLeft.getNumericPromotionOrder() && |
| 187 | ptRight.getNumericPromotionOrder()) { |
| 188 | if (leftType == rightType) { |
| 189 | return applyUpdates(leftType); |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | // (list, integer) or (integer, list) refine to the list type. |
| 194 | if (dunder_name() == "mul") { |
| 195 | auto leftList = leftType.dyn_cast<ListType>(); |
| 196 | auto rightList = rightType.dyn_cast<ListType>(); |
| 197 | auto leftInteger = leftType.dyn_cast<IntegerType>(); |
| 198 | auto rightInteger = rightType.dyn_cast<IntegerType>(); |
| 199 | if (leftList && rightInteger) { |
| 200 | return applyUpdates(leftList); |
| 201 | } else if (leftInteger && rightList) { |
| 202 | return applyUpdates(rightList); |
| 203 | } |
| 204 | } |
| 205 | |
| 206 | return false; |
| 207 | } |
| 208 | |
| 209 | //===----------------------------------------------------------------------===// |
| 210 | // ApplyCompareOp |
| 211 | //===----------------------------------------------------------------------===// |
| 212 | |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 213 | void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 214 | MLIRContext *context) { |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 215 | patterns.add<UnboxOperands>(getOperationName(), context); |
Stella Laurenzo | 813d3ae | 2021-10-06 15:39:03 -0700 | [diff] [blame] | 216 | } |
| 217 | |
| 218 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 219 | // AsBoolOp |
| 220 | //===----------------------------------------------------------------------===// |
| 221 | |
| 222 | namespace { |
| 223 | struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> { |
| 224 | public: |
| 225 | using OpRewritePattern::OpRewritePattern; |
| 226 | LogicalResult matchAndRewrite(AsBoolOp op, |
| 227 | PatternRewriter &rewriter) const override { |
| 228 | if (op.value().getType().isa<BoolType>()) { |
| 229 | rewriter.replaceOp(op, op.value()); |
| 230 | return success(); |
| 231 | } |
| 232 | return failure(); |
| 233 | } |
| 234 | }; |
| 235 | |
| 236 | struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> { |
| 237 | public: |
| 238 | using OpRewritePattern::OpRewritePattern; |
| 239 | LogicalResult matchAndRewrite(AsBoolOp op, |
| 240 | PatternRewriter &rewriter) const override { |
| 241 | auto loc = op.getLoc(); |
| 242 | auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>(); |
| 243 | if (!ptType) return failure(); |
| 244 | if (!ptType.getNumericPromotionOrder()) return failure(); |
| 245 | |
| 246 | auto boolType = rewriter.getType<BoolType>(); |
| 247 | Value zeroValue = |
| 248 | getNumericZeroConstant(loc, op.value().getType(), rewriter); |
| 249 | Value trueValue = getBoolConstant(loc, true, rewriter); |
| 250 | Value falseValue = getBoolConstant(loc, false, rewriter); |
| 251 | Value cmpResult = rewriter.create<ApplyCompareOp>( |
| 252 | loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue); |
| 253 | rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue, |
| 254 | trueValue); |
| 255 | return success(); |
| 256 | } |
| 257 | }; |
| 258 | |
| 259 | } // namespace |
| 260 | |
| 261 | void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 262 | MLIRContext *context) { |
| 263 | patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context); |
| 264 | } |
| 265 | |
| 266 | OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) { |
| 267 | Builder b(getContext()); |
| 268 | // Fold NoneType to False. |
| 269 | if (value().getType().isa<NoneType>()) { |
| 270 | return b.getBoolAttr(false); |
| 271 | } |
| 272 | |
| 273 | return {}; |
| 274 | } |
| 275 | |
| 276 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 277 | // AssignSubscriptOp |
| 278 | //===----------------------------------------------------------------------===// |
| 279 | |
| 280 | void AssignSubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 281 | MLIRContext *context) { |
| 282 | llvm::SmallSet<int, 4> unboxIndices; |
| 283 | unboxIndices.insert(0); |
| 284 | unboxIndices.insert(1); |
| 285 | patterns.add<UnboxOperands>(getOperationName(), context, unboxIndices); |
| 286 | } |
| 287 | |
| 288 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 289 | // BoolToPredOp |
| 290 | //===----------------------------------------------------------------------===// |
| 291 | |
| 292 | OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) { |
| 293 | if (!operands[0]) return {}; |
| 294 | // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1), |
| 295 | // we can just return it. |
| 296 | return operands[0]; |
| 297 | } |
| 298 | |
| 299 | //===----------------------------------------------------------------------===// |
| 300 | // BoxOp and UnboxOp |
| 301 | //===----------------------------------------------------------------------===// |
| 302 | |
| 303 | LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) { |
| 304 | // Sometimes boxes are emitted when the input is an object. Just remove. |
| 305 | if (op.primitive().getType().isa<ObjectType>()) { |
| 306 | rewriter.replaceOp(op, op.primitive()); |
| 307 | return success(); |
| 308 | } |
| 309 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 310 | // Box to an appropriate type and static info cast. |
| 311 | ObjectType objectType = rewriter.getType<ObjectType>(nullptr); |
| 312 | if (op.object().getType() == objectType && |
| 313 | !op.primitive().getType().isa<ObjectType>()) { |
| 314 | auto refinedBox = rewriter.create<BoxOp>( |
| 315 | op.getLoc(), |
| 316 | rewriter.getType<ObjectType>( |
| 317 | op.primitive().getType().cast<PrimitiveType>()), |
| 318 | op.primitive()); |
| 319 | rewriter.replaceOpWithNewOp<StaticInfoCastOp>(op, op.object().getType(), |
| 320 | refinedBox); |
| 321 | return success(); |
| 322 | } |
| 323 | |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 324 | return failure(); |
| 325 | } |
| 326 | |
| 327 | LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp, |
| 328 | PatternRewriter &rewriter) { |
| 329 | auto loc = unboxOp.getLoc(); |
| 330 | |
| 331 | // Handle the case of an immediate BoxOp producer. |
| 332 | if (auto boxProducer = |
| 333 | dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) { |
| 334 | // If the producer is boxing to the same type we are unboxing, then |
| 335 | // just elide everything. |
| 336 | if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) { |
| 337 | auto successValue = rewriter.create<SuccessOp>( |
| 338 | loc, rewriter.getType<ExceptionResultType>()); |
| 339 | rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()}); |
| 340 | return success(); |
| 341 | } |
| 342 | } |
| 343 | return failure(); |
| 344 | } |
| 345 | |
| 346 | //===----------------------------------------------------------------------===// |
| 347 | // DynamicBinaryPromoteOp |
| 348 | //===----------------------------------------------------------------------===// |
| 349 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 350 | namespace { |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 351 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 352 | /// Resolves a DynamicBinaryPromote over numeric operands to either elide |
| 353 | /// or insert specific PromoteNumeric ops. |
| 354 | struct ResolveNumericDynamicBinaryPromote |
| 355 | : public OpRewritePattern<DynamicBinaryPromoteOp> { |
| 356 | public: |
| 357 | using OpRewritePattern::OpRewritePattern; |
| 358 | LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op, |
| 359 | PatternRewriter &rewriter) const override { |
| 360 | auto loc = op.getLoc(); |
| 361 | auto leftType = op.left().getType(); |
| 362 | auto rightType = op.right().getType(); |
| 363 | auto leftResultType = op.getResultTypes()[0]; |
| 364 | auto rightResultType = op.getResultTypes()[1]; |
| 365 | auto leftPt = leftType.dyn_cast<PythonTypeInterface>(); |
| 366 | auto rightPt = rightType.dyn_cast<PythonTypeInterface>(); |
| 367 | if (!leftPt || !rightPt) return failure(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 368 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 369 | Optional<int> leftOrder = leftPt.getNumericPromotionOrder(); |
| 370 | Optional<int> rightOrder = rightPt.getNumericPromotionOrder(); |
| 371 | Value newLeft = op.left(); |
| 372 | Value newRight = op.right(); |
| 373 | |
| 374 | if (leftOrder && rightOrder) { |
| 375 | // Both numeric. |
| 376 | if (*leftOrder > *rightOrder) { |
| 377 | newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight); |
| 378 | } |
| 379 | if (*rightOrder > *leftOrder) { |
| 380 | newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft); |
| 381 | } |
| 382 | } else { |
| 383 | return failure(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 384 | } |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 385 | |
| 386 | // Need to box back to the original type (which will always be a generic |
| 387 | // object). |
| 388 | newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft); |
| 389 | newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight); |
| 390 | |
| 391 | rewriter.replaceOp(op, {newLeft, newRight}); |
| 392 | return success(); |
| 393 | } |
| 394 | }; |
| 395 | |
| 396 | /// If we statically determine one of the arguments to be a concrete, non |
| 397 | /// numeric type, then the op has no meaning and is elided. |
| 398 | struct ElideNonNumericDynamicBinaryPromote |
| 399 | : public OpRewritePattern<DynamicBinaryPromoteOp> { |
| 400 | public: |
| 401 | using OpRewritePattern::OpRewritePattern; |
| 402 | LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op, |
| 403 | PatternRewriter &rewriter) const override { |
| 404 | if ((!isConcreteNonNumericType(op.left().getType()) && |
| 405 | !isConcreteNonNumericType(op.right().getType()))) |
| 406 | return failure(); |
| 407 | |
| 408 | // Since DynamicBinaryPromote already returns object, and we only match |
| 409 | // non-object operands, box them back. |
| 410 | auto loc = op.getLoc(); |
| 411 | auto leftResultType = op.getResultTypes()[0]; |
| 412 | auto rightResultType = op.getResultTypes()[1]; |
| 413 | Value newLeft = rewriter.create<BoxOp>(loc, leftResultType, op.left()); |
| 414 | Value newRight = rewriter.create<BoxOp>(loc, rightResultType, op.right()); |
| 415 | rewriter.replaceOp(op, {newLeft, newRight}); |
| 416 | return success(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 417 | } |
| 418 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 419 | static bool isConcreteNonNumericType(Type t) { |
| 420 | if (t.isa<ObjectType>()) return false; |
| 421 | auto pt = t.dyn_cast<PythonTypeInterface>(); |
| 422 | if (!pt || pt.getNumericPromotionOrder()) return false; |
| 423 | return true; |
| 424 | } |
| 425 | }; |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 426 | |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 427 | } // namespace |
| 428 | |
| 429 | void DynamicBinaryPromoteOp::getCanonicalizationPatterns( |
| 430 | RewritePatternSet &patterns, MLIRContext *context) { |
| 431 | patterns.add<ResolveNumericDynamicBinaryPromote>(context); |
| 432 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 433 | patterns.add<ElideNonNumericDynamicBinaryPromote>(context); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 434 | } |
| 435 | |
| 436 | //===----------------------------------------------------------------------===// |
| 437 | // FunctionalIfOp |
| 438 | //===----------------------------------------------------------------------===// |
| 439 | |
| 440 | ::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; } |
| 441 | |
| 442 | static LogicalResult verify(FunctionalIfOp op) { |
| 443 | if (op.getNumResults() != 0 && op.elseRegion().empty()) |
| 444 | return op.emitOpError("must have an else block if defining values"); |
| 445 | |
| 446 | return RegionBranchOpInterface::verifyTypes(op); |
| 447 | } |
| 448 | |
| 449 | static ParseResult parseFunctionalIfOp(OpAsmParser &parser, |
| 450 | OperationState &result) { |
| 451 | // Create the regions for 'then'. |
| 452 | result.regions.reserve(2); |
| 453 | Region *thenRegion = result.addRegion(); |
| 454 | Region *elseRegion = result.addRegion(); |
| 455 | |
| 456 | auto &builder = parser.getBuilder(); |
| 457 | OpAsmParser::OperandType cond; |
| 458 | Type conditionType = builder.getType<PyBoolType>(); |
| 459 | if (parser.parseOperand(cond) || |
| 460 | parser.resolveOperand(cond, conditionType, result.operands)) |
| 461 | return failure(); |
| 462 | // Parse optional results type list. |
| 463 | if (parser.parseOptionalArrowTypeList(result.types)) return failure(); |
| 464 | // Parse the 'then' region. |
| 465 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
| 466 | return failure(); |
| 467 | // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); |
| 468 | |
| 469 | // If we find an 'else' keyword then parse the 'else' region. |
| 470 | if (!parser.parseOptionalKeyword("else")) { |
| 471 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
| 472 | return failure(); |
| 473 | // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
| 474 | // result.location); |
| 475 | } |
| 476 | |
| 477 | // Parse the optional attribute list. |
| 478 | if (parser.parseOptionalAttrDict(result.attributes)) return failure(); |
| 479 | return success(); |
| 480 | } |
| 481 | |
| 482 | static void print(OpAsmPrinter &p, FunctionalIfOp op) { |
| 483 | bool printBlockTerminators = false; |
| 484 | |
| 485 | p << " " << op.condition(); |
| 486 | if (!op.results().empty()) { |
| 487 | p << " -> (" << op.getResultTypes() << ")"; |
| 488 | // Print yield explicitly if the op defines values. |
| 489 | printBlockTerminators = true; |
| 490 | } |
| 491 | p.printRegion(op.thenRegion(), |
| 492 | /*printEntryBlockArgs=*/false, |
| 493 | /*printBlockTerminators=*/printBlockTerminators); |
| 494 | |
| 495 | // Print the 'else' regions if it exists and has a block. |
| 496 | auto &elseRegion = op.elseRegion(); |
| 497 | if (!elseRegion.empty()) { |
| 498 | p << " else"; |
| 499 | p.printRegion(elseRegion, |
| 500 | /*printEntryBlockArgs=*/false, |
| 501 | /*printBlockTerminators=*/printBlockTerminators); |
| 502 | } |
| 503 | |
| 504 | p.printOptionalAttrDict(op->getAttrs()); |
| 505 | } |
| 506 | |
| 507 | /// Given the region at `index`, or the parent operation if `index` is None, |
| 508 | /// return the successor regions. These are the regions that may be selected |
| 509 | /// during the flow of control. `operands` is a set of optional attributes that |
| 510 | /// correspond to a constant value for each operand, or null if that operand is |
| 511 | /// not a constant. |
| 512 | void FunctionalIfOp::getSuccessorRegions( |
| 513 | Optional<unsigned> index, ArrayRef<Attribute> operands, |
| 514 | SmallVectorImpl<RegionSuccessor> ®ions) { |
| 515 | // The `then` and the `else` region branch back to the parent operation. |
| 516 | if (index.hasValue()) { |
| 517 | regions.push_back(RegionSuccessor(getResults())); |
| 518 | return; |
| 519 | } |
| 520 | |
| 521 | // Don't consider the else region if it is empty. |
| 522 | Region *elseRegion = &this->elseRegion(); |
| 523 | if (elseRegion->empty()) elseRegion = nullptr; |
| 524 | |
| 525 | // Otherwise, the successor is dependent on the condition. |
| 526 | if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) { |
| 527 | bool condition = condAttr.getValue(); |
| 528 | // Add the successor regions using the condition. |
| 529 | regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion)); |
| 530 | } else { |
| 531 | // If the condition isn't constant, both regions may be executed. |
| 532 | regions.push_back(RegionSuccessor(&thenRegion())); |
| 533 | // If the else region does not exist, it is not a viable successor. |
| 534 | if (elseRegion) regions.push_back(RegionSuccessor(elseRegion)); |
| 535 | } |
| 536 | } |
| 537 | |
| 538 | //===----------------------------------------------------------------------===// |
| 539 | // FuncOp |
| 540 | //===----------------------------------------------------------------------===// |
| 541 | |
| 542 | ::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; } |
| 543 | |
| 544 | LogicalResult PyFuncOp::verifyType() { |
| 545 | // TODO: Enforce arg/result invariants. |
| 546 | return success(); |
| 547 | } |
| 548 | |
| 549 | static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { |
| 550 | auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, |
| 551 | ArrayRef<Type> results, |
| 552 | function_like_impl::VariadicFlag, std::string &) { |
| 553 | return builder.getFunctionType(argTypes, results); |
| 554 | }; |
| 555 | |
| 556 | return function_like_impl::parseFunctionLikeOp( |
| 557 | parser, result, /*allowVariadic=*/false, buildFuncType); |
| 558 | } |
| 559 | |
| 560 | static void print(PyFuncOp op, OpAsmPrinter &p) { |
| 561 | FunctionType fnType = op.getType(); |
| 562 | function_like_impl::printFunctionLikeOp( |
| 563 | p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); |
| 564 | } |
| 565 | |
| 566 | static LogicalResult verify(PyFuncOp op) { |
| 567 | // TODO: Enforce invariants. |
| 568 | return success(); |
| 569 | } |
| 570 | |
| 571 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 572 | // MakeListOp |
| 573 | //===----------------------------------------------------------------------===// |
| 574 | |
| 575 | static LogicalResult verify(MakeListOp op) { |
| 576 | auto listType = op.list().getType().cast<ListType>(); |
| 577 | switch (listType.getStorageClass()) { |
| 578 | case CollectionStorageClass::Boxed: |
| 579 | for (auto element : op.elements()) { |
| 580 | if (!element.getType().isa<ObjectType>()) { |
| 581 | return op.emitOpError() << "making a list with boxed storage class " |
| 582 | "must have object elements. Got: " |
| 583 | << element.getType(); |
| 584 | } |
| 585 | } |
| 586 | break; |
| 587 | case CollectionStorageClass::Unboxed: |
| 588 | for (auto element : op.elements()) { |
| 589 | if (element.getType().isa<ObjectType>()) { |
| 590 | return op.emitOpError() << "making a list with unboxed storage class " |
| 591 | "must not have object elements. Got: " |
| 592 | << element.getType(); |
| 593 | } |
| 594 | } |
| 595 | break; |
| 596 | case CollectionStorageClass::Empty: |
| 597 | if (!op.elements().empty()) { |
| 598 | return op.emitOpError() |
| 599 | << "making a list with empty storage class must have zero " |
| 600 | "elements"; |
| 601 | } |
| 602 | break; |
| 603 | } |
| 604 | return success(); |
| 605 | } |
| 606 | |
| 607 | //===----------------------------------------------------------------------===// |
| 608 | // NegOp |
| 609 | //===----------------------------------------------------------------------===// |
| 610 | |
| 611 | void NegOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 612 | MLIRContext *context) { |
| 613 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 614 | } |
| 615 | |
| 616 | bool NegOp::refineResultTypes() { |
| 617 | if (value().getType() != getResult().getType()) { |
| 618 | getResult().setType(value().getType()); |
| 619 | return true; |
| 620 | } |
| 621 | return false; |
| 622 | } |
| 623 | |
| 624 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 625 | // PatternMatchCallOp |
| 626 | //===----------------------------------------------------------------------===// |
| 627 | |
| 628 | LogicalResult PatternMatchCallOp::verifySymbolUses( |
| 629 | SymbolTableCollection &symbolTable) { |
| 630 | auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult { |
| 631 | for (auto symbolAttr : symbols) { |
| 632 | auto symbol = symbolAttr.cast<FlatSymbolRefAttr>(); |
| 633 | PyFuncOp fn = |
| 634 | symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol); |
| 635 | if (!fn) |
| 636 | return emitOpError() << "'" << symbol.getValue() |
| 637 | << "' does not reference a valid function"; |
| 638 | } |
| 639 | return success(); |
| 640 | }; |
| 641 | auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match"); |
| 642 | if (!genericsAttr) |
| 643 | return emitOpError( |
| 644 | "requires a 'generic_match' array of symbol reference attributes"); |
| 645 | if (failed(verifySymbols(genericsAttr))) return failure(); |
| 646 | |
| 647 | auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match"); |
| 648 | if (!specificsAttr) |
| 649 | return emitOpError( |
| 650 | "requires a 'specific_match' array of symbol reference attributes"); |
| 651 | if (failed(verifySymbols(specificsAttr))) return failure(); |
| 652 | |
| 653 | return success(); |
| 654 | } |
| 655 | |
| 656 | //===----------------------------------------------------------------------===// |
| 657 | // PromoteNumericOp |
| 658 | //===----------------------------------------------------------------------===// |
| 659 | |
| 660 | OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) { |
| 661 | if (!operands[0]) return {}; |
| 662 | |
| 663 | Builder b(getContext()); |
| 664 | Attribute fromAttr = operands[0]; |
| 665 | return TypeSwitch<Type, OpFoldResult>(getResult().getType()) |
| 666 | .Case([&](PyIntegerType toType) -> OpFoldResult { |
| 667 | return TypeSwitch<Attribute, OpFoldResult>(fromAttr) |
| 668 | .Case([&](BoolAttr fromBool) -> OpFoldResult { |
| 669 | return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0); |
| 670 | }) |
| 671 | .Default([](Attribute) -> OpFoldResult { return {}; }); |
| 672 | }) |
| 673 | .Case([&](PyRealType toType) -> OpFoldResult { |
| 674 | return TypeSwitch<Attribute, OpFoldResult>(fromAttr) |
| 675 | .Case([&](BoolAttr fromBool) -> OpFoldResult { |
| 676 | return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0); |
| 677 | }) |
| 678 | .Case([&](IntegerAttr fromInteger) -> OpFoldResult { |
| 679 | APInt value = fromInteger.getValue(); |
| 680 | return b.getF64FloatAttr(value.getSExtValue()); |
| 681 | }) |
| 682 | .Default([](Attribute) -> OpFoldResult { return {}; }); |
| 683 | }) |
| 684 | .Default([](Type) -> OpFoldResult { return {}; }); |
| 685 | } |
| 686 | |
| 687 | LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op, |
| 688 | PatternRewriter &rewriter) { |
| 689 | if (op.input().getType() == op.getResult().getType()) { |
| 690 | rewriter.replaceOp(op, op.input()); |
| 691 | return success(); |
| 692 | } |
| 693 | return failure(); |
| 694 | } |
| 695 | |
| 696 | //===----------------------------------------------------------------------===// |
| 697 | // RaiseOnFailureOp |
| 698 | //===----------------------------------------------------------------------===// |
| 699 | |
Stella Laurenzo | 826f1db | 2021-11-19 12:54:23 -0800 | [diff] [blame^] | 700 | LogicalResult PYDM::RaiseOnFailureOp::canonicalize(RaiseOnFailureOp op, |
| 701 | PatternRewriter &rewriter) { |
| 702 | if (op.exc_result().getDefiningOp<SuccessOp>()) { |
| 703 | op.getOperation()->erase(); |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 704 | return success(); |
| 705 | } |
| 706 | return failure(); |
| 707 | } |
| 708 | |
| 709 | //===----------------------------------------------------------------------===// |
| 710 | // SelectOp |
| 711 | //===----------------------------------------------------------------------===// |
| 712 | |
| 713 | OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { |
| 714 | if (!operands[0]) return {}; |
| 715 | |
| 716 | BoolAttr boolAttr = operands[0].cast<BoolAttr>(); |
| 717 | if (boolAttr.getValue()) |
| 718 | return true_value(); |
| 719 | else |
| 720 | return false_value(); |
| 721 | } |
| 722 | |
| 723 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | ec9d61f | 2021-11-06 16:14:31 -0700 | [diff] [blame] | 724 | // SequenceCloneOp |
| 725 | //===----------------------------------------------------------------------===// |
| 726 | |
| 727 | bool SequenceCloneOp::refineResultTypes() { |
| 728 | if (sequence().getType() != getResult().getType()) { |
| 729 | getResult().setType(sequence().getType()); |
| 730 | return true; |
| 731 | } |
| 732 | return false; |
| 733 | } |
| 734 | |
| 735 | //===----------------------------------------------------------------------===// |
| 736 | // SubscriptOp |
| 737 | //===----------------------------------------------------------------------===// |
| 738 | |
| 739 | void SubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| 740 | MLIRContext *context) { |
| 741 | patterns.add<UnboxOperands>(getOperationName(), context); |
| 742 | } |
| 743 | |
| 744 | //===----------------------------------------------------------------------===// |
Stella Laurenzo | a69124b | 2021-09-06 13:14:10 -0700 | [diff] [blame] | 745 | // CallOp |
| 746 | //===----------------------------------------------------------------------===// |
| 747 | |
| 748 | LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| 749 | // Check that the callee attribute was specified. |
| 750 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); |
| 751 | if (!fnAttr) |
| 752 | return emitOpError("requires a 'callee' symbol reference attribute"); |
| 753 | PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr); |
| 754 | if (!fn) |
| 755 | return emitOpError() << "'" << fnAttr.getValue() |
| 756 | << "' does not reference a valid function"; |
| 757 | |
| 758 | // Verify that the operand and result types match the callee. |
| 759 | auto fnType = fn.getType(); |
| 760 | if (fnType.getNumInputs() != getNumOperands()) |
| 761 | return emitOpError("incorrect number of operands for callee"); |
| 762 | |
| 763 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { |
| 764 | if (getOperand(i).getType() != fnType.getInput(i)) { |
| 765 | return emitOpError("operand type mismatch: expected operand type ") |
| 766 | << fnType.getInput(i) << ", but provided " |
| 767 | << getOperand(i).getType() << " for operand number " << i; |
| 768 | } |
| 769 | } |
| 770 | |
| 771 | if (fnType.getNumResults() != getNumResults()) |
| 772 | return emitOpError("incorrect number of results for callee"); |
| 773 | |
| 774 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { |
| 775 | if (getResult(i).getType() != fnType.getResult(i)) { |
| 776 | auto diag = emitOpError("result type mismatch at index ") << i; |
| 777 | diag.attachNote() << " op result types: " << getResultTypes(); |
| 778 | diag.attachNote() << "function result types: " << fnType.getResults(); |
| 779 | return diag; |
| 780 | } |
| 781 | } |
| 782 | |
| 783 | return success(); |
| 784 | } |
| 785 | |
| 786 | FunctionType PyCallOp::getCalleeType() { |
| 787 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
| 788 | } |
| 789 | |
| 790 | //===----------------------------------------------------------------------===// |
| 791 | // DynamicCallOp |
| 792 | //===----------------------------------------------------------------------===// |
| 793 | |
| 794 | LogicalResult DynamicCallOp::verifySymbolUses( |
| 795 | SymbolTableCollection &symbolTable) { |
| 796 | // Check that the callee attribute was specified. |
| 797 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); |
| 798 | if (!fnAttr) |
| 799 | return emitOpError("requires a 'callee' symbol reference attribute"); |
| 800 | Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); |
| 801 | if (!fn || !isa<PyFuncOp>(fn)) |
| 802 | return emitOpError() << "'" << fnAttr.getValue() |
| 803 | << "' does not reference a valid function"; |
| 804 | return success(); |
| 805 | } |
| 806 | |
| 807 | #define GET_OP_CLASSES |
Stella Laurenzo | 02cfcd1 | 2021-11-14 13:20:53 -0800 | [diff] [blame] | 808 | #include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc" |