Reintroduce RematerializeParallelOpsPass as preprocessing (#14521)

This was removed in #14453. This reintroduces the pass as a
preprocessing option to allow user application on pre-formed dispatches.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 29e4ae8..fd1d90b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -452,9 +452,6 @@
 
   OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
 
-  // This is a temporary solution for handling aggressive fusion heuristics.
-  // This rematerializes parallel ops into the consumers to avoid stack
-  // allocation.
   SmallVector<int64_t> allFusableLevels(tilingConfig.getFusableLevels());
   // Apply tile and fuse to all the non-distribution fusable levels. Skip
   // distribution level as that level has been fused already.
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index ed33f10..c7a6702 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -34,6 +34,7 @@
         "PadLinalgOps.cpp",
         "PassDetail.h",
         "Passes.cpp",
+        "RematerializeParallelOps.cpp",
     ],
     hdrs = [
         "Passes.h",
@@ -41,11 +42,14 @@
     ],
     deps = [
         ":PassesIncGen",
+        "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:DialectUtils",
+        "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgDialect",
+        "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:TensorUtils",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 4cb8e2a..b0ec7b5 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -30,16 +30,20 @@
     "PadLinalgOps.cpp"
     "PassDetail.h"
     "Passes.cpp"
+    "RematerializeParallelOps.cpp"
   DEPS
     ::PassesIncGen
     LLVMSupport
     MLIRArithDialect
+    MLIRFuncDialect
     MLIRIR
     MLIRLinalgDialect
+    MLIRLinalgTransforms
     MLIRPass
     MLIRTensorDialect
     MLIRTensorUtils
     MLIRTransforms
+    iree::compiler::Dialect::Flow::Transforms
   PUBLIC
 )
 
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
index 6a9fdf4..01fe504 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
 #define IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index 5b4f9f5..bd5d07b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -9,6 +9,7 @@
 
 #include <functional>
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 
@@ -23,6 +24,10 @@
 // A pass to pad linalg ops to the next integer multiple of `paddingSize`.
 std::unique_ptr<Pass> createPadLinalgOpsToIntegerMultiplePass();
 
+/// Pass to merge parallel linalg operations.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createRematerializeParallelOpsPass();
+
 //===----------------------------------------------------------------------===//
 // Register all Passes
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index ea40299..e063ef7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -26,4 +26,10 @@
   ];
 }
 
