Adding FMA folding in the VM. (#12382)
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index f088f3d..20b2551 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -607,6 +607,38 @@
return {};
}
+// %0 = vm.mul.f32 %a, %b : f32
+// %1 = vm.add.f32 %0, %c : f32
+// ->
+// %1 = vm.fma.f32 %a, %b, %c : f32
+template <class MulOp, class AddOp, class FMAOp>
+struct FuseFMAOp : public OpRewritePattern<AddOp> {
+ using OpRewritePattern<AddOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(AddOp addOp,
+ PatternRewriter &rewriter) const override {
+ auto fuse = [&](MulOp mulOp, Value a, Value b, Value c) {
+ if (!mulOp->hasOneUse() ||
+ mulOp->isUsedOutsideOfBlock(mulOp->getBlock())) {
+ return failure();
+ }
+ rewriter.replaceOp(
+ addOp,
+ rewriter
+ .create<FMAOp>(rewriter.getFusedLoc({a.getLoc(), c.getLoc()}),
+ a.getType(), a, b, c)
+ .getResult());
+ return success();
+ };
+ if (auto mulOp = dyn_cast_or_null<MulOp>(addOp.getLhs().getDefiningOp())) {
+ return fuse(mulOp, mulOp.getLhs(), mulOp.getRhs(), addOp.getRhs());
+ } else if (auto mulOp =
+ dyn_cast_or_null<MulOp>(addOp.getRhs().getDefiningOp())) {
+ return fuse(mulOp, mulOp.getLhs(), mulOp.getRhs(), addOp.getLhs());
+ }
+ return failure();
+ }
+};
+
template <class AttrElementT, typename ADD, typename SUB,
class ElementValueT = typename AttrElementT::ValueType>
static OpFoldResult foldAddOp(ADD op, Attribute lhs, Attribute rhs) {
@@ -634,11 +666,21 @@
operands.getRhs());
}
+void AddI32Op::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FuseFMAOp<MulI32Op, AddI32Op, FMAI32Op>>(context);
+}
+
OpFoldResult AddI64Op::fold(FoldAdaptor operands) {
return foldAddOp<IntegerAttr, AddI64Op, SubI64Op>(*this, operands.getLhs(),
operands.getRhs());
}
+void AddI64Op::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FuseFMAOp<MulI64Op, AddI64Op, FMAI64Op>>(context);
+}
+
template <class AttrElementT, typename SUB, typename ADD,
class ElementValueT = typename AttrElementT::ValueType>
static OpFoldResult foldSubOp(SUB op, Attribute lhs, Attribute rhs) {
@@ -922,11 +964,23 @@
operands.getRhs());
}
+void AddF32Op::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<FloatAttr, MulF32Op, ConstF32Op>>(
+ context);
+ results.insert<FuseFMAOp<MulF32Op, AddF32Op, FMAF32Op>>(context);
+}
+
OpFoldResult AddF64Op::fold(FoldAdaptor operands) {
return foldAddOp<FloatAttr, AddF64Op, SubF64Op>(*this, operands.getLhs(),
operands.getRhs());
}
+void AddF64Op::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FuseFMAOp<MulF64Op, AddF64Op, FMAF64Op>>(context);
+}
+
OpFoldResult SubF32Op::fold(FoldAdaptor operands) {
return foldSubOp<FloatAttr, SubF32Op, AddF32Op>(*this, operands.getLhs(),
operands.getRhs());
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index 0928a0a..8276dc5 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -2105,12 +2105,14 @@
VM_BinaryArithmeticOp<I32, "add.i32", VM_OPC_AddI32, [Commutative]> {
let summary = [{integer add operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_AddI64Op :
VM_BinaryArithmeticOp<I64, "add.i64", VM_OPC_AddI64, [Commutative]> {
let summary = [{integer add operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_SubI32Op :
@@ -2222,6 +2224,7 @@
[VM_ExtF32, Commutative]> {
let summary = [{floating-point add operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_AddF64Op :
@@ -2229,6 +2232,7 @@
[VM_ExtF64, Commutative]> {
let summary = [{floating-point add operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_SubF32Op :
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
index 708c441..35080c1 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -29,6 +29,55 @@
%0 = vm.add.i32 %c1, %c4 : i32
vm.return %0 : i32
}
+
+ // CHECK-LABEL: @mul_add_i32_lhs
+ vm.func @mul_add_i32_lhs(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %0 = vm.mul.i32 %arg0, %arg1 : i32
+ // CHECK: %[[RET:.+]] = vm.fma.i32 %arg0, %arg1, %arg2 : i32
+ %1 = vm.add.i32 %0, %arg2 : i32
+ // CHECK: return %[[RET]]
+ vm.return %1 : i32
+ }
+
+ // CHECK-LABEL: @mul_add_i32_rhs
+ vm.func @mul_add_i32_rhs(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %0 = vm.mul.i32 %arg0, %arg1 : i32
+ // CHECK: %[[RET:.+]] = vm.fma.i32 %arg0, %arg1, %arg2 : i32
+ %1 = vm.add.i32 %arg2, %0 : i32
+ // CHECK: return %[[RET]]
+ vm.return %1 : i32
+ }
+
+ // Expect this not to fold:
+ // CHECK-LABEL: @mul_add_i32_multiple_users
+ vm.func @mul_add_i32_multiple_users(%arg0: i32, %arg1: i32, %arg2: i32) -> (i32, i32) {
+ // CHECK: vm.mul.i32
+ %0 = vm.mul.i32 %arg0, %arg1 : i32
+ // CHECK-NOT: vm.fma.i32
+ // CHECK-NEXT: vm.add.i32
+ %1 = vm.add.i32 %0, %arg2 : i32
+ // CHECK-NEXT: vm.add.i32
+ %2 = vm.add.i32 %0, %arg1 : i32
+ vm.return %1, %2 : i32, i32
+ }
+
+ // Expect this not to fold:
+ // CHECK-LABEL: @mul_add_i32_dont_sink
+ vm.func @mul_add_i32_dont_sink(%arg0: i32, %arg1: i32, %arg2: i32, %cond: i32) -> i32 {
+ // CHECK: vm.mul.i32
+ %0 = vm.mul.i32 %arg0, %arg1 : i32
+ vm.cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ // CHECK: vm.add.i32
+ %1 = vm.add.i32 %0, %arg2 : i32
+ vm.return %1 : i32
+ ^bb2:
+ // CHECK: vm.add.i32
+ %2 = vm.add.i32 %0, %arg1 : i32
+ // CHECK: vm.div.i32.s
+ %3 = vm.div.i32.s %2, %arg1 : i32
+ vm.return %3 : i32
+ }
}
// -----