Merge pull request #3467 from google/benvanik-vm-fold-arithmetic

diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 819373f..0f97aae 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -202,10 +202,9 @@
 namespace {
 
 template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
   using OpRewritePattern<INDIRECT>::OpRewritePattern;
 
- public:
   LogicalResult matchAndRewrite(INDIRECT op,
                                 PatternRewriter &rewriter) const override {
     if (auto addressOp =
@@ -244,10 +243,9 @@
 namespace {
 
 template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
   using OpRewritePattern<INDIRECT>::OpRewritePattern;
 
- public:
   LogicalResult matchAndRewrite(INDIRECT op,
                                 PatternRewriter &rewriter) const override {
     if (auto addressOp =
@@ -382,15 +380,13 @@
 // Native integer arithmetic
 //===----------------------------------------------------------------------===//
 
-namespace {
-
 /// Performs const folding `calculate` with element-wise behavior on the given
 /// attribute in `operands` and returns the result if possible.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
-                           const CalculationT &calculate) {
+static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+                                  const CalculationT &calculate) {
   assert(operands.size() == 1 && "unary op takes one operand");
   if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
     return AttrElementT::get(operand.getType(), calculate(operand.getValue()));
@@ -414,8 +410,8 @@
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
               std::function<ElementValueT(ElementValueT, ElementValueT)>>
-Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
-                            const CalculationT &calculate) {
+static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
+                                   const CalculationT &calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
   if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
     auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
@@ -448,42 +444,54 @@
   return {};
 }
 
-}  // namespace
-
-template <typename T>
-static OpFoldResult foldAddOp(T op, ArrayRef<Attribute> operands) {
+template <typename ADD, typename SUB>
+static OpFoldResult foldAddOp(ADD op, ArrayRef<Attribute> operands) {
   if (matchPattern(op.rhs(), m_Zero())) {
     // x + 0 = x or 0 + y = y (commutative)
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a + b; });
+  if (auto subOp = dyn_cast_or_null<SUB>(op.lhs().getDefiningOp())) {
+    if (subOp.lhs() == op.rhs()) return subOp.rhs();
+    if (subOp.rhs() == op.rhs()) return subOp.lhs();
+  } else if (auto subOp = dyn_cast_or_null<SUB>(op.rhs().getDefiningOp())) {
+    if (subOp.lhs() == op.lhs()) return subOp.rhs();
+    if (subOp.rhs() == op.lhs()) return subOp.lhs();
+  }
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a + b; });
 }
 
 OpFoldResult AddI32Op::fold(ArrayRef<Attribute> operands) {
-  return foldAddOp(*this, operands);
+  return foldAddOp<AddI32Op, SubI32Op>(*this, operands);
 }
 
 OpFoldResult AddI64Op::fold(ArrayRef<Attribute> operands) {
-  return foldAddOp(*this, operands);
+  return foldAddOp<AddI64Op, SubI64Op>(*this, operands);
 }
 
-template <typename T>
-static OpFoldResult foldSubOp(T op, ArrayRef<Attribute> operands) {
+template <typename SUB, typename ADD>
+static OpFoldResult foldSubOp(SUB op, ArrayRef<Attribute> operands) {
   if (matchPattern(op.rhs(), m_Zero())) {
     // x - 0 = x
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a - b; });
+  if (auto addOp = dyn_cast_or_null<ADD>(op.lhs().getDefiningOp())) {
+    if (addOp.lhs() == op.rhs()) return addOp.rhs();
+    if (addOp.rhs() == op.rhs()) return addOp.lhs();
+  } else if (auto addOp = dyn_cast_or_null<ADD>(op.rhs().getDefiningOp())) {
+    if (addOp.lhs() == op.lhs()) return addOp.rhs();
+    if (addOp.rhs() == op.lhs()) return addOp.lhs();
+  }
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a - b; });
 }
 
 OpFoldResult SubI32Op::fold(ArrayRef<Attribute> operands) {
-  return foldSubOp(*this, operands);
+  return foldSubOp<SubI32Op, AddI32Op>(*this, operands);
 }
 
 OpFoldResult SubI64Op::fold(ArrayRef<Attribute> operands) {
-  return foldSubOp(*this, operands);
+  return foldSubOp<SubI64Op, AddI64Op>(*this, operands);
 }
 
 template <typename T>
@@ -495,18 +503,51 @@
     // x * 1 = x or 1 * y = y (commutative)
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a * b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a * b; });
 }
 
+template <typename T, typename CONST_OP>
+struct FoldConstantMulOperand : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    IntegerAttr c1, c2;
+    if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure();
+    if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) {
+      if (matchPattern(mulOp.rhs(), m_Constant(&c2))) {
+        auto c = rewriter.createOrFold<CONST_OP>(
+            FusedLoc::get({mulOp.getLoc(), op.getLoc()}, rewriter.getContext()),
+            constFoldBinaryOp<IntegerAttr>(
+                {c1, c2},
+                [](const APInt &a, const APInt &b) { return a * b; }));
+        rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
 OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) {
   return foldMulOp(*this, operands);
 }
 
+void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context);
+}
+
 OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) {
   return foldMulOp(*this, operands);
 }
 