-#endif  // IREE_PREPROCESSING_COMMON_PASSES
\ No newline at end of file
+def RematerializeParallelOps :
+    Pass<"iree-preprocessing-rematerialize-parallel-ops", "func::FuncOp"> {
+  let summary = "Pass to rematerialize and merge parallel ops on pre-formed dispatches.";
+  let constructor = "mlir::iree_compiler::IREE::createRematerializeParallelOpsPass()";
+}
+
+#endif  // IREE_PREPROCESSING_COMMON_PASSES
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/RematerializeParallelOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/RematerializeParallelOps.cpp
new file mode 100644
index 0000000..be08355
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/RematerializeParallelOps.cpp
@@ -0,0 +1,93 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Preprocessing/Common/PassDetail.h"
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-preprocessing-rematerialize-parallel-ops"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+
+namespace {
+
+static bool isScalarOrTensorOfSizeOne(Type t) {
+  if (auto tensorType = dyn_cast<RankedTensorType>(t)) {
+    return tensorType.hasStaticShape() && tensorType.getNumElements() == 1;
+  }
+  return t.isIntOrIndexOrFloat();
+}
+
+/// Rematerialize all parallel elementwise operations into its users within a
+/// `flow.dispatch.region`.
+struct RematerializeParallelOpsPattern
+    : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    // Restrict to operations within pre-formed dispatches to avoid blanket
+    // rematerialization over the whole model.
+    if (Flow::isNonNullAndOutsideDispatch(genericOp))
+      return failure();
+
+    // Avoid doing this for scalar operations.
+    auto isScalarValue = [](Value v) {
+      return isScalarOrTensorOfSizeOne(v.getType());
+    };
+    if (llvm::all_of(genericOp.getOperands(), isScalarValue) &&
+        llvm::all_of(genericOp.getResults(), isScalarValue)) {
+      return failure();
+    }
+
+    // Find the first operand that is defined by another generic op on tensors.
+    for (OpOperand &opOperand : genericOp->getOpOperands()) {
+      if (!linalg::areElementwiseOpsFusable(&opOperand))
+        continue;
+
+      FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
+          linalg::fuseElementwiseOps(rewriter, &opOperand);
+      if (succeeded(fusionResult)) {
+        auto replacements = fusionResult->fusedOp->getResults().take_back(
+            genericOp.getNumResults());
+        rewriter.replaceOp(genericOp, replacements);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
+struct RematerializeParallelOpsPass
+    : public RematerializeParallelOpsBase<RematerializeParallelOpsPass> {
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    RewritePatternSet fusionPatterns(funcOp.getContext());
+    fusionPatterns.insert<RematerializeParallelOpsPattern>(funcOp.getContext());
+    linalg::populateEraseUnusedOperandsAndResultsPatterns(fusionPatterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createRematerializeParallelOpsPass() {
+  return std::make_unique<RematerializeParallelOpsPass>();
+}
+
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 52528c5..9dd501c 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -18,6 +18,7 @@
         [
             "conv2d_to_img2col.mlir",
             "pad_linalg_ops.mlir",
+            "rematerialize_parallel_ops.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index 6c7ac7c..a754aea 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -16,6 +16,7 @@
   SRCS
     "conv2d_to_img2col.mlir"
     "pad_linalg_ops.mlir"
+    "rematerialize_parallel_ops.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/rematerialize_parallel_ops.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/rematerialize_parallel_ops.mlir
new file mode 100644
index 0000000..73ca751
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/rematerialize_parallel_ops.mlir
@@ -0,0 +1,179 @@
+// RUN: iree-opt -iree-preprocessing-rematerialize-parallel-ops %s | FileCheck %s
+
+func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>)
+  -> tensor<1x40960xf32> {
+   %res = flow.dispatch.region -> (tensor<1x40960xf32>) {
+     %2 = tensor.empty() : tensor<1x40960xf32>
+     %cst = arith.constant -3.40282347E+38 : f32
+     %8 = linalg.generic
+     {indexing_maps = [
+         affine_map<(d0, d1) -> (d0, d1)>,
+         affine_map<(d0, d1) -> (d0)>,
+         affine_map<(d0, d1) -> (d0, d1)>],
+         iterator_types = ["parallel", "parallel"]}
+         ins(%0, %1 : tensor<1x40960xf32>, tensor<1xf32>)
+         outs(%2 : tensor<1x40960xf32>) {
+       ^bb0(%in: f32, %in_2: f32, %out: f32):
+         %10 = arith.subf %in, %in_2 : f32
+         %11 = math.exp %10 : f32
+         linalg.yield %11 : f32
+       } -> (tensor<1x40960xf32>)
+     %9 = linalg.generic {
+         indexing_maps = [
+             affine_map<(d0, d1) -> (d0, d1)>,
+             affine_map<(d0, d1) -> (d0)>,
+             affine_map<(d0, d1) -> (d0, d1)>],
+             iterator_types = ["parallel", "parallel"]}
+             ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>)
+             outs(%2 : tensor<1x40960xf32>) {
+       ^bb0(%in: f32, %in_2: f32, %out: f32):
+         %10 = arith.divf %cst, %in_2 : f32
+         %11 = arith.mulf %in, %10 : f32
+         linalg.yield %11 : f32
+       } -> tensor<1x40960xf32>
+     flow.return %9 : tensor<1x40960xf32>
+   }
+   return %res : tensor<1x40960xf32>
+}
+
+
+//   CHECK-LABEL: func.func @merged_reduction_parallel
+//         CHECK:   %{{.+}} = linalg.generic
+//         CHECK:     arith.subf
+//    CHECK-NEXT:     math.exp
+//    CHECK-NEXT:     arith.divf
+//    CHECK-NEXT:     arith.mulf
+//    CHECK-NEXT:     linalg.yield %{{.+}} : f32
+//         CHECK:   } -> tensor<1x40960xf32>
+
+// -----
+
+func.func @softmax(%7 : tensor<16x32x4096xf32>) -> tensor<16x32x4096xf32> {
+  %res = flow.dispatch.region -> (tensor<16x32x4096xf32>) {
+    %cst = arith.constant -3.40282347E+38 : f32
+    %cst_0 = arith.constant 0.000000e+00 : f32
+    %cst_1 = arith.constant 1.000000e+00 : f32
+    %8 = tensor.empty() : tensor<16x32xf32>
+    %6 = tensor.empty() : tensor<16x32x4096xf32>
+    %9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<16x32xf32>) -> tensor<16x32xf32>
+    %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%7 : tensor<16x32x4096xf32>) outs(%9 : tensor<16x32xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %16 = arith.maxf %in, %out : f32
+      linalg.yield %16 : f32
+    } -> tensor<16x32xf32>
+    %11 = tensor.empty() : tensor<16x32x4096xf32>
+    %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %10 : tensor<16x32x4096xf32>, tensor<16x32xf32>) outs(%11 : tensor<16x32x4096xf32>) {
+    ^bb0(%in: f32, %in_2: f32, %out: f32):
+      %16 = arith.subf %in, %in_2 : f32
+      %17 = math.exp %16 : f32
+      linalg.yield %17 : f32
+    } -> tensor<16x32x4096xf32>
+    %13 = linalg.fill ins(%cst_0 : f32) outs(%8 : tensor<16x32xf32>) -> tensor<16x32xf32>
+    %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12 : tensor<16x32x4096xf32>) outs(%13 : tensor<16x32xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %16 = arith.addf %in, %out : f32
+      linalg.yield %16 : f32
+    } -> tensor<16x32xf32>
+    %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12, %14 : tensor<16x32x4096xf32>, tensor<16x32xf32>) outs(%6 : tensor<16x32x4096xf32>) {
+    ^bb0(%in: f32, %in_2: f32, %out: f32):
+      %16 = arith.divf %cst_1, %in_2 : f32
+      %17 = arith.mulf %in, %16 : f32
+      linalg.yield %17 : f32
+    } -> tensor<16x32x4096xf32>
+    flow.return %15 : tensor<16x32x4096xf32>
+  }
+  return %res : tensor<16x32x4096xf32>
+}
+//      CHECK: func @softmax(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<16x32x4096xf32>)
+//  CHECK-DAG:   %[[CST0:.+]] = arith.constant 0.0
+//      CHECK:   %[[MAXF:.+]] = linalg.generic
+// CHECK-SAME:       ["parallel", "parallel", "reduction"]
+// CHECK-SAME:       ins(%[[ARG0]] :
+//      CHECK:   %[[FILL0:.+]] = linalg.fill
+// CHECK-SAME:       ins(%[[CST0]] :
+//      CHECK:   %[[EXPF:.+]] = linalg.generic
+// CHECK-SAME:       ["parallel", "parallel", "reduction"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[MAXF]] :
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:       ["parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[MAXF]], %[[EXPF]] :
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @no_rematerialize_scalar_ops(%arg0 : tensor<f32>) -> tensor<f32> {
+  %res = flow.dispatch.region -> (tensor<f32>) {
+    %0 = tensor.empty() : tensor<f32>
+    %1 = linalg.generic {
+        indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+        iterator_types = []}
+        ins(%arg0: tensor<f32>) outs(%0 : tensor<f32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %2 = arith.addf %b0, %b0: f32
+        linalg.yield %2: f32
+    } -> tensor<f32>
+    %3 = linalg.generic {
+        indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+        iterator_types = []}
+        ins(%1: tensor<f32>) outs(%0 : tensor<f32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %4 = arith.mulf %b0, %b0: f32
+        linalg.yield %4: f32
+    } -> tensor<f32>
+    %5 = linalg.generic {
+        indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>],
+        iterator_types = []}
+        ins(%1, %3 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+        %6 = arith.addf %b0, %b1: f32
+        linalg.yield %6: f32
+    } -> tensor<f32>
+    flow.return %5 : tensor<f32>
+  }
+  return %res : tensor<f32>
+}
+// CHECK-LABEL: func @no_rematerialize_scalar_ops(
+//       CHECK:   linalg.generic
+//       CHECK:   linalg.generic
+//       CHECK:   linalg.generic
+
+// -----
+
+func.func @no_rematerialize_non_dispatch(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>)
+  -> tensor<1x40960xf32> {
+   %2 = tensor.empty() : tensor<1x40960xf32>
+   %cst = arith.constant -3.40282347E+38 : f32
+   %8 = linalg.generic
+   {indexing_maps = [
+       affine_map<(d0, d1) -> (d0, d1)>,
+       affine_map<(d0, d1) -> (d0)>,
+       affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+       ins(%0, %1 : tensor<1x40960xf32>, tensor<1xf32>)
+       outs(%2 : tensor<1x40960xf32>) {
+     ^bb0(%in: f32, %in_2: f32, %out: f32):
+       %10 = arith.subf %in, %in_2 : f32
+       %11 = math.exp %10 : f32
+       linalg.yield %11 : f32
+     } -> (tensor<1x40960xf32>)
+   %9 = linalg.generic {
+       indexing_maps = [
+           affine_map<(d0, d1) -> (d0, d1)>,
+           affine_map<(d0, d1) -> (d0)>,
+           affine_map<(d0, d1) -> (d0, d1)>],
+           iterator_types = ["parallel", "parallel"]}
+           ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>)
+           outs(%2 : tensor<1x40960xf32>) {
+     ^bb0(%in: f32, %in_2: f32, %out: f32):
+       %10 = arith.divf %cst, %in_2 : f32
+       %11 = arith.mulf %in, %10 : f32
+       linalg.yield %11 : f32
+     } -> tensor<1x40960xf32>
+   return %9 : tensor<1x40960xf32>
+}
+
+
+//   CHECK-LABEL: func.func @no_rematerialize_non_dispatch
+//         CHECK:   linalg.generic
+//         CHECK:   linalg.generic