blob: 20106882d3adc79ede17972d93339438ca657b6c [file] [log] [blame]
Stella Laurenzoa69124b2021-09-06 13:14:10 -07001// Copyright 2021 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
Stella Laurenzo02cfcd12021-11-14 13:20:53 -08007#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
Stella Laurenzoa69124b2021-09-06 13:14:10 -07008
Stella Laurenzo02cfcd12021-11-14 13:20:53 -08009#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070010#include "llvm/ADT/SmallSet.h"
Stella Laurenzoa69124b2021-09-06 13:14:10 -070011#include "llvm/ADT/TypeSwitch.h"
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070012#include "llvm/Support/Debug.h"
Stella Laurenzoa69124b2021-09-06 13:14:10 -070013#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/FunctionImplementation.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/TypeUtilities.h"
18
19using namespace mlir;
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070020namespace PYDM = mlir::iree_compiler::IREE::PYDM;
21using namespace PYDM;
Stella Laurenzoa69124b2021-09-06 13:14:10 -070022
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070023using llvm::dbgs;
24
25using PyBoolType = PYDM::BoolType;
26using PyConstantOp = PYDM::ConstantOp;
27using PyIntegerType = PYDM::IntegerType;
28using PyRealType = PYDM::RealType;
29using PyCallOp = PYDM::CallOp;
30using PyFuncOp = PYDM::FuncOp;
31
Stella Laurenzoa69124b2021-09-06 13:14:10 -070032//===----------------------------------------------------------------------===//
33// Utilities
34//===----------------------------------------------------------------------===//
35
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070036namespace {
37
38/// Generic pattern to unbox any operands that are a specific object
39/// type (i.e. object<integer>).
40struct UnboxOperands : public RewritePattern {
41 UnboxOperands(StringRef rootName, MLIRContext *context,
42 Optional<llvm::SmallSet<int, 4>> operandIndices = None)
43 : RewritePattern(rootName, 1, context), operandIndices(operandIndices) {}
44 LogicalResult matchAndRewrite(Operation *op,
45 PatternRewriter &rewriter) const override {
46 Location loc = op->getLoc();
47 bool changed = false;
48 SmallVector<Value> operands(op->getOperands());
49 auto excResultType = rewriter.getType<ExceptionResultType>();
50 for (int operandIndex = 0, e = operands.size(); operandIndex < e;
51 ++operandIndex) {
52 Value &operand = operands[operandIndex];
53 if (operandIndices && !operandIndices->contains(operandIndex)) continue;
54 if (auto objectType = operand.getType().dyn_cast<ObjectType>()) {
55 Type primitiveType = objectType.getPrimitiveType();
56 if (primitiveType) {
57 // Unbox.
58 auto unboxOp = rewriter.create<UnboxOp>(
59 loc, TypeRange{excResultType, primitiveType}, operand);
60 operand = unboxOp.primitive();
61 changed = true;
62 }
63 }
64 }
65
66 if (changed) {
67 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
68 return success();
69 }
70
71 return failure();
72 }
73 Optional<llvm::SmallSet<int, 4>> operandIndices;
74};
75
76} // namespace
77
Stella Laurenzoa69124b2021-09-06 13:14:10 -070078static Value getNumericZeroConstant(Location loc, Type numericType,
79 OpBuilder &builder) {
80 return TypeSwitch<Type, Value>(numericType)
81 .Case([&](PyBoolType t) -> Value {
82 return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false));
83 })
84 .Case([&](PyIntegerType t) -> Value {
85 return builder.create<PyConstantOp>(loc, t,
86 builder.getI64IntegerAttr(0));
87 })
88 .Case([&](PyRealType t) -> Value {
89 return builder.create<PyConstantOp>(loc, t,
90 builder.getF64FloatAttr(0.0));
91 });
92}
93
94static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) {
95 return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(),
96 builder.getBoolAttr(pred));
97}
98
99//===----------------------------------------------------------------------===//
100// Constants
101//===----------------------------------------------------------------------===//
102
103OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) {
104 assert(operands.empty() && "constant has no operands");
105 return getValue();
106}
107
108OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) {
109 assert(operands.empty() && "constant has no operands");
110 return UnitAttr::get(getContext());
111}
112
113OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) {
114 assert(operands.empty() && "constant has no operands");
115 return UnitAttr::get(getContext());
116}
117
118//===----------------------------------------------------------------------===//
119// Variables
120//===----------------------------------------------------------------------===//
121
122void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
123 setNameFn(getResult(), name());
124}
125
126//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700127// ApplyBinaryOp
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700128//===----------------------------------------------------------------------===//
129
130namespace {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700131struct ApplyBinaryToSequenceClone : public OpRewritePattern<ApplyBinaryOp> {
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700132 using OpRewritePattern::OpRewritePattern;
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700133 LogicalResult matchAndRewrite(ApplyBinaryOp op,
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700134 PatternRewriter &rewriter) const override {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700135 if (op.dunder_name() != "mul") return failure();
136 Value listOperand;
137 Value countOperand;
138 if (isBuiltinSequence(op.left()) && isInteger(op.right())) {
139 listOperand = op.left();
140 countOperand = op.right();
141 } else if (isInteger(op.left()) && isBuiltinSequence(op.right())) {
142 countOperand = op.left();
143 listOperand = op.right();
144 } else {
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700145 return failure();
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700146 }
147 Type resultType = op.getResult().getType();
148 rewriter.replaceOpWithNewOp<SequenceCloneOp>(op, resultType, listOperand,
149 countOperand);
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700150 return success();
151 }
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700152
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700153 static bool isBuiltinSequence(Value operand) {
154 return operand.getType().isa<PYDM::ListType, PYDM::TupleType>();
155 }
156 static bool isInteger(Value operand) {
157 return operand.getType().isa<PYDM::IntegerType>();
158 }
159};
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700160} // namespace
161
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700162void ApplyBinaryOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
163 MLIRContext *context) {
164 patterns.add<UnboxOperands>(getOperationName(), context);
165 patterns.add<ApplyBinaryToSequenceClone>(context);
166}
167
168bool ApplyBinaryOp::refineResultTypes() {
169 auto leftType = left().getType();
170 auto rightType = right().getType();
171 auto applyUpdates = [&](Type newResultType) -> bool {
172 if (newResultType != getResult().getType()) {
173 getResult().setType(newResultType);
174 return true;
175 }
176 return false;
177 };
178
179 // Both numeric types. It is only dynamically legal for statically known
180 // numeric types to be the same, in which case the result type must be the
181 // same as well.
182 auto ptLeft = leftType.dyn_cast<PythonTypeInterface>();
183 auto ptRight = rightType.dyn_cast<PythonTypeInterface>();
184 if (ptLeft && ptRight && ptLeft.getNumericPromotionOrder() &&
185 ptRight.getNumericPromotionOrder()) {
186 if (leftType == rightType) {
187 return applyUpdates(leftType);
188 }
189 }
190
191 // (list, integer) or (integer, list) refine to the list type.
192 if (dunder_name() == "mul") {
193 auto leftList = leftType.dyn_cast<ListType>();
194 auto rightList = rightType.dyn_cast<ListType>();
195 auto leftInteger = leftType.dyn_cast<IntegerType>();
196 auto rightInteger = rightType.dyn_cast<IntegerType>();
197 if (leftList && rightInteger) {
198 return applyUpdates(leftList);
199 } else if (leftInteger && rightList) {
200 return applyUpdates(rightList);
201 }
202 }
203
204 return false;
205}
206
207//===----------------------------------------------------------------------===//
208// ApplyCompareOp
209//===----------------------------------------------------------------------===//
210
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700211void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
212 MLIRContext *context) {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700213 patterns.add<UnboxOperands>(getOperationName(), context);
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700214}
215
216//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700217// AsBoolOp
218//===----------------------------------------------------------------------===//
219
220namespace {
221struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> {
222 public:
223 using OpRewritePattern::OpRewritePattern;
224 LogicalResult matchAndRewrite(AsBoolOp op,
225 PatternRewriter &rewriter) const override {
226 if (op.value().getType().isa<BoolType>()) {
227 rewriter.replaceOp(op, op.value());
228 return success();
229 }
230 return failure();
231 }
232};
233
234struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> {
235 public:
236 using OpRewritePattern::OpRewritePattern;
237 LogicalResult matchAndRewrite(AsBoolOp op,
238 PatternRewriter &rewriter) const override {
239 auto loc = op.getLoc();
240 auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>();
241 if (!ptType) return failure();
242 if (!ptType.getNumericPromotionOrder()) return failure();
243
244 auto boolType = rewriter.getType<BoolType>();
245 Value zeroValue =
246 getNumericZeroConstant(loc, op.value().getType(), rewriter);
247 Value trueValue = getBoolConstant(loc, true, rewriter);
248 Value falseValue = getBoolConstant(loc, false, rewriter);
249 Value cmpResult = rewriter.create<ApplyCompareOp>(
250 loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue);
251 rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue,
252 trueValue);
253 return success();
254 }
255};
256
257} // namespace
258
259void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
260 MLIRContext *context) {
261 patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context);
262}
263
264OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) {
265 Builder b(getContext());
266 // Fold NoneType to False.
267 if (value().getType().isa<NoneType>()) {
268 return b.getBoolAttr(false);
269 }
270
271 return {};
272}
273
274//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700275// AssignSubscriptOp
276//===----------------------------------------------------------------------===//
277
278void AssignSubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
279 MLIRContext *context) {
280 llvm::SmallSet<int, 4> unboxIndices;
281 unboxIndices.insert(0);
282 unboxIndices.insert(1);
283 patterns.add<UnboxOperands>(getOperationName(), context, unboxIndices);
284}
285
286//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700287// BoolToPredOp
288//===----------------------------------------------------------------------===//
289
290OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) {
291 if (!operands[0]) return {};
292 // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1),
293 // we can just return it.
294 return operands[0];
295}
296
297//===----------------------------------------------------------------------===//
298// BoxOp and UnboxOp
299//===----------------------------------------------------------------------===//
300
301LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) {
302 // Sometimes boxes are emitted when the input is an object. Just remove.
303 if (op.primitive().getType().isa<ObjectType>()) {
304 rewriter.replaceOp(op, op.primitive());
305 return success();
306 }
307
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700308 // Box to an appropriate type and static info cast.
309 ObjectType objectType = rewriter.getType<ObjectType>(nullptr);
310 if (op.object().getType() == objectType &&
311 !op.primitive().getType().isa<ObjectType>()) {
312 auto refinedBox = rewriter.create<BoxOp>(
313 op.getLoc(),
314 rewriter.getType<ObjectType>(
315 op.primitive().getType().cast<PrimitiveType>()),
316 op.primitive());
317 rewriter.replaceOpWithNewOp<StaticInfoCastOp>(op, op.object().getType(),
318 refinedBox);
319 return success();
320 }
321
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700322 return failure();
323}
324
325LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp,
326 PatternRewriter &rewriter) {
327 auto loc = unboxOp.getLoc();
328
329 // Handle the case of an immediate BoxOp producer.
330 if (auto boxProducer =
331 dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) {
332 // If the producer is boxing to the same type we are unboxing, then
333 // just elide everything.
334 if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) {
335 auto successValue = rewriter.create<SuccessOp>(
336 loc, rewriter.getType<ExceptionResultType>());
337 rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()});
338 return success();
339 }
340 }
341 return failure();
342}
343
344//===----------------------------------------------------------------------===//
345// DynamicBinaryPromoteOp
346//===----------------------------------------------------------------------===//
347
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700348namespace {
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700349
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700350/// Resolves a DynamicBinaryPromote over numeric operands to either elide
351/// or insert specific PromoteNumeric ops.
352struct ResolveNumericDynamicBinaryPromote
353 : public OpRewritePattern<DynamicBinaryPromoteOp> {
354 public:
355 using OpRewritePattern::OpRewritePattern;
356 LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
357 PatternRewriter &rewriter) const override {
358 auto loc = op.getLoc();
359 auto leftType = op.left().getType();
360 auto rightType = op.right().getType();
361 auto leftResultType = op.getResultTypes()[0];
362 auto rightResultType = op.getResultTypes()[1];
363 auto leftPt = leftType.dyn_cast<PythonTypeInterface>();
364 auto rightPt = rightType.dyn_cast<PythonTypeInterface>();
365 if (!leftPt || !rightPt) return failure();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700366
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700367 Optional<int> leftOrder = leftPt.getNumericPromotionOrder();
368 Optional<int> rightOrder = rightPt.getNumericPromotionOrder();
369 Value newLeft = op.left();
370 Value newRight = op.right();
371
372 if (leftOrder && rightOrder) {
373 // Both numeric.
374 if (*leftOrder > *rightOrder) {
375 newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight);
376 }
377 if (*rightOrder > *leftOrder) {
378 newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft);
379 }
380 } else {
381 return failure();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700382 }
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700383
384 // Need to box back to the original type (which will always be a generic
385 // object).
386 newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft);
387 newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight);
388
389 rewriter.replaceOp(op, {newLeft, newRight});
390 return success();
391 }
392};
393
394/// If we statically determine one of the arguments to be a concrete, non
395/// numeric type, then the op has no meaning and is elided.
396struct ElideNonNumericDynamicBinaryPromote
397 : public OpRewritePattern<DynamicBinaryPromoteOp> {
398 public:
399 using OpRewritePattern::OpRewritePattern;
400 LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
401 PatternRewriter &rewriter) const override {
402 if ((!isConcreteNonNumericType(op.left().getType()) &&
403 !isConcreteNonNumericType(op.right().getType())))
404 return failure();
405
406 // Since DynamicBinaryPromote already returns object, and we only match
407 // non-object operands, box them back.
408 auto loc = op.getLoc();
409 auto leftResultType = op.getResultTypes()[0];
410 auto rightResultType = op.getResultTypes()[1];
411 Value newLeft = rewriter.create<BoxOp>(loc, leftResultType, op.left());
412 Value newRight = rewriter.create<BoxOp>(loc, rightResultType, op.right());
413 rewriter.replaceOp(op, {newLeft, newRight});
414 return success();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700415 }
416
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700417 static bool isConcreteNonNumericType(Type t) {
418 if (t.isa<ObjectType>()) return false;
419 auto pt = t.dyn_cast<PythonTypeInterface>();
420 if (!pt || pt.getNumericPromotionOrder()) return false;
421 return true;
422 }
423};
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700424
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700425} // namespace
426
427void DynamicBinaryPromoteOp::getCanonicalizationPatterns(
428 RewritePatternSet &patterns, MLIRContext *context) {
429 patterns.add<ResolveNumericDynamicBinaryPromote>(context);
430 patterns.add<UnboxOperands>(getOperationName(), context);
431 patterns.add<ElideNonNumericDynamicBinaryPromote>(context);
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700432}
433
434//===----------------------------------------------------------------------===//
435// FunctionalIfOp
436//===----------------------------------------------------------------------===//
437
438::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; }
439
MaheshRavishankarf488f172022-03-18 18:19:02 -0700440LogicalResult FunctionalIfOp::verify() {
441 if (getNumResults() != 0 && elseRegion().empty())
442 return emitOpError("must have an else block if defining values");
Han-Chung Wang94dcb462022-03-09 10:51:31 -0800443 return success();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700444}
445
Jacques Pienaarff38cb42022-02-12 08:42:27 -0800446ParseResult FunctionalIfOp::parse(OpAsmParser &parser, OperationState &result) {
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700447 // Create the regions for 'then'.
448 result.regions.reserve(2);
449 Region *thenRegion = result.addRegion();
450 Region *elseRegion = result.addRegion();
451
452 auto &builder = parser.getBuilder();
453 OpAsmParser::OperandType cond;
454 Type conditionType = builder.getType<PyBoolType>();
455 if (parser.parseOperand(cond) ||
456 parser.resolveOperand(cond, conditionType, result.operands))
457 return failure();
458 // Parse optional results type list.
459 if (parser.parseOptionalArrowTypeList(result.types)) return failure();
460 // Parse the 'then' region.
461 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
462 return failure();
463 // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
464
465 // If we find an 'else' keyword then parse the 'else' region.
466 if (!parser.parseOptionalKeyword("else")) {
467 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
468 return failure();
469 // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
470 // result.location);
471 }
472
473 // Parse the optional attribute list.
474 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
475 return success();
476}
477
Jacques Pienaarff38cb42022-02-12 08:42:27 -0800478void FunctionalIfOp::print(OpAsmPrinter &p) {
479 FunctionalIfOp op = *this;
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700480 bool printBlockTerminators = false;
481
482 p << " " << op.condition();
483 if (!op.results().empty()) {
484 p << " -> (" << op.getResultTypes() << ")";
485 // Print yield explicitly if the op defines values.
486 printBlockTerminators = true;
487 }
Stella Laurenzo44c41872022-01-19 19:39:08 -0800488 p << " ";
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700489 p.printRegion(op.thenRegion(),
490 /*printEntryBlockArgs=*/false,
491 /*printBlockTerminators=*/printBlockTerminators);
492
493 // Print the 'else' regions if it exists and has a block.
494 auto &elseRegion = op.elseRegion();
495 if (!elseRegion.empty()) {
Stella Laurenzo44c41872022-01-19 19:39:08 -0800496 p << " else ";
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700497 p.printRegion(elseRegion,
498 /*printEntryBlockArgs=*/false,
499 /*printBlockTerminators=*/printBlockTerminators);
500 }
501
502 p.printOptionalAttrDict(op->getAttrs());
503}
504
505/// Given the region at `index`, or the parent operation if `index` is None,
506/// return the successor regions. These are the regions that may be selected
507/// during the flow of control. `operands` is a set of optional attributes that
508/// correspond to a constant value for each operand, or null if that operand is
509/// not a constant.
510void FunctionalIfOp::getSuccessorRegions(
511 Optional<unsigned> index, ArrayRef<Attribute> operands,
512 SmallVectorImpl<RegionSuccessor> &regions) {
513 // The `then` and the `else` region branch back to the parent operation.
514 if (index.hasValue()) {
515 regions.push_back(RegionSuccessor(getResults()));
516 return;
517 }
518
519 // Don't consider the else region if it is empty.
520 Region *elseRegion = &this->elseRegion();
521 if (elseRegion->empty()) elseRegion = nullptr;
522
523 // Otherwise, the successor is dependent on the condition.
524 if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
525 bool condition = condAttr.getValue();
526 // Add the successor regions using the condition.
527 regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
528 } else {
529 // If the condition isn't constant, both regions may be executed.
530 regions.push_back(RegionSuccessor(&thenRegion()));
531 // If the else region does not exist, it is not a viable successor.
532 if (elseRegion) regions.push_back(RegionSuccessor(elseRegion));
533 }
534}
535
536//===----------------------------------------------------------------------===//
537// FuncOp
538//===----------------------------------------------------------------------===//
539
540::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; }
541
542LogicalResult PyFuncOp::verifyType() {
543 // TODO: Enforce arg/result invariants.
544 return success();
545}
546
Jacques Pienaarff38cb42022-02-12 08:42:27 -0800547ParseResult PyFuncOp::parse(OpAsmParser &parser, OperationState &result) {
Stella Laurenzo68e78102022-01-25 22:24:56 -0800548 auto buildFuncType =
549 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
550 function_interface_impl::VariadicFlag,
551 std::string &) { return builder.getFunctionType(argTypes, results); };
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700552
Stella Laurenzo68e78102022-01-25 22:24:56 -0800553 return function_interface_impl::parseFunctionOp(
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700554 parser, result, /*allowVariadic=*/false, buildFuncType);
555}
556
Jacques Pienaarff38cb42022-02-12 08:42:27 -0800557void PyFuncOp::print(OpAsmPrinter &p) {
558 FunctionType fnType = getType();
Stella Laurenzo68e78102022-01-25 22:24:56 -0800559 function_interface_impl::printFunctionOp(
Jacques Pienaarff38cb42022-02-12 08:42:27 -0800560 p, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700561}
562
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700563//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700564// MakeListOp
565//===----------------------------------------------------------------------===//
566
MaheshRavishankarf488f172022-03-18 18:19:02 -0700567LogicalResult MakeListOp::verify() {
568 auto listType = list().getType().cast<ListType>();
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700569 switch (listType.getStorageClass()) {
570 case CollectionStorageClass::Boxed:
MaheshRavishankarf488f172022-03-18 18:19:02 -0700571 for (auto element : elements()) {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700572 if (!element.getType().isa<ObjectType>()) {
MaheshRavishankarf488f172022-03-18 18:19:02 -0700573 return emitOpError() << "making a list with boxed storage class "
574 "must have object elements. Got: "
575 << element.getType();
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700576 }
577 }
578 break;
579 case CollectionStorageClass::Unboxed:
MaheshRavishankarf488f172022-03-18 18:19:02 -0700580 for (auto element : elements()) {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700581 if (element.getType().isa<ObjectType>()) {
MaheshRavishankarf488f172022-03-18 18:19:02 -0700582 return emitOpError() << "making a list with unboxed storage class "
583 "must not have object elements. Got: "
584 << element.getType();
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700585 }
586 }
587 break;
588 case CollectionStorageClass::Empty:
MaheshRavishankarf488f172022-03-18 18:19:02 -0700589 if (!elements().empty()) {
590 return emitOpError()
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700591 << "making a list with empty storage class must have zero "
592 "elements";
593 }
594 break;
595 }
596 return success();
597}
598
599//===----------------------------------------------------------------------===//
600// NegOp
601//===----------------------------------------------------------------------===//
602
603void NegOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604 MLIRContext *context) {
605 patterns.add<UnboxOperands>(getOperationName(), context);
606}
607
608bool NegOp::refineResultTypes() {
609 if (value().getType() != getResult().getType()) {
610 getResult().setType(value().getType());
611 return true;
612 }
613 return false;
614}
615
616//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700617// PatternMatchCallOp
618//===----------------------------------------------------------------------===//
619
620LogicalResult PatternMatchCallOp::verifySymbolUses(
621 SymbolTableCollection &symbolTable) {
622 auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
623 for (auto symbolAttr : symbols) {
624 auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
625 PyFuncOp fn =
626 symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
627 if (!fn)
628 return emitOpError() << "'" << symbol.getValue()
629 << "' does not reference a valid function";
630 }
631 return success();
632 };
633 auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
634 if (!genericsAttr)
635 return emitOpError(
636 "requires a 'generic_match' array of symbol reference attributes");
637 if (failed(verifySymbols(genericsAttr))) return failure();
638
639 auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
640 if (!specificsAttr)
641 return emitOpError(
642 "requires a 'specific_match' array of symbol reference attributes");
643 if (failed(verifySymbols(specificsAttr))) return failure();
644
645 return success();
646}
647
648//===----------------------------------------------------------------------===//
649// PromoteNumericOp
650//===----------------------------------------------------------------------===//
651
652OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) {
653 if (!operands[0]) return {};
654
655 Builder b(getContext());
656 Attribute fromAttr = operands[0];
657 return TypeSwitch<Type, OpFoldResult>(getResult().getType())
658 .Case([&](PyIntegerType toType) -> OpFoldResult {
659 return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
660 .Case([&](BoolAttr fromBool) -> OpFoldResult {
661 return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0);
662 })
663 .Default([](Attribute) -> OpFoldResult { return {}; });
664 })
665 .Case([&](PyRealType toType) -> OpFoldResult {
666 return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
667 .Case([&](BoolAttr fromBool) -> OpFoldResult {
668 return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0);
669 })
670 .Case([&](IntegerAttr fromInteger) -> OpFoldResult {
671 APInt value = fromInteger.getValue();
672 return b.getF64FloatAttr(value.getSExtValue());
673 })
674 .Default([](Attribute) -> OpFoldResult { return {}; });
675 })
676 .Default([](Type) -> OpFoldResult { return {}; });
677}
678
679LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op,
680 PatternRewriter &rewriter) {
681 if (op.input().getType() == op.getResult().getType()) {
682 rewriter.replaceOp(op, op.input());
683 return success();
684 }
685 return failure();
686}
687
688//===----------------------------------------------------------------------===//
689// RaiseOnFailureOp
690//===----------------------------------------------------------------------===//
691
Stella Laurenzo826f1db2021-11-19 12:54:23 -0800692LogicalResult PYDM::RaiseOnFailureOp::canonicalize(RaiseOnFailureOp op,
693 PatternRewriter &rewriter) {
694 if (op.exc_result().getDefiningOp<SuccessOp>()) {
695 op.getOperation()->erase();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700696 return success();
697 }
698 return failure();
699}
700
701//===----------------------------------------------------------------------===//
702// SelectOp
703//===----------------------------------------------------------------------===//
704
705OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
706 if (!operands[0]) return {};
707
708 BoolAttr boolAttr = operands[0].cast<BoolAttr>();
709 if (boolAttr.getValue())
710 return true_value();
711 else
712 return false_value();
713}
714
715//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700716// SequenceCloneOp
717//===----------------------------------------------------------------------===//
718
719bool SequenceCloneOp::refineResultTypes() {
720 if (sequence().getType() != getResult().getType()) {
721 getResult().setType(sequence().getType());
722 return true;
723 }
724 return false;
725}
726
727//===----------------------------------------------------------------------===//
728// SubscriptOp
729//===----------------------------------------------------------------------===//
730
731void SubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
732 MLIRContext *context) {
733 patterns.add<UnboxOperands>(getOperationName(), context);
734}
735
736//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700737// CallOp
738//===----------------------------------------------------------------------===//
739
740LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
741 // Check that the callee attribute was specified.
742 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
743 if (!fnAttr)
744 return emitOpError("requires a 'callee' symbol reference attribute");
745 PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
746 if (!fn)
747 return emitOpError() << "'" << fnAttr.getValue()
748 << "' does not reference a valid function";
749
750 // Verify that the operand and result types match the callee.
751 auto fnType = fn.getType();
752 if (fnType.getNumInputs() != getNumOperands())
753 return emitOpError("incorrect number of operands for callee");
754
755 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
756 if (getOperand(i).getType() != fnType.getInput(i)) {
757 return emitOpError("operand type mismatch: expected operand type ")
758 << fnType.getInput(i) << ", but provided "
759 << getOperand(i).getType() << " for operand number " << i;
760 }
761 }
762
763 if (fnType.getNumResults() != getNumResults())
764 return emitOpError("incorrect number of results for callee");
765
766 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
767 if (getResult(i).getType() != fnType.getResult(i)) {
768 auto diag = emitOpError("result type mismatch at index ") << i;
769 diag.attachNote() << " op result types: " << getResultTypes();
770 diag.attachNote() << "function result types: " << fnType.getResults();
771 return diag;
772 }
773 }
774
775 return success();
776}
777
778FunctionType PyCallOp::getCalleeType() {
779 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
780}
781
782//===----------------------------------------------------------------------===//
783// DynamicCallOp
784//===----------------------------------------------------------------------===//
785
786LogicalResult DynamicCallOp::verifySymbolUses(
787 SymbolTableCollection &symbolTable) {
788 // Check that the callee attribute was specified.
789 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
790 if (!fnAttr)
791 return emitOpError("requires a 'callee' symbol reference attribute");
792 Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
793 if (!fn || !isa<PyFuncOp>(fn))
794 return emitOpError() << "'" << fnAttr.getValue()
795 << "' does not reference a valid function";
796 return success();
797}
798
799#define GET_OP_CLASSES
Stella Laurenzo02cfcd12021-11-14 13:20:53 -0800800#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"