Reintroduce RematerializeParallelOpsPass as preprocessing (#14521)
This was removed in #14453. This reintroduces the pass as a
preprocessing option to allow user application on pre-formed dispatches.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 29e4ae8..fd1d90b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -452,9 +452,6 @@
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
- // This is a temporary solution for handling aggressive fusion heuristics.
- // This rematerializes parallel ops into the consumers to avoid stack
- // allocation.
SmallVector<int64_t> allFusableLevels(tilingConfig.getFusableLevels());
// Apply tile and fuse to all the non-distribution fusable levels. Skip
// distribution level as that level has been fused already.
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index ed33f10..c7a6702 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -34,6 +34,7 @@
"PadLinalgOps.cpp",
"PassDetail.h",
"Passes.cpp",
+ "RematerializeParallelOps.cpp",
],
hdrs = [
"Passes.h",
@@ -41,11 +42,14 @@
],
deps = [
":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorUtils",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 4cb8e2a..b0ec7b5 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -30,16 +30,20 @@
"PadLinalgOps.cpp"
"PassDetail.h"
"Passes.cpp"
+ "RematerializeParallelOps.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRArithDialect
+ MLIRFuncDialect
MLIRIR
MLIRLinalgDialect
+ MLIRLinalgTransforms
MLIRPass
MLIRTensorDialect
MLIRTensorUtils
MLIRTransforms
+ iree::compiler::Dialect::Flow::Transforms
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
index 6a9fdf4..01fe504 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
#define IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index 5b4f9f5..bd5d07b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -9,6 +9,7 @@
#include <functional>
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -23,6 +24,10 @@
// A pass to pad linalg ops to the next integer multiple of `paddingSize`.
std::unique_ptr<Pass> createPadLinalgOpsToIntegerMultiplePass();
+/// Pass to merge parallel linalg operations.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createRematerializeParallelOpsPass();
+
//===----------------------------------------------------------------------===//
// Register all Passes
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index ea40299..e063ef7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -26,4 +26,10 @@
];
}
-#endif // IREE_PREPROCESSING_COMMON_PASSES
\ No newline at end of file
+def RematerializeParallelOps :
+ Pass<"iree-preprocessing-rematerialize-parallel-ops", "func::FuncOp"> {
+ let summary = "Pass to rematerialize and merge parallel ops on pre-formed dispatches.";
+ let constructor = "mlir::iree_compiler::IREE::createRematerializeParallelOpsPass()";
+}
+
+#endif // IREE_PREPROCESSING_COMMON_PASSES
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/RematerializeParallelOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/RematerializeParallelOps.cpp
new file mode 100644
index 0000000..be08355
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/RematerializeParallelOps.cpp
@@ -0,0 +1,93 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Preprocessing/Common/PassDetail.h"
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-preprocessing-rematerialize-parallel-ops"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+namespace {
+
+static bool isScalarOrTensorOfSizeOne(Type t) {
+ if (auto tensorType = dyn_cast<RankedTensorType>(t)) {
+ return tensorType.hasStaticShape() && tensorType.getNumElements() == 1;
+ }
+ return t.isIntOrIndexOrFloat();
+}
+
+/// Rematerialize all parallel elementwise operations into its users within a
+/// `flow.dispatch.region`.
+struct RematerializeParallelOpsPattern
+ : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ // Restrict to operations within pre-formed dispatches to avoid blanket
+ // rematerialization over the whole model.
+ if (Flow::isNonNullAndOutsideDispatch(genericOp))
+ return failure();
+
+ // Avoid doing this for scalar operations.
+ auto isScalarValue = [](Value v) {
+ return isScalarOrTensorOfSizeOne(v.getType());
+ };
+ if (llvm::all_of(genericOp.getOperands(), isScalarValue) &&
+ llvm::all_of(genericOp.getResults(), isScalarValue)) {
+ return failure();
+ }
+
+ // Find the first operand that is defined by another generic op on tensors.
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ if (!linalg::areElementwiseOpsFusable(&opOperand))
+ continue;
+
+ FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
+ linalg::fuseElementwiseOps(rewriter, &opOperand);
+ if (succeeded(fusionResult)) {
+ auto replacements = fusionResult->fusedOp->getResults().take_back(
+ genericOp.getNumResults());
+ rewriter.replaceOp(genericOp, replacements);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
+struct RematerializeParallelOpsPass
+ : public RematerializeParallelOpsBase<RematerializeParallelOpsPass> {
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ RewritePatternSet fusionPatterns(funcOp.getContext());
+ fusionPatterns.insert<RematerializeParallelOpsPattern>(funcOp.getContext());
+ linalg::populateEraseUnusedOperandsAndResultsPatterns(fusionPatterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createRematerializeParallelOpsPass() {
+ return std::make_unique<RematerializeParallelOpsPass>();
+}
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 52528c5..9dd501c 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -18,6 +18,7 @@
[
"conv2d_to_img2col.mlir",
"pad_linalg_ops.mlir",
+ "rematerialize_parallel_ops.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index 6c7ac7c..a754aea 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -16,6 +16,7 @@
SRCS
"conv2d_to_img2col.mlir"
"pad_linalg_ops.mlir"
+ "rematerialize_parallel_ops.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/rematerialize_parallel_ops.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/rematerialize_parallel_ops.mlir
new file mode 100644
index 0000000..73ca751
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/rematerialize_parallel_ops.mlir
@@ -0,0 +1,179 @@
+// RUN: iree-opt -iree-preprocessing-rematerialize-parallel-ops %s | FileCheck %s
+
+func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>)
+ -> tensor<1x40960xf32> {
+ %res = flow.dispatch.region -> (tensor<1x40960xf32>) {
+ %2 = tensor.empty() : tensor<1x40960xf32>
+ %cst = arith.constant -3.40282347E+38 : f32
+ %8 = linalg.generic
+ {indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %1 : tensor<1x40960xf32>, tensor<1xf32>)
+ outs(%2 : tensor<1x40960xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %10 = arith.subf %in, %in_2 : f32
+ %11 = math.exp %10 : f32
+ linalg.yield %11 : f32
+ } -> (tensor<1x40960xf32>)
+ %9 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>)
+ outs(%2 : tensor<1x40960xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %10 = arith.divf %cst, %in_2 : f32
+ %11 = arith.mulf %in, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<1x40960xf32>
+ flow.return %9 : tensor<1x40960xf32>
+ }
+ return %res : tensor<1x40960xf32>
+}
+
+
+// CHECK-LABEL: func.func @merged_reduction_parallel
+// CHECK: %{{.+}} = linalg.generic
+// CHECK: arith.subf
+// CHECK-NEXT: math.exp
+// CHECK-NEXT: arith.divf
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: linalg.yield %{{.+}} : f32
+// CHECK: } -> tensor<1x40960xf32>
+
+// -----
+
+func.func @softmax(%7 : tensor<16x32x4096xf32>) -> tensor<16x32x4096xf32> {
+ %res = flow.dispatch.region -> (tensor<16x32x4096xf32>) {
+ %cst = arith.constant -3.40282347E+38 : f32
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %cst_1 = arith.constant 1.000000e+00 : f32
+ %8 = tensor.empty() : tensor<16x32xf32>
+ %6 = tensor.empty() : tensor<16x32x4096xf32>
+ %9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<16x32xf32>) -> tensor<16x32xf32>
+ %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%7 : tensor<16x32x4096xf32>) outs(%9 : tensor<16x32xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %16 = arith.maxf %in, %out : f32
+ linalg.yield %16 : f32
+ } -> tensor<16x32xf32>
+ %11 = tensor.empty() : tensor<16x32x4096xf32>
+ %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %10 : tensor<16x32x4096xf32>, tensor<16x32xf32>) outs(%11 : tensor<16x32x4096xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %16 = arith.subf %in, %in_2 : f32
+ %17 = math.exp %16 : f32
+ linalg.yield %17 : f32
+ } -> tensor<16x32x4096xf32>
+ %13 = linalg.fill ins(%cst_0 : f32) outs(%8 : tensor<16x32xf32>) -> tensor<16x32xf32>
+ %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12 : tensor<16x32x4096xf32>) outs(%13 : tensor<16x32xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %16 = arith.addf %in, %out : f32
+ linalg.yield %16 : f32
+ } -> tensor<16x32xf32>
+ %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12, %14 : tensor<16x32x4096xf32>, tensor<16x32xf32>) outs(%6 : tensor<16x32x4096xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %16 = arith.divf %cst_1, %in_2 : f32
+ %17 = arith.mulf %in, %16 : f32
+ linalg.yield %17 : f32
+ } -> tensor<16x32x4096xf32>
+ flow.return %15 : tensor<16x32x4096xf32>
+ }
+ return %res : tensor<16x32x4096xf32>
+}
+// CHECK: func @softmax(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<16x32x4096xf32>)
+// CHECK-DAG: %[[CST0:.+]] = arith.constant 0.0
+// CHECK: %[[MAXF:.+]] = linalg.generic
+// CHECK-SAME: ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[ARG0]] :
+// CHECK: %[[FILL0:.+]] = linalg.fill
+// CHECK-SAME: ins(%[[CST0]] :
+// CHECK: %[[EXPF:.+]] = linalg.generic
+// CHECK-SAME: ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[ARG0]], %[[MAXF]] :
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[MAXF]], %[[EXPF]] :
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @no_rematerialize_scalar_ops(%arg0 : tensor<f32>) -> tensor<f32> {
+ %res = flow.dispatch.region -> (tensor<f32>) {
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ ins(%arg0: tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %2 = arith.addf %b0, %b0: f32
+ linalg.yield %2: f32
+ } -> tensor<f32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ ins(%1: tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %4 = arith.mulf %b0, %b0: f32
+ linalg.yield %4: f32
+ } -> tensor<f32>
+ %5 = linalg.generic {
+ indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ ins(%1, %3 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %6 = arith.addf %b0, %b1: f32
+ linalg.yield %6: f32
+ } -> tensor<f32>
+ flow.return %5 : tensor<f32>
+ }
+ return %res : tensor<f32>
+}
+// CHECK-LABEL: func @no_rematerialize_scalar_ops(
+// CHECK: linalg.generic
+// CHECK: linalg.generic
+// CHECK: linalg.generic
+
+// -----
+
+func.func @no_rematerialize_non_dispatch(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>)
+ -> tensor<1x40960xf32> {
+ %2 = tensor.empty() : tensor<1x40960xf32>
+ %cst = arith.constant -3.40282347E+38 : f32
+ %8 = linalg.generic
+ {indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %1 : tensor<1x40960xf32>, tensor<1xf32>)
+ outs(%2 : tensor<1x40960xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %10 = arith.subf %in, %in_2 : f32
+ %11 = math.exp %10 : f32
+ linalg.yield %11 : f32
+ } -> (tensor<1x40960xf32>)
+ %9 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>)
+ outs(%2 : tensor<1x40960xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %10 = arith.divf %cst, %in_2 : f32
+ %11 = arith.mulf %in, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<1x40960xf32>
+ return %9 : tensor<1x40960xf32>
+}
+
+
+// CHECK-LABEL: func.func @no_rematerialize_non_dispatch
+// CHECK: linalg.generic
+// CHECK: linalg.generic