+void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context);
+}
+
 template <typename T>
 static OpFoldResult foldDivSOp(T op, ArrayRef<Attribute> operands) {
   if (matchPattern(op.rhs(), m_Zero())) {
@@ -521,7 +562,7 @@
     return op.lhs();
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.sdiv(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.sdiv(b); });
 }
 
 OpFoldResult DivI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -546,7 +587,7 @@
     return op.lhs();
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.udiv(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.udiv(b); });
 }
 
 OpFoldResult DivI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -570,7 +611,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.srem(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.srem(b); });
 }
 
 OpFoldResult RemI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -589,7 +630,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.urem(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.urem(b); });
 }
 
 OpFoldResult RemI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -625,8 +666,8 @@
     // x & x = x
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a & b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a & b; });
 }
 
 OpFoldResult AndI32Op::fold(ArrayRef<Attribute> operands) {
@@ -646,8 +687,8 @@
     // x | x = x
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a | b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a | b; });
 }
 
 OpFoldResult OrI32Op::fold(ArrayRef<Attribute> operands) {
@@ -667,8 +708,8 @@
     // x ^ x = 0
     return zeroOfType(op.getType());
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a ^ b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a ^ b; });
 }
 
 OpFoldResult XorI32Op::fold(ArrayRef<Attribute> operands) {
@@ -693,7 +734,7 @@
     return op.operand();
   }
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return a.shl(op.amount()); });
+      operands, [&](const APInt &a) { return a.shl(op.amount()); });
 }
 
 OpFoldResult ShlI32Op::fold(ArrayRef<Attribute> operands) {
@@ -714,7 +755,7 @@
     return op.operand();
   }
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return a.ashr(op.amount()); });
+      operands, [&](const APInt &a) { return a.ashr(op.amount()); });
 }
 
 OpFoldResult ShrI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -735,7 +776,7 @@
     return op.operand();
   }
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return a.lshr(op.amount()); });
+      operands, [&](const APInt &a) { return a.lshr(op.amount()); });
 }
 
 OpFoldResult ShrI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -755,8 +796,9 @@
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldConversionOp(Type resultType, ArrayRef<Attribute> operands,
-                                const CalculationT &calculate) {
+static Attribute constFoldConversionOp(Type resultType,
+                                       ArrayRef<Attribute> operands,
+                                       const CalculationT &calculate) {
   assert(operands.size() == 1 && "unary op takes one operand");
   if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
     return AttrElementT::get(resultType, calculate(operand.getValue()));
@@ -767,101 +809,100 @@
 OpFoldResult TruncI32I8Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(32); });
+      [&](const APInt &a) { return a.trunc(8).zext(32); });
 }
 
 OpFoldResult TruncI32I16Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(32); });
+      [&](const APInt &a) { return a.trunc(16).zext(32); });
 }
 
 OpFoldResult TruncI64I8Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(32); });
+      [&](const APInt &a) { return a.trunc(8).zext(32); });
 }
 
 OpFoldResult TruncI64I16Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(32); });
+      [&](const APInt &a) { return a.trunc(16).zext(32); });
 }
 
 OpFoldResult TruncI64I32Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(32); });
+      [&](const APInt &a) { return a.trunc(32); });
 }
 
 OpFoldResult ExtI8I32SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).sext(32); });
+      [&](const APInt &a) { return a.trunc(8).sext(32); });
 }
 
 OpFoldResult ExtI8I32UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(32); });
+      [&](const APInt &a) { return a.trunc(8).zext(32); });
 }
 
 OpFoldResult ExtI16I32SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).sext(32); });
+      [&](const APInt &a) { return a.trunc(16).sext(32); });
 }
 
 OpFoldResult ExtI16I32UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(32); });
+      [&](const APInt &a) { return a.trunc(16).zext(32); });
 }
 
 OpFoldResult ExtI8I64SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).sext(64); });
+      [&](const APInt &a) { return a.trunc(8).sext(64); });
 }
 
 OpFoldResult ExtI8I64UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(64); });
+      [&](const APInt &a) { return a.trunc(8).zext(64); });
 }
 
 OpFoldResult ExtI16I64SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).sext(64); });
+      [&](const APInt &a) { return a.trunc(16).sext(64); });
 }
 
 OpFoldResult ExtI16I64UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(64); });
+      [&](const APInt &a) { return a.trunc(16).zext(64); });
 }
 
 OpFoldResult ExtI32I64SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.sext(64); });
+      [&](const APInt &a) { return a.sext(64); });
 }
 
 OpFoldResult ExtI32I64UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.zext(64); });
