Merge pull request #3467 from google/benvanik-vm-fold-arithmetic
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 819373f..0f97aae 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -202,10 +202,9 @@
namespace {
template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
using OpRewritePattern<INDIRECT>::OpRewritePattern;
- public:
LogicalResult matchAndRewrite(INDIRECT op,
PatternRewriter &rewriter) const override {
if (auto addressOp =
@@ -244,10 +243,9 @@
namespace {
template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
using OpRewritePattern<INDIRECT>::OpRewritePattern;
- public:
LogicalResult matchAndRewrite(INDIRECT op,
PatternRewriter &rewriter) const override {
if (auto addressOp =
@@ -382,15 +380,13 @@
// Native integer arithmetic
//===----------------------------------------------------------------------===//
-namespace {
-
/// Performs const folding `calculate` with element-wise behavior on the given
/// attribute in `operands` and returns the result if possible.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 1 && "unary op takes one operand");
if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
return AttrElementT::get(operand.getType(), calculate(operand.getValue()));
@@ -414,8 +410,8 @@
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
-Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
@@ -448,42 +444,54 @@
return {};
}
-} // namespace
-
-template <typename T>
-static OpFoldResult foldAddOp(T op, ArrayRef<Attribute> operands) {
+template <typename ADD, typename SUB>
+static OpFoldResult foldAddOp(ADD op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x + 0 = x or 0 + y = y (commutative)
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a + b; });
+ if (auto subOp = dyn_cast_or_null<SUB>(op.lhs().getDefiningOp())) {
+ if (subOp.lhs() == op.rhs()) return subOp.rhs();
+ if (subOp.rhs() == op.rhs()) return subOp.lhs();
+ } else if (auto subOp = dyn_cast_or_null<SUB>(op.rhs().getDefiningOp())) {
+ if (subOp.lhs() == op.lhs()) return subOp.rhs();
+ if (subOp.rhs() == op.lhs()) return subOp.lhs();
+ }
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a + b; });
}
OpFoldResult AddI32Op::fold(ArrayRef<Attribute> operands) {
- return foldAddOp(*this, operands);
+ return foldAddOp<AddI32Op, SubI32Op>(*this, operands);
}
OpFoldResult AddI64Op::fold(ArrayRef<Attribute> operands) {
- return foldAddOp(*this, operands);
+ return foldAddOp<AddI64Op, SubI64Op>(*this, operands);
}
-template <typename T>
-static OpFoldResult foldSubOp(T op, ArrayRef<Attribute> operands) {
+template <typename SUB, typename ADD>
+static OpFoldResult foldSubOp(SUB op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x - 0 = x
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a - b; });
+ if (auto addOp = dyn_cast_or_null<ADD>(op.lhs().getDefiningOp())) {
+ if (addOp.lhs() == op.rhs()) return addOp.rhs();
+ if (addOp.rhs() == op.rhs()) return addOp.lhs();
+ } else if (auto addOp = dyn_cast_or_null<ADD>(op.rhs().getDefiningOp())) {
+ if (addOp.lhs() == op.lhs()) return addOp.rhs();
+ if (addOp.rhs() == op.lhs()) return addOp.lhs();
+ }
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a - b; });
}
OpFoldResult SubI32Op::fold(ArrayRef<Attribute> operands) {
- return foldSubOp(*this, operands);
+ return foldSubOp<SubI32Op, AddI32Op>(*this, operands);
}
OpFoldResult SubI64Op::fold(ArrayRef<Attribute> operands) {
- return foldSubOp(*this, operands);
+ return foldSubOp<SubI64Op, AddI64Op>(*this, operands);
}
template <typename T>
@@ -495,18 +503,51 @@
// x * 1 = x or 1 * y = y (commutative)
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a * b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a * b; });
}
+template <typename T, typename CONST_OP>
+struct FoldConstantMulOperand : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ IntegerAttr c1, c2;
+ if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure();
+ if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) {
+ if (matchPattern(mulOp.rhs(), m_Constant(&c2))) {
+ auto c = rewriter.createOrFold<CONST_OP>(
+ FusedLoc::get({mulOp.getLoc(), op.getLoc()}, rewriter.getContext()),
+ constFoldBinaryOp<IntegerAttr>(
+ {c1, c2},
+ [](const APInt &a, const APInt &b) { return a * b; }));
+ rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) {
return foldMulOp(*this, operands);
}
+void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context);
+}
+
OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) {
return foldMulOp(*this, operands);
}
+void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context);
+}
+
template <typename T>
static OpFoldResult foldDivSOp(T op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
@@ -521,7 +562,7 @@
return op.lhs();
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.sdiv(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.sdiv(b); });
}
OpFoldResult DivI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -546,7 +587,7 @@
return op.lhs();
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.udiv(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.udiv(b); });
}
OpFoldResult DivI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -570,7 +611,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.srem(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.srem(b); });
}
OpFoldResult RemI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -589,7 +630,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.urem(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.urem(b); });
}
OpFoldResult RemI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -625,8 +666,8 @@
// x & x = x
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a & b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a & b; });
}
OpFoldResult AndI32Op::fold(ArrayRef<Attribute> operands) {
@@ -646,8 +687,8 @@
// x | x = x
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a | b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a | b; });
}
OpFoldResult OrI32Op::fold(ArrayRef<Attribute> operands) {
@@ -667,8 +708,8 @@
// x ^ x = 0
return zeroOfType(op.getType());
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a ^ b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a ^ b; });
}
OpFoldResult XorI32Op::fold(ArrayRef<Attribute> operands) {
@@ -693,7 +734,7 @@
return op.operand();
}
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return a.shl(op.amount()); });
+ operands, [&](const APInt &a) { return a.shl(op.amount()); });
}
OpFoldResult ShlI32Op::fold(ArrayRef<Attribute> operands) {
@@ -714,7 +755,7 @@
return op.operand();
}
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return a.ashr(op.amount()); });
+ operands, [&](const APInt &a) { return a.ashr(op.amount()); });
}
OpFoldResult ShrI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -735,7 +776,7 @@
return op.operand();
}
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return a.lshr(op.amount()); });
+ operands, [&](const APInt &a) { return a.lshr(op.amount()); });
}
OpFoldResult ShrI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -755,8 +796,9 @@
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldConversionOp(Type resultType, ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+static Attribute constFoldConversionOp(Type resultType,
+ ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 1 && "unary op takes one operand");
if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
return AttrElementT::get(resultType, calculate(operand.getValue()));
@@ -767,101 +809,100 @@
OpFoldResult TruncI32I8Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(32); });
+ [&](const APInt &a) { return a.trunc(8).zext(32); });
}
OpFoldResult TruncI32I16Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(32); });
+ [&](const APInt &a) { return a.trunc(16).zext(32); });
}
OpFoldResult TruncI64I8Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(32); });
+ [&](const APInt &a) { return a.trunc(8).zext(32); });
}
OpFoldResult TruncI64I16Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(32); });
+ [&](const APInt &a) { return a.trunc(16).zext(32); });
}
OpFoldResult TruncI64I32Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(32); });
+ [&](const APInt &a) { return a.trunc(32); });
}
OpFoldResult ExtI8I32SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).sext(32); });
+ [&](const APInt &a) { return a.trunc(8).sext(32); });
}
OpFoldResult ExtI8I32UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(32); });
+ [&](const APInt &a) { return a.trunc(8).zext(32); });
}
OpFoldResult ExtI16I32SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).sext(32); });
+ [&](const APInt &a) { return a.trunc(16).sext(32); });
}
OpFoldResult ExtI16I32UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(32); });
+ [&](const APInt &a) { return a.trunc(16).zext(32); });
}
OpFoldResult ExtI8I64SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(8).sext(64); });
+ [&](const APInt &a) { return a.trunc(8).sext(64); });
}
OpFoldResult ExtI8I64UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(64); });
+ [&](const APInt &a) { return a.trunc(8).zext(64); });
}
OpFoldResult ExtI16I64SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(16).sext(64); });
+ [&](const APInt &a) { return a.trunc(16).sext(64); });
}
OpFoldResult ExtI16I64UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(64); });
+ [&](const APInt &a) { return a.trunc(16).zext(64); });
}
OpFoldResult ExtI32I64SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.sext(64); });
+ [&](const APInt &a) { return a.sext(64); });
}
OpFoldResult ExtI32I64UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.zext(64); });
+ [&](const APInt &a) { return a.zext(64); });
}
namespace {
template <typename SRC_OP, typename OP_A, int SZ_T, typename OP_B>
-class PseudoIntegerConversionToSplitConversionOp
+struct PseudoIntegerConversionToSplitConversionOp
: public OpRewritePattern<SRC_OP> {
using OpRewritePattern<SRC_OP>::OpRewritePattern;
- public:
LogicalResult matchAndRewrite(SRC_OP op,
PatternRewriter &rewriter) const override {
auto tmp = rewriter.createOrFold<OP_A>(
@@ -956,7 +997,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.eq(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.eq(b); });
}
OpFoldResult CmpEQI32Op::fold(ArrayRef<Attribute> operands) {
@@ -984,7 +1025,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ne(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ne(b); });
}
OpFoldResult CmpNEI32Op::fold(ArrayRef<Attribute> operands) {
@@ -1032,7 +1073,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.slt(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.slt(b); });
}
OpFoldResult CmpLTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1056,7 +1097,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ult(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ult(b); });
}
OpFoldResult CmpLTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1100,7 +1141,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.sle(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.sle(b); });
}
OpFoldResult CmpLTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1130,7 +1171,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ule(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ule(b); });
}
OpFoldResult CmpLTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1176,7 +1217,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.sgt(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.sgt(b); });
}
OpFoldResult CmpGTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1206,7 +1247,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ugt(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ugt(b); });
}
OpFoldResult CmpGTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1256,7 +1297,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.sge(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.sge(b); });
}
OpFoldResult CmpGTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1286,7 +1327,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.uge(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.uge(b); });
}
OpFoldResult CmpGTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1311,12 +1352,12 @@
OpFoldResult CmpNZI32Op::fold(ArrayRef<Attribute> operands) {
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return APInt(32, a.getBoolValue()); });
+ operands, [&](const APInt &a) { return APInt(32, a.getBoolValue()); });
}
OpFoldResult CmpNZI64Op::fold(ArrayRef<Attribute> operands) {
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return APInt(64, a.getBoolValue()); });
+ operands, [&](const APInt &a) { return APInt(64, a.getBoolValue()); });
}
OpFoldResult CmpEQRefOp::fold(ArrayRef<Attribute> operands) {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 727cbf0..d613c8b 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1311,6 +1311,7 @@
VM_BinaryArithmeticOp<I32, "mul.i32", VM_OPC_MulI32, [Commutative]> {
let summary = [{integer multiplication operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_MulI64Op :
@@ -1318,6 +1319,7 @@
[VM_ExtI64, Commutative]> {
let summary = [{integer multiplication operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_DivI32SOp :
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
index 153389c..38e49b2 100644
--- a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -56,6 +56,41 @@
// -----
+// CHECK-LABEL: @add_sub_i32_folds
+vm.module @add_sub_i32_folds {
+ // CHECK-LABEL: @add_sub_x
+ vm.func @add_sub_x(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.add.i32 %arg0, %arg1 : i32
+ %1 = vm.sub.i32 %0, %arg1 : i32
+ vm.return %1 : i32
+ }
+ // CHECK-LABEL: @add_sub_x_rev
+ vm.func @add_sub_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.add.i32 %arg1, %arg0 : i32
+ %1 = vm.sub.i32 %arg1, %0 : i32
+ vm.return %1 : i32
+ }
+
+ // CHECK-LABEL: @sub_add_x
+ vm.func @sub_add_x(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.sub.i32 %arg0, %arg1 : i32
+ %1 = vm.add.i32 %0, %arg1 : i32
+ vm.return %1 : i32
+ }
+ // CHECK-LABEL: @sub_add_x_rev
+ vm.func @sub_add_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.sub.i32 %arg0, %arg1 : i32
+ %1 = vm.add.i32 %arg1, %0 : i32
+ vm.return %1 : i32
+ }
+}
+
+// -----
+
// CHECK-LABEL: @mul_i32_folds
vm.module @mul_i32_folds {
// CHECK-LABEL: @mul_i32_by_0
@@ -96,6 +131,23 @@
// -----
+// CHECK-LABEL: @mul_mul_i32_folds
+vm.module @mul_mul_i32_folds {
+ // CHECK-LABEL: @mul_mul_i32_const
+ vm.func @mul_mul_i32_const(%arg0 : i32) -> i32 {
+ // CHECK: %c40 = vm.const.i32 40 : i32
+ %c4 = vm.const.i32 4 : i32
+ %c10 = vm.const.i32 10 : i32
+ // CHECK: %0 = vm.mul.i32 %arg0, %c40 : i32
+ %0 = vm.mul.i32 %arg0, %c4 : i32
+ %1 = vm.mul.i32 %0, %c10 : i32
+ // CHECK-NEXT: vm.return %0 : i32
+ vm.return %1 : i32
+ }
+}
+
+// -----
+
// CHECK-LABEL: @div_i32_folds
vm.module @div_i32_folds {
// CHECK-LABEL: @div_i32_0_y