Adding FMA folding in the VM. (#12382)
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index f088f3d..20b2551 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -607,6 +607,38 @@ return {}; } +// %0 = vm.mul.f32 %a, %b : f32 +// %1 = vm.add.f32 %0, %c : f32 +// -> +// %1 = vm.fma.f32 %a, %b, %c : f32 +template <class MulOp, class AddOp, class FMAOp> +struct FuseFMAOp : public OpRewritePattern<AddOp> { + using OpRewritePattern<AddOp>::OpRewritePattern; + LogicalResult matchAndRewrite(AddOp addOp, + PatternRewriter &rewriter) const override { + auto fuse = [&](MulOp mulOp, Value a, Value b, Value c) { + if (!mulOp->hasOneUse() || + mulOp->isUsedOutsideOfBlock(mulOp->getBlock())) { + return failure(); + } + rewriter.replaceOp( + addOp, + rewriter + .create<FMAOp>(rewriter.getFusedLoc({a.getLoc(), c.getLoc()}), + a.getType(), a, b, c) + .getResult()); + return success(); + }; + if (auto mulOp = dyn_cast_or_null<MulOp>(addOp.getLhs().getDefiningOp())) { + return fuse(mulOp, mulOp.getLhs(), mulOp.getRhs(), addOp.getRhs()); + } else if (auto mulOp = + dyn_cast_or_null<MulOp>(addOp.getRhs().getDefiningOp())) { + return fuse(mulOp, mulOp.getLhs(), mulOp.getRhs(), addOp.getLhs()); + } + return failure(); + } +}; + template <class AttrElementT, typename ADD, typename SUB, class ElementValueT = typename AttrElementT::ValueType> static OpFoldResult foldAddOp(ADD op, Attribute lhs, Attribute rhs) { @@ -634,11 +666,21 @@ operands.getRhs()); } +void AddI32Op::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert<FuseFMAOp<MulI32Op, AddI32Op, FMAI32Op>>(context); +} + OpFoldResult AddI64Op::fold(FoldAdaptor operands) { return foldAddOp<IntegerAttr, AddI64Op, SubI64Op>(*this, operands.getLhs(), operands.getRhs()); } +void AddI64Op::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert<FuseFMAOp<MulI64Op, AddI64Op, FMAI64Op>>(context); +} + template <class AttrElementT, typename SUB, typename ADD, class ElementValueT = typename AttrElementT::ValueType> static OpFoldResult foldSubOp(SUB op, Attribute lhs, Attribute rhs) { @@ -922,11 +964,23 @@ operands.getRhs()); } +void AddF32Op::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert<FoldConstantMulOperand<FloatAttr, MulF32Op, ConstF32Op>>( + context); + results.insert<FuseFMAOp<MulF32Op, AddF32Op, FMAF32Op>>(context); +} + OpFoldResult AddF64Op::fold(FoldAdaptor operands) { return foldAddOp<FloatAttr, AddF64Op, SubF64Op>(*this, operands.getLhs(), operands.getRhs()); } +void AddF64Op::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert<FuseFMAOp<MulF64Op, AddF64Op, FMAF64Op>>(context); +} + OpFoldResult SubF32Op::fold(FoldAdaptor operands) { return foldSubOp<FloatAttr, SubF32Op, AddF32Op>(*this, operands.getLhs(), operands.getRhs());
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index 0928a0a..8276dc5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -2105,12 +2105,14 @@ VM_BinaryArithmeticOp<I32, "add.i32", VM_OPC_AddI32, [Commutative]> { let summary = [{integer add operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } def VM_AddI64Op : VM_BinaryArithmeticOp<I64, "add.i64", VM_OPC_AddI64, [Commutative]> { let summary = [{integer add operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } def VM_SubI32Op : @@ -2222,6 +2224,7 @@ [VM_ExtF32, Commutative]> { let summary = [{floating-point add operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } def VM_AddF64Op : @@ -2229,6 +2232,7 @@ [VM_ExtF64, Commutative]> { let summary = [{floating-point add operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } def VM_SubF32Op :
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir index 708c441..35080c1 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -29,6 +29,55 @@ %0 = vm.add.i32 %c1, %c4 : i32 vm.return %0 : i32 } + + // CHECK-LABEL: @mul_add_i32_lhs + vm.func @mul_add_i32_lhs(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + %0 = vm.mul.i32 %arg0, %arg1 : i32 + // CHECK: %[[RET:.+]] = vm.fma.i32 %arg0, %arg1, %arg2 : i32 + %1 = vm.add.i32 %0, %arg2 : i32 + // CHECK: return %[[RET]] + vm.return %1 : i32 + } + + // CHECK-LABEL: @mul_add_i32_rhs + vm.func @mul_add_i32_rhs(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + %0 = vm.mul.i32 %arg0, %arg1 : i32 + // CHECK: %[[RET:.+]] = vm.fma.i32 %arg0, %arg1, %arg2 : i32 + %1 = vm.add.i32 %arg2, %0 : i32 + // CHECK: return %[[RET]] + vm.return %1 : i32 + } + + // Expect this not to fold: + // CHECK-LABEL: @mul_add_i32_multiple_users + vm.func @mul_add_i32_multiple_users(%arg0: i32, %arg1: i32, %arg2: i32) -> (i32, i32) { + // CHECK: vm.mul.i32 + %0 = vm.mul.i32 %arg0, %arg1 : i32 + // CHECK-NOT: vm.fma.i32 + // CHECK-NEXT: vm.add.i32 + %1 = vm.add.i32 %0, %arg2 : i32 + // CHECK-NEXT: vm.add.i32 + %2 = vm.add.i32 %0, %arg1 : i32 + vm.return %1, %2 : i32, i32 + } + + // Expect this not to fold: + // CHECK-LABEL: @mul_add_i32_dont_sink + vm.func @mul_add_i32_dont_sink(%arg0: i32, %arg1: i32, %arg2: i32, %cond: i32) -> i32 { + // CHECK: vm.mul.i32 + %0 = vm.mul.i32 %arg0, %arg1 : i32 + vm.cond_br %cond, ^bb1, ^bb2 + ^bb1: + // CHECK: vm.add.i32 + %1 = vm.add.i32 %0, %arg2 : i32 + vm.return %1 : i32 + ^bb2: + // CHECK: vm.add.i32 + %2 = vm.add.i32 %0, %arg1 : i32 + // CHECK: vm.div.i32.s + %3 = vm.div.i32.s %2, %arg1 : i32 + vm.return %3 : i32 + } } // -----