blob: 3b3e9611b7a1de102b49ec44f07d829444cad64c [file] [log] [blame]
Stella Laurenzoa69124b2021-09-06 13:14:10 -07001// Copyright 2021 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
Stella Laurenzo02cfcd12021-11-14 13:20:53 -08007#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
Stella Laurenzoa69124b2021-09-06 13:14:10 -07008
Stella Laurenzo02cfcd12021-11-14 13:20:53 -08009#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070010#include "llvm/ADT/SmallSet.h"
Stella Laurenzoa69124b2021-09-06 13:14:10 -070011#include "llvm/ADT/TypeSwitch.h"
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070012#include "llvm/Support/Debug.h"
Stella Laurenzoa69124b2021-09-06 13:14:10 -070013#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/FunctionImplementation.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/TypeUtilities.h"
18
19using namespace mlir;
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070020namespace PYDM = mlir::iree_compiler::IREE::PYDM;
21using namespace PYDM;
Stella Laurenzoa69124b2021-09-06 13:14:10 -070022
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070023using llvm::dbgs;
24
25using PyBoolType = PYDM::BoolType;
26using PyConstantOp = PYDM::ConstantOp;
27using PyIntegerType = PYDM::IntegerType;
28using PyRealType = PYDM::RealType;
29using PyCallOp = PYDM::CallOp;
30using PyFuncOp = PYDM::FuncOp;
31
32static LogicalResult verify(Operation *) { return success(); }
Stella Laurenzoa69124b2021-09-06 13:14:10 -070033
34//===----------------------------------------------------------------------===//
35// Utilities
36//===----------------------------------------------------------------------===//
37
Stella Laurenzoec9d61f2021-11-06 16:14:31 -070038namespace {
39
40/// Generic pattern to unbox any operands that are a specific object
41/// type (i.e. object<integer>).
42struct UnboxOperands : public RewritePattern {
43 UnboxOperands(StringRef rootName, MLIRContext *context,
44 Optional<llvm::SmallSet<int, 4>> operandIndices = None)
45 : RewritePattern(rootName, 1, context), operandIndices(operandIndices) {}
46 LogicalResult matchAndRewrite(Operation *op,
47 PatternRewriter &rewriter) const override {
48 Location loc = op->getLoc();
49 bool changed = false;
50 SmallVector<Value> operands(op->getOperands());
51 auto excResultType = rewriter.getType<ExceptionResultType>();
52 for (int operandIndex = 0, e = operands.size(); operandIndex < e;
53 ++operandIndex) {
54 Value &operand = operands[operandIndex];
55 if (operandIndices && !operandIndices->contains(operandIndex)) continue;
56 if (auto objectType = operand.getType().dyn_cast<ObjectType>()) {
57 Type primitiveType = objectType.getPrimitiveType();
58 if (primitiveType) {
59 // Unbox.
60 auto unboxOp = rewriter.create<UnboxOp>(
61 loc, TypeRange{excResultType, primitiveType}, operand);
62 operand = unboxOp.primitive();
63 changed = true;
64 }
65 }
66 }
67
68 if (changed) {
69 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
70 return success();
71 }
72
73 return failure();
74 }
75 Optional<llvm::SmallSet<int, 4>> operandIndices;
76};
77
78} // namespace
79
Stella Laurenzoa69124b2021-09-06 13:14:10 -070080static Value getNumericZeroConstant(Location loc, Type numericType,
81 OpBuilder &builder) {
82 return TypeSwitch<Type, Value>(numericType)
83 .Case([&](PyBoolType t) -> Value {
84 return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false));
85 })
86 .Case([&](PyIntegerType t) -> Value {
87 return builder.create<PyConstantOp>(loc, t,
88 builder.getI64IntegerAttr(0));
89 })
90 .Case([&](PyRealType t) -> Value {
91 return builder.create<PyConstantOp>(loc, t,
92 builder.getF64FloatAttr(0.0));
93 });
94}
95
96static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) {
97 return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(),
98 builder.getBoolAttr(pred));
99}
100
101//===----------------------------------------------------------------------===//
102// Constants
103//===----------------------------------------------------------------------===//
104
105OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) {
106 assert(operands.empty() && "constant has no operands");
107 return getValue();
108}
109
110OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) {
111 assert(operands.empty() && "constant has no operands");
112 return UnitAttr::get(getContext());
113}
114
115OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) {
116 assert(operands.empty() && "constant has no operands");
117 return UnitAttr::get(getContext());
118}
119
120//===----------------------------------------------------------------------===//
121// Variables
122//===----------------------------------------------------------------------===//
123
124void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
125 setNameFn(getResult(), name());
126}
127
128//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700129// ApplyBinaryOp
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700130//===----------------------------------------------------------------------===//
131
132namespace {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700133struct ApplyBinaryToSequenceClone : public OpRewritePattern<ApplyBinaryOp> {
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700134 using OpRewritePattern::OpRewritePattern;
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700135 LogicalResult matchAndRewrite(ApplyBinaryOp op,
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700136 PatternRewriter &rewriter) const override {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700137 if (op.dunder_name() != "mul") return failure();
138 Value listOperand;
139 Value countOperand;
140 if (isBuiltinSequence(op.left()) && isInteger(op.right())) {
141 listOperand = op.left();
142 countOperand = op.right();
143 } else if (isInteger(op.left()) && isBuiltinSequence(op.right())) {
144 countOperand = op.left();
145 listOperand = op.right();
146 } else {
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700147 return failure();
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700148 }
149 Type resultType = op.getResult().getType();
150 rewriter.replaceOpWithNewOp<SequenceCloneOp>(op, resultType, listOperand,
151 countOperand);
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700152 return success();
153 }
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700154
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700155 static bool isBuiltinSequence(Value operand) {
156 return operand.getType().isa<PYDM::ListType, PYDM::TupleType>();
157 }
158 static bool isInteger(Value operand) {
159 return operand.getType().isa<PYDM::IntegerType>();
160 }
161};
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700162} // namespace
163
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700164void ApplyBinaryOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
165 MLIRContext *context) {
166 patterns.add<UnboxOperands>(getOperationName(), context);
167 patterns.add<ApplyBinaryToSequenceClone>(context);
168}
169
170bool ApplyBinaryOp::refineResultTypes() {
171 auto leftType = left().getType();
172 auto rightType = right().getType();
173 auto applyUpdates = [&](Type newResultType) -> bool {
174 if (newResultType != getResult().getType()) {
175 getResult().setType(newResultType);
176 return true;
177 }
178 return false;
179 };
180
181 // Both numeric types. It is only dynamically legal for statically known
182 // numeric types to be the same, in which case the result type must be the
183 // same as well.
184 auto ptLeft = leftType.dyn_cast<PythonTypeInterface>();
185 auto ptRight = rightType.dyn_cast<PythonTypeInterface>();
186 if (ptLeft && ptRight && ptLeft.getNumericPromotionOrder() &&
187 ptRight.getNumericPromotionOrder()) {
188 if (leftType == rightType) {
189 return applyUpdates(leftType);
190 }
191 }
192
193 // (list, integer) or (integer, list) refine to the list type.
194 if (dunder_name() == "mul") {
195 auto leftList = leftType.dyn_cast<ListType>();
196 auto rightList = rightType.dyn_cast<ListType>();
197 auto leftInteger = leftType.dyn_cast<IntegerType>();
198 auto rightInteger = rightType.dyn_cast<IntegerType>();
199 if (leftList && rightInteger) {
200 return applyUpdates(leftList);
201 } else if (leftInteger && rightList) {
202 return applyUpdates(rightList);
203 }
204 }
205
206 return false;
207}
208
209//===----------------------------------------------------------------------===//
210// ApplyCompareOp
211//===----------------------------------------------------------------------===//
212
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700213void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
214 MLIRContext *context) {
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700215 patterns.add<UnboxOperands>(getOperationName(), context);
Stella Laurenzo813d3ae2021-10-06 15:39:03 -0700216}
217
218//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700219// AsBoolOp
220//===----------------------------------------------------------------------===//
221
222namespace {
223struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> {
224 public:
225 using OpRewritePattern::OpRewritePattern;
226 LogicalResult matchAndRewrite(AsBoolOp op,
227 PatternRewriter &rewriter) const override {
228 if (op.value().getType().isa<BoolType>()) {
229 rewriter.replaceOp(op, op.value());
230 return success();
231 }
232 return failure();
233 }
234};
235
236struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> {
237 public:
238 using OpRewritePattern::OpRewritePattern;
239 LogicalResult matchAndRewrite(AsBoolOp op,
240 PatternRewriter &rewriter) const override {
241 auto loc = op.getLoc();
242 auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>();
243 if (!ptType) return failure();
244 if (!ptType.getNumericPromotionOrder()) return failure();
245
246 auto boolType = rewriter.getType<BoolType>();
247 Value zeroValue =
248 getNumericZeroConstant(loc, op.value().getType(), rewriter);
249 Value trueValue = getBoolConstant(loc, true, rewriter);
250 Value falseValue = getBoolConstant(loc, false, rewriter);
251 Value cmpResult = rewriter.create<ApplyCompareOp>(
252 loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue);
253 rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue,
254 trueValue);
255 return success();
256 }
257};
258
259} // namespace
260
261void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
262 MLIRContext *context) {
263 patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context);
264}
265
266OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) {
267 Builder b(getContext());
268 // Fold NoneType to False.
269 if (value().getType().isa<NoneType>()) {
270 return b.getBoolAttr(false);
271 }
272
273 return {};
274}
275
276//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700277// AssignSubscriptOp
278//===----------------------------------------------------------------------===//
279
280void AssignSubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
281 MLIRContext *context) {
282 llvm::SmallSet<int, 4> unboxIndices;
283 unboxIndices.insert(0);
284 unboxIndices.insert(1);
285 patterns.add<UnboxOperands>(getOperationName(), context, unboxIndices);
286}
287
288//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700289// BoolToPredOp
290//===----------------------------------------------------------------------===//
291
292OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) {
293 if (!operands[0]) return {};
294 // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1),
295 // we can just return it.
296 return operands[0];
297}
298
299//===----------------------------------------------------------------------===//
300// BoxOp and UnboxOp
301//===----------------------------------------------------------------------===//
302
303LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) {
304 // Sometimes boxes are emitted when the input is an object. Just remove.
305 if (op.primitive().getType().isa<ObjectType>()) {
306 rewriter.replaceOp(op, op.primitive());
307 return success();
308 }
309
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700310 // Box to an appropriate type and static info cast.
311 ObjectType objectType = rewriter.getType<ObjectType>(nullptr);
312 if (op.object().getType() == objectType &&
313 !op.primitive().getType().isa<ObjectType>()) {
314 auto refinedBox = rewriter.create<BoxOp>(
315 op.getLoc(),
316 rewriter.getType<ObjectType>(
317 op.primitive().getType().cast<PrimitiveType>()),
318 op.primitive());
319 rewriter.replaceOpWithNewOp<StaticInfoCastOp>(op, op.object().getType(),
320 refinedBox);
321 return success();
322 }
323
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700324 return failure();
325}
326
327LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp,
328 PatternRewriter &rewriter) {
329 auto loc = unboxOp.getLoc();
330
331 // Handle the case of an immediate BoxOp producer.
332 if (auto boxProducer =
333 dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) {
334 // If the producer is boxing to the same type we are unboxing, then
335 // just elide everything.
336 if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) {
337 auto successValue = rewriter.create<SuccessOp>(
338 loc, rewriter.getType<ExceptionResultType>());
339 rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()});
340 return success();
341 }
342 }
343 return failure();
344}
345
346//===----------------------------------------------------------------------===//
347// DynamicBinaryPromoteOp
348//===----------------------------------------------------------------------===//
349
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700350namespace {
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700351
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700352/// Resolves a DynamicBinaryPromote over numeric operands to either elide
353/// or insert specific PromoteNumeric ops.
354struct ResolveNumericDynamicBinaryPromote
355 : public OpRewritePattern<DynamicBinaryPromoteOp> {
356 public:
357 using OpRewritePattern::OpRewritePattern;
358 LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
359 PatternRewriter &rewriter) const override {
360 auto loc = op.getLoc();
361 auto leftType = op.left().getType();
362 auto rightType = op.right().getType();
363 auto leftResultType = op.getResultTypes()[0];
364 auto rightResultType = op.getResultTypes()[1];
365 auto leftPt = leftType.dyn_cast<PythonTypeInterface>();
366 auto rightPt = rightType.dyn_cast<PythonTypeInterface>();
367 if (!leftPt || !rightPt) return failure();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700368
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700369 Optional<int> leftOrder = leftPt.getNumericPromotionOrder();
370 Optional<int> rightOrder = rightPt.getNumericPromotionOrder();
371 Value newLeft = op.left();
372 Value newRight = op.right();
373
374 if (leftOrder && rightOrder) {
375 // Both numeric.
376 if (*leftOrder > *rightOrder) {
377 newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight);
378 }
379 if (*rightOrder > *leftOrder) {
380 newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft);
381 }
382 } else {
383 return failure();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700384 }
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700385
386 // Need to box back to the original type (which will always be a generic
387 // object).
388 newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft);
389 newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight);
390
391 rewriter.replaceOp(op, {newLeft, newRight});
392 return success();
393 }
394};
395
396/// If we statically determine one of the arguments to be a concrete, non
397/// numeric type, then the op has no meaning and is elided.
398struct ElideNonNumericDynamicBinaryPromote
399 : public OpRewritePattern<DynamicBinaryPromoteOp> {
400 public:
401 using OpRewritePattern::OpRewritePattern;
402 LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
403 PatternRewriter &rewriter) const override {
404 if ((!isConcreteNonNumericType(op.left().getType()) &&
405 !isConcreteNonNumericType(op.right().getType())))
406 return failure();
407
408 // Since DynamicBinaryPromote already returns object, and we only match
409 // non-object operands, box them back.
410 auto loc = op.getLoc();
411 auto leftResultType = op.getResultTypes()[0];
412 auto rightResultType = op.getResultTypes()[1];
413 Value newLeft = rewriter.create<BoxOp>(loc, leftResultType, op.left());
414 Value newRight = rewriter.create<BoxOp>(loc, rightResultType, op.right());
415 rewriter.replaceOp(op, {newLeft, newRight});
416 return success();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700417 }
418
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700419 static bool isConcreteNonNumericType(Type t) {
420 if (t.isa<ObjectType>()) return false;
421 auto pt = t.dyn_cast<PythonTypeInterface>();
422 if (!pt || pt.getNumericPromotionOrder()) return false;
423 return true;
424 }
425};
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700426
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700427} // namespace
428
429void DynamicBinaryPromoteOp::getCanonicalizationPatterns(
430 RewritePatternSet &patterns, MLIRContext *context) {
431 patterns.add<ResolveNumericDynamicBinaryPromote>(context);
432 patterns.add<UnboxOperands>(getOperationName(), context);
433 patterns.add<ElideNonNumericDynamicBinaryPromote>(context);
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700434}
435
436//===----------------------------------------------------------------------===//
437// FunctionalIfOp
438//===----------------------------------------------------------------------===//
439
440::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; }
441
442static LogicalResult verify(FunctionalIfOp op) {
443 if (op.getNumResults() != 0 && op.elseRegion().empty())
444 return op.emitOpError("must have an else block if defining values");
445
446 return RegionBranchOpInterface::verifyTypes(op);
447}
448
449static ParseResult parseFunctionalIfOp(OpAsmParser &parser,
450 OperationState &result) {
451 // Create the regions for 'then'.
452 result.regions.reserve(2);
453 Region *thenRegion = result.addRegion();
454 Region *elseRegion = result.addRegion();
455
456 auto &builder = parser.getBuilder();
457 OpAsmParser::OperandType cond;
458 Type conditionType = builder.getType<PyBoolType>();
459 if (parser.parseOperand(cond) ||
460 parser.resolveOperand(cond, conditionType, result.operands))
461 return failure();
462 // Parse optional results type list.
463 if (parser.parseOptionalArrowTypeList(result.types)) return failure();
464 // Parse the 'then' region.
465 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
466 return failure();
467 // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
468
469 // If we find an 'else' keyword then parse the 'else' region.
470 if (!parser.parseOptionalKeyword("else")) {
471 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
472 return failure();
473 // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
474 // result.location);
475 }
476
477 // Parse the optional attribute list.
478 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
479 return success();
480}
481
482static void print(OpAsmPrinter &p, FunctionalIfOp op) {
483 bool printBlockTerminators = false;
484
485 p << " " << op.condition();
486 if (!op.results().empty()) {
487 p << " -> (" << op.getResultTypes() << ")";
488 // Print yield explicitly if the op defines values.
489 printBlockTerminators = true;
490 }
491 p.printRegion(op.thenRegion(),
492 /*printEntryBlockArgs=*/false,
493 /*printBlockTerminators=*/printBlockTerminators);
494
495 // Print the 'else' regions if it exists and has a block.
496 auto &elseRegion = op.elseRegion();
497 if (!elseRegion.empty()) {
498 p << " else";
499 p.printRegion(elseRegion,
500 /*printEntryBlockArgs=*/false,
501 /*printBlockTerminators=*/printBlockTerminators);
502 }
503
504 p.printOptionalAttrDict(op->getAttrs());
505}
506
507/// Given the region at `index`, or the parent operation if `index` is None,
508/// return the successor regions. These are the regions that may be selected
509/// during the flow of control. `operands` is a set of optional attributes that
510/// correspond to a constant value for each operand, or null if that operand is
511/// not a constant.
512void FunctionalIfOp::getSuccessorRegions(
513 Optional<unsigned> index, ArrayRef<Attribute> operands,
514 SmallVectorImpl<RegionSuccessor> &regions) {
515 // The `then` and the `else` region branch back to the parent operation.
516 if (index.hasValue()) {
517 regions.push_back(RegionSuccessor(getResults()));
518 return;
519 }
520
521 // Don't consider the else region if it is empty.
522 Region *elseRegion = &this->elseRegion();
523 if (elseRegion->empty()) elseRegion = nullptr;
524
525 // Otherwise, the successor is dependent on the condition.
526 if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
527 bool condition = condAttr.getValue();
528 // Add the successor regions using the condition.
529 regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
530 } else {
531 // If the condition isn't constant, both regions may be executed.
532 regions.push_back(RegionSuccessor(&thenRegion()));
533 // If the else region does not exist, it is not a viable successor.
534 if (elseRegion) regions.push_back(RegionSuccessor(elseRegion));
535 }
536}
537
538//===----------------------------------------------------------------------===//
539// FuncOp
540//===----------------------------------------------------------------------===//
541
542::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; }
543
544LogicalResult PyFuncOp::verifyType() {
545 // TODO: Enforce arg/result invariants.
546 return success();
547}
548
549static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
550 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
551 ArrayRef<Type> results,
552 function_like_impl::VariadicFlag, std::string &) {
553 return builder.getFunctionType(argTypes, results);
554 };
555
556 return function_like_impl::parseFunctionLikeOp(
557 parser, result, /*allowVariadic=*/false, buildFuncType);
558}
559
560static void print(PyFuncOp op, OpAsmPrinter &p) {
561 FunctionType fnType = op.getType();
562 function_like_impl::printFunctionLikeOp(
563 p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
564}
565
566static LogicalResult verify(PyFuncOp op) {
567 // TODO: Enforce invariants.
568 return success();
569}
570
571//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700572// MakeListOp
573//===----------------------------------------------------------------------===//
574
575static LogicalResult verify(MakeListOp op) {
576 auto listType = op.list().getType().cast<ListType>();
577 switch (listType.getStorageClass()) {
578 case CollectionStorageClass::Boxed:
579 for (auto element : op.elements()) {
580 if (!element.getType().isa<ObjectType>()) {
581 return op.emitOpError() << "making a list with boxed storage class "
582 "must have object elements. Got: "
583 << element.getType();
584 }
585 }
586 break;
587 case CollectionStorageClass::Unboxed:
588 for (auto element : op.elements()) {
589 if (element.getType().isa<ObjectType>()) {
590 return op.emitOpError() << "making a list with unboxed storage class "
591 "must not have object elements. Got: "
592 << element.getType();
593 }
594 }
595 break;
596 case CollectionStorageClass::Empty:
597 if (!op.elements().empty()) {
598 return op.emitOpError()
599 << "making a list with empty storage class must have zero "
600 "elements";
601 }
602 break;
603 }
604 return success();
605}
606
607//===----------------------------------------------------------------------===//
608// NegOp
609//===----------------------------------------------------------------------===//
610
611void NegOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
612 MLIRContext *context) {
613 patterns.add<UnboxOperands>(getOperationName(), context);
614}
615
616bool NegOp::refineResultTypes() {
617 if (value().getType() != getResult().getType()) {
618 getResult().setType(value().getType());
619 return true;
620 }
621 return false;
622}
623
624//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700625// PatternMatchCallOp
626//===----------------------------------------------------------------------===//
627
628LogicalResult PatternMatchCallOp::verifySymbolUses(
629 SymbolTableCollection &symbolTable) {
630 auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
631 for (auto symbolAttr : symbols) {
632 auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
633 PyFuncOp fn =
634 symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
635 if (!fn)
636 return emitOpError() << "'" << symbol.getValue()
637 << "' does not reference a valid function";
638 }
639 return success();
640 };
641 auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
642 if (!genericsAttr)
643 return emitOpError(
644 "requires a 'generic_match' array of symbol reference attributes");
645 if (failed(verifySymbols(genericsAttr))) return failure();
646
647 auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
648 if (!specificsAttr)
649 return emitOpError(
650 "requires a 'specific_match' array of symbol reference attributes");
651 if (failed(verifySymbols(specificsAttr))) return failure();
652
653 return success();
654}
655
656//===----------------------------------------------------------------------===//
657// PromoteNumericOp
658//===----------------------------------------------------------------------===//
659
660OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) {
661 if (!operands[0]) return {};
662
663 Builder b(getContext());
664 Attribute fromAttr = operands[0];
665 return TypeSwitch<Type, OpFoldResult>(getResult().getType())
666 .Case([&](PyIntegerType toType) -> OpFoldResult {
667 return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
668 .Case([&](BoolAttr fromBool) -> OpFoldResult {
669 return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0);
670 })
671 .Default([](Attribute) -> OpFoldResult { return {}; });
672 })
673 .Case([&](PyRealType toType) -> OpFoldResult {
674 return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
675 .Case([&](BoolAttr fromBool) -> OpFoldResult {
676 return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0);
677 })
678 .Case([&](IntegerAttr fromInteger) -> OpFoldResult {
679 APInt value = fromInteger.getValue();
680 return b.getF64FloatAttr(value.getSExtValue());
681 })
682 .Default([](Attribute) -> OpFoldResult { return {}; });
683 })
684 .Default([](Type) -> OpFoldResult { return {}; });
685}
686
687LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op,
688 PatternRewriter &rewriter) {
689 if (op.input().getType() == op.getResult().getType()) {
690 rewriter.replaceOp(op, op.input());
691 return success();
692 }
693 return failure();
694}
695
696//===----------------------------------------------------------------------===//
697// RaiseOnFailureOp
698//===----------------------------------------------------------------------===//
699
Stella Laurenzo826f1db2021-11-19 12:54:23 -0800700LogicalResult PYDM::RaiseOnFailureOp::canonicalize(RaiseOnFailureOp op,
701 PatternRewriter &rewriter) {
702 if (op.exc_result().getDefiningOp<SuccessOp>()) {
703 op.getOperation()->erase();
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700704 return success();
705 }
706 return failure();
707}
708
709//===----------------------------------------------------------------------===//
710// SelectOp
711//===----------------------------------------------------------------------===//
712
713OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
714 if (!operands[0]) return {};
715
716 BoolAttr boolAttr = operands[0].cast<BoolAttr>();
717 if (boolAttr.getValue())
718 return true_value();
719 else
720 return false_value();
721}
722
723//===----------------------------------------------------------------------===//
Stella Laurenzoec9d61f2021-11-06 16:14:31 -0700724// SequenceCloneOp
725//===----------------------------------------------------------------------===//
726
727bool SequenceCloneOp::refineResultTypes() {
728 if (sequence().getType() != getResult().getType()) {
729 getResult().setType(sequence().getType());
730 return true;
731 }
732 return false;
733}
734
735//===----------------------------------------------------------------------===//
736// SubscriptOp
737//===----------------------------------------------------------------------===//
738
739void SubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
740 MLIRContext *context) {
741 patterns.add<UnboxOperands>(getOperationName(), context);
742}
743
744//===----------------------------------------------------------------------===//
Stella Laurenzoa69124b2021-09-06 13:14:10 -0700745// CallOp
746//===----------------------------------------------------------------------===//
747
748LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
749 // Check that the callee attribute was specified.
750 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
751 if (!fnAttr)
752 return emitOpError("requires a 'callee' symbol reference attribute");
753 PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
754 if (!fn)
755 return emitOpError() << "'" << fnAttr.getValue()
756 << "' does not reference a valid function";
757
758 // Verify that the operand and result types match the callee.
759 auto fnType = fn.getType();
760 if (fnType.getNumInputs() != getNumOperands())
761 return emitOpError("incorrect number of operands for callee");
762
763 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
764 if (getOperand(i).getType() != fnType.getInput(i)) {
765 return emitOpError("operand type mismatch: expected operand type ")
766 << fnType.getInput(i) << ", but provided "
767 << getOperand(i).getType() << " for operand number " << i;
768 }
769 }
770
771 if (fnType.getNumResults() != getNumResults())
772 return emitOpError("incorrect number of results for callee");
773
774 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
775 if (getResult(i).getType() != fnType.getResult(i)) {
776 auto diag = emitOpError("result type mismatch at index ") << i;
777 diag.attachNote() << " op result types: " << getResultTypes();
778 diag.attachNote() << "function result types: " << fnType.getResults();
779 return diag;
780 }
781 }
782
783 return success();
784}
785
786FunctionType PyCallOp::getCalleeType() {
787 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
788}
789
790//===----------------------------------------------------------------------===//
791// DynamicCallOp
792//===----------------------------------------------------------------------===//
793
794LogicalResult DynamicCallOp::verifySymbolUses(
795 SymbolTableCollection &symbolTable) {
796 // Check that the callee attribute was specified.
797 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
798 if (!fnAttr)
799 return emitOpError("requires a 'callee' symbol reference attribute");
800 Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
801 if (!fn || !isa<PyFuncOp>(fn))
802 return emitOpError() << "'" << fnAttr.getValue()
803 << "' does not reference a valid function";
804 return success();
805}
806
807#define GET_OP_CLASSES
Stella Laurenzo02cfcd12021-11-14 13:20:53 -0800808#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"