+      [&](const APInt &a) { return a.zext(64); });
 }
 
 namespace {
 
 template <typename SRC_OP, typename OP_A, int SZ_T, typename OP_B>
-class PseudoIntegerConversionToSplitConversionOp
+struct PseudoIntegerConversionToSplitConversionOp
     : public OpRewritePattern<SRC_OP> {
   using OpRewritePattern<SRC_OP>::OpRewritePattern;
 
- public:
   LogicalResult matchAndRewrite(SRC_OP op,
                                 PatternRewriter &rewriter) const override {
     auto tmp = rewriter.createOrFold<OP_A>(
@@ -956,7 +997,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.eq(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.eq(b); });
 }
 
 OpFoldResult CmpEQI32Op::fold(ArrayRef<Attribute> operands) {
@@ -984,7 +1025,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ne(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ne(b); });
 }
 
 OpFoldResult CmpNEI32Op::fold(ArrayRef<Attribute> operands) {
@@ -1032,7 +1073,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.slt(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.slt(b); });
 }
 
 OpFoldResult CmpLTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1056,7 +1097,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ult(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ult(b); });
 }
 
 OpFoldResult CmpLTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1100,7 +1141,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.sle(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.sle(b); });
 }
 
 OpFoldResult CmpLTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1130,7 +1171,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ule(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ule(b); });
 }
 
 OpFoldResult CmpLTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1176,7 +1217,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.sgt(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.sgt(b); });
 }
 
 OpFoldResult CmpGTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1206,7 +1247,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ugt(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ugt(b); });
 }
 
 OpFoldResult CmpGTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1256,7 +1297,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.sge(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.sge(b); });
 }
 
 OpFoldResult CmpGTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1286,7 +1327,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.uge(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.uge(b); });
 }
 
 OpFoldResult CmpGTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1311,12 +1352,12 @@
 
 OpFoldResult CmpNZI32Op::fold(ArrayRef<Attribute> operands) {
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return APInt(32, a.getBoolValue()); });
+      operands, [&](const APInt &a) { return APInt(32, a.getBoolValue()); });
 }
 
 OpFoldResult CmpNZI64Op::fold(ArrayRef<Attribute> operands) {
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return APInt(64, a.getBoolValue()); });
+      operands, [&](const APInt &a) { return APInt(64, a.getBoolValue()); });
 }
 
 OpFoldResult CmpEQRefOp::fold(ArrayRef<Attribute> operands) {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 727cbf0..d613c8b 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1311,6 +1311,7 @@
     VM_BinaryArithmeticOp<I32, "mul.i32", VM_OPC_MulI32, [Commutative]> {
   let summary = [{integer multiplication operation}];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def VM_MulI64Op :
@@ -1318,6 +1319,7 @@
                           [VM_ExtI64, Commutative]> {
   let summary = [{integer multiplication operation}];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def VM_DivI32SOp :
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
index 153389c..38e49b2 100644
--- a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -56,6 +56,41 @@
 
 // -----
 
+// CHECK-LABEL: @add_sub_i32_folds
+vm.module @add_sub_i32_folds {
+  // CHECK-LABEL: @add_sub_x
+  vm.func @add_sub_x(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.add.i32 %arg0, %arg1 : i32
+    %1 = vm.sub.i32 %0, %arg1 : i32
+    vm.return %1 : i32
+  }
+  // CHECK-LABEL: @add_sub_x_rev
+  vm.func @add_sub_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.add.i32 %arg1, %arg0 : i32
+    %1 = vm.sub.i32 %arg1, %0 : i32
+    vm.return %1 : i32
+  }
+
+  // CHECK-LABEL: @sub_add_x
+  vm.func @sub_add_x(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.sub.i32 %arg0, %arg1 : i32
+    %1 = vm.add.i32 %0, %arg1 : i32
+    vm.return %1 : i32
+  }
+  // CHECK-LABEL: @sub_add_x_rev
+  vm.func @sub_add_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.sub.i32 %arg0, %arg1 : i32
+    %1 = vm.add.i32 %arg1, %0 : i32
+    vm.return %1 : i32
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @mul_i32_folds
 vm.module @mul_i32_folds {
   // CHECK-LABEL: @mul_i32_by_0
@@ -96,6 +131,23 @@
 
 // -----
 
+// CHECK-LABEL: @mul_mul_i32_folds
+vm.module @mul_mul_i32_folds {
+  // CHECK-LABEL: @mul_mul_i32_const
+  vm.func @mul_mul_i32_const(%arg0 : i32) -> i32 {
+    // CHECK: %c40 = vm.const.i32 40 : i32
+    %c4 = vm.const.i32 4 : i32
+    %c10 = vm.const.i32 10 : i32
+    // CHECK: %0 = vm.mul.i32 %arg0, %c40 : i32
+    %0 = vm.mul.i32 %arg0, %c4 : i32
+    %1 = vm.mul.i32 %0, %c10 : i32
+    // CHECK-NEXT: vm.return %0 : i32
+    vm.return %1 : i32
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @div_i32_folds
 vm.module @div_i32_folds {
   // CHECK-LABEL: @div_i32_0_y