Adding folders for mul/mul of constants.
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index d7217e5..0f97aae 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -507,14 +507,47 @@
operands, [](const APInt &a, const APInt &b) { return a * b; });
}
+template <typename T, typename CONST_OP>
+struct FoldConstantMulOperand : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ IntegerAttr c1, c2;
+ if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure();
+ if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) {
+ if (matchPattern(mulOp.rhs(), m_Constant(&c2))) {
+ auto c = rewriter.createOrFold<CONST_OP>(
+ FusedLoc::get({mulOp.getLoc(), op.getLoc()}, rewriter.getContext()),
+ constFoldBinaryOp<IntegerAttr>(
+ {c1, c2},
+ [](const APInt &a, const APInt &b) { return a * b; }));
+ rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) {
return foldMulOp(*this, operands);
}
+void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context);
+}
+
OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) {
return foldMulOp(*this, operands);
}
+void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context);
+}
+
template <typename T>
static OpFoldResult foldDivSOp(T op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 727cbf0..d613c8b 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1311,6 +1311,7 @@
VM_BinaryArithmeticOp<I32, "mul.i32", VM_OPC_MulI32, [Commutative]> {
let summary = [{integer multiplication operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_MulI64Op :
@@ -1318,6 +1319,7 @@
[VM_ExtI64, Commutative]> {
let summary = [{integer multiplication operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_DivI32SOp :
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
index 130466f..38e49b2 100644
--- a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -131,6 +131,23 @@
// -----
+// CHECK-LABEL: @mul_mul_i32_folds
+vm.module @mul_mul_i32_folds {
+ // CHECK-LABEL: @mul_mul_i32_const
+ vm.func @mul_mul_i32_const(%arg0 : i32) -> i32 {
+ // CHECK: %c40 = vm.const.i32 40 : i32
+ %c4 = vm.const.i32 4 : i32
+ %c10 = vm.const.i32 10 : i32
+ // CHECK: %0 = vm.mul.i32 %arg0, %c40 : i32
+ %0 = vm.mul.i32 %arg0, %c4 : i32
+ %1 = vm.mul.i32 %0, %c10 : i32
+ // CHECK-NEXT: vm.return %0 : i32
+ vm.return %1 : i32
+ }
+}
+
+// -----
+
// CHECK-LABEL: @div_i32_folds
vm.module @div_i32_folds {
// CHECK-LABEL: @div_i32_0_y