Add pass to fuse Linalg ops on tensors.

Adding a pass to fuse linalg operations within dispatch regions to the
XLA->Linalg->SPIR-V path. The fusion utility was added to MLIR
recently.

PiperOrigin-RevId: 295223174
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
index d787c2e..9677f61 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -20,12 +20,15 @@
 cc_library(
     name = "LinalgToSPIRV",
     srcs = [
+        "LinalgFusion.cpp",
         "LowerToSPIRV.cpp",
     ],
     hdrs = [
         "LowerToSPIRV.h",
+        "Passes.h",
     ],
     deps = [
+        "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Translation/XLAToLinalg:IREELinalgTensorToBuffer",
         "//iree/compiler/Utils",
         "@llvm-project//llvm:support",
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
index 00f9f3a..f420484 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -21,9 +21,12 @@
     LinalgToSPIRV
   HDRS
     "LowerToSPIRV.h"
+    "Passes.h"
   SRCS
+    "LinalgFusion.cpp"
     "LowerToSPIRV.cpp"
   DEPS
+    iree::compiler::Dialect::IREE::IR
     iree::compiler::Translation::XLAToLinalg::IREELinalgTensorToBuffer
     iree::compiler::Utils
     LLVMSupport
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgFusion.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgFusion.cpp
new file mode 100644
index 0000000..795e3ab
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgFusion.cpp
@@ -0,0 +1,85 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LinalgFusion.cpp - Fuse Linalg operations within a dispatch region--===//
+//
+// Fuses all Linalg operations with a dispatch region into a single linalg
+// operation.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Pattern to implement the fusion. Only fuses op with its producer if the
+/// latter has a single use (this op).
+// TODO(ravishankarm): Generalize this to handle more valid fusion cases.
+struct IREEFuseGenericTensorOps : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+  PatternMatchResult matchAndRewrite(linalg::GenericOp op,
+                                     PatternRewriter &rewriter) const override;
+};
+
+/// Fuses linalg operations on tensors in dispatch function. For now does only
+/// producer consumer fusion.
+struct IREELinalgFusionPass : public FunctionPass<IREELinalgFusionPass> {
+  void runOnFunction() override;
+};
+}  // namespace
+
+PatternMatchResult IREEFuseGenericTensorOps::matchAndRewrite(
+    linalg::GenericOp op, PatternRewriter &rewriter) const {
+  if (!op.hasTensorSemantics()) return matchFailure();
+  for (unsigned i = 0, e = op.getOperation()->getNumOperands(); i != e; ++i) {
+    auto producerOp = dyn_cast_or_null<linalg::LinalgOp>(
+        op.getOperation()->getOperand(i).getDefiningOp());
+    if (!producerOp || producerOp.getOperation()->getNumResults() != 1)
+      continue;
+    bool isDeadIfUsed = producerOp.getOperation()->getResult(0).hasOneUse();
+    if (Optional<linalg::LinalgOp> fusedOp = linalg::fuseTensorOps(
+            rewriter, producerOp, cast<linalg::LinalgOp>(op.getOperation()),
+            i)) {
+      rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
+      if (isDeadIfUsed) rewriter.eraseOp(producerOp);
+      return matchSuccess();
+    }
+  }
+  return matchFailure();
+}
+
+void IREELinalgFusionPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  Operation *op = getOperation();
+  patterns.insert<IREEFuseGenericTensorOps>(op->getContext());
+  applyPatternsGreedily(op->getRegions(), patterns);
+}
+
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass() {
+  return std::make_unique<IREELinalgFusionPass>();
+}
+
+static PassRegistration<IREELinalgFusionPass> pass(
+    "iree-linalg-fusion", "Fuse Linalg operations within a dispatch region");
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
index ad38411..dfe6090 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
@@ -18,6 +18,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h"
 #include "iree/compiler/Utils/IREECodegenUtils.h"
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
@@ -332,7 +333,7 @@
         workGroupSize[1] = 2;
         workGroupSize[2] = 2;
       }
-      // TODO(ravishankarm) : The current code-generation will "serialize" all
+      // TODO(ravishankarm): The current code-generation will "serialize" all
       // the inner loops that are more than 3 deep. We can potentially "fold"
       // all the parallel loops so that they all executed on different
       // workitems.
@@ -376,6 +377,7 @@
 
 void addLowerToSPIRVPasses(OpPassManager &pm, ArrayRef<int64_t> workGroupSize) {
   pm.addPass(xla_hlo::createLegalizeHloToLinalgPass());
+  pm.addPass(createLinalgFusionPass());
   pm.addPass(createLinalgTensorToBufferConversionPass());
   pm.addPass(std::make_unique<UpdateWorkGroupSizePass>(workGroupSize));
   addLinalgToSPIRVPasses(pm);
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
new file mode 100644
index 0000000..bd265d0
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h
@@ -0,0 +1,40 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Passes.h - IREE specific passes used in Linalg To SPIRV conversion--===//
+//
+// IREE specific passes used in the XLA -> Linalg -> SPIRV Conversion.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_PASSES_H
+#define IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_PASSES_H
+
+#include <memory>
+
+namespace mlir {
+
+class FuncOp;
+template <typename OpTy>
+class OpPassBase;
+
+namespace iree_compiler {
+
+/// Fuses linalg operations on tensors in dispatch function. For now does only
+/// producer consumer fusion.
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass();
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_PASSES_H
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
index e5f8e17..8e5e9d3 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
@@ -25,6 +25,8 @@
     name = "lit",
     srcs = glob(
         ["*.mlir"],
+        # TODO(b/149270254) : Enable test after diagnosing flakiness.
+        exclude = ["pw_add_mul.mlir"],
     ),
     data = [
         "//iree/tools:IreeFileCheck",
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add_mul.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add_mul.mlir
new file mode 100644
index 0000000..803c8ce
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_add_mul.mlir
@@ -0,0 +1,11 @@
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv -iree-use-linalg-to-spirv-path -input-value="4x8xi32=[[1 2 3 4 5 6 7 8][9 10 11 12 13 14 15 16][17 18 19 20 21 22 23 24][25 26 27 28 29 30 31 32]]" -input-value="4x8xi32=[[2 4 6 8 10 12 14 16][18 20 22 24 26 28 30 32][34 36 38 40 42 44 46 48][50 52 54 56 58 60 62 64]]" -input-value="4x8xi32=[[3 6 9 12 15 18 21 24][27 30 33 36 39 42 45 48][51 54 57 60 63 66 69 72][75 78 81 84 87 90 93 96]]" %s | IreeFileCheck %s)
+
+// CHECK-LABEL: EXEC @pw_add_mul
+// CHECK: 4x8xi32=[5 14 27 44 65 90 119 152][189 230 275 324 377 434 495 560][629 702 779 860 945 1034 1127 1224][1325 1430 1539 1652 1769 1890 2015 2144]
+module {
+  func @pw_add_mul(%arg0: tensor<4x8xi32>, %arg1: tensor<4x8xi32>, %arg2 : tensor<4x8xi32>) -> tensor<4x8xi32> {
+    %0 = "xla_hlo.mul"(%arg0, %arg1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+    %1 = "xla_hlo.add"(%0, %arg2) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+    return %1 : tensor<4x8xi32>
+  }
+}
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_linalg_fusion.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_linalg_fusion.mlir
new file mode 100644
index 0000000..7cd6de6
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/pw_linalg_fusion.mlir
@@ -0,0 +1,98 @@
+// RUN: iree-opt -split-input-file -hlo-legalize-to-linalg -iree-linalg-fusion %s | IreeFileCheck %s
+
+// CHECK-LABEL: @pw_fusion_two
+func @pw_fusion_two(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>, %arg2 : memref<4x8xi32>, %arg3: memref<4x8xi32>)
+attributes { iree.executable.export, iree.executable.workgroup_size = dense<[32, 8, 1]> : tensor<3xi32>, iree.executable.workload = dense<[8, 4, 1]> : tensor<3xi32>} {
+  %0 = iree.load_input(%arg0 : memref<4x8xi32>) : tensor<4x8xi32>
+  %1 = iree.load_input(%arg1 : memref<4x8xi32>) : tensor<4x8xi32>
+  %2 = iree.load_input(%arg2 : memref<4x8xi32>) : tensor<4x8xi32>
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-zA-Z0-9_]*}}
+  // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: i32
+  // CHECK: [[TEMP:%[a-zA-Z0-9_]*]] = muli [[ARG0]], [[ARG1]]
+  // CHECK: addi [[TEMP]], [[ARG2]]
+  // CHECK-NOT: linalg.generic
+  %4 = "xla_hlo.mul"(%0, %1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %5 = "xla_hlo.add"(%4, %2) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  iree.store_output(%5 : tensor<4x8xi32>, %arg3 : memref<4x8xi32>)
+  iree.return
+}
+
+// -----
+
+// CHECK-LABEL: @pw_fusion_three
+func @pw_fusion_three(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>, %arg2 : memref<4x8xi32>, %arg3: memref<4x8xi32>, %arg4: memref<4x8xi32>)
+attributes { iree.executable.export, iree.executable.workgroup_size = dense<[32, 8, 1]> : tensor<3xi32>, iree.executable.workload = dense<[8, 4, 1]> : tensor<3xi32>} {
+  %0 = iree.load_input(%arg0 : memref<4x8xi32>) : tensor<4x8xi32>
+  %1 = iree.load_input(%arg1 : memref<4x8xi32>) : tensor<4x8xi32>
+  %2 = iree.load_input(%arg2 : memref<4x8xi32>) : tensor<4x8xi32>
+  %3 = iree.load_input(%arg3 : memref<4x8xi32>) : tensor<4x8xi32>
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-zA-Z0-9_]*}}
+  // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]]: i32
+  // CHECK: [[TEMP1:%[a-zA-Z0-9_]*]] = muli [[ARG0]], [[ARG1]]
+  // CHECK: [[TEMP2:%[a-zA-Z0-9_]*]] = addi [[TEMP1]], [[ARG2]]
+  // CHECK: subi [[TEMP2]], [[ARG3]]
+  // CHECK-NOT: linalg.generic
+  %4 = "xla_hlo.mul"(%0, %1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %5 = "xla_hlo.add"(%4, %2) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %6 = "xla_hlo.sub"(%5, %3) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  iree.store_output(%6 : tensor<4x8xi32>, %arg4 : memref<4x8xi32>)
+  iree.return
+}
+
+// -----
+
+// CHECK-LABEL: @pw_fusion_dag
+func @pw_fusion_dag(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>, %arg2 : memref<4x8xi32>, %arg3: memref<4x8xi32>, %arg4: memref<4x8xi32>)
+attributes { iree.executable.export, iree.executable.workgroup_size = dense<[32, 8, 1]> : tensor<3xi32>, iree.executable.workload = dense<[8, 4, 1]> : tensor<3xi32>} {
+  %0 = iree.load_input(%arg0 : memref<4x8xi32>) : tensor<4x8xi32>
+  %1 = iree.load_input(%arg1 : memref<4x8xi32>) : tensor<4x8xi32>
+  %2 = iree.load_input(%arg2 : memref<4x8xi32>) : tensor<4x8xi32>
+  %3 = iree.load_input(%arg3 : memref<4x8xi32>) : tensor<4x8xi32>
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-zA-Z0-9_]*}}
+  // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-DAG: [[TEMP1:%[a-zA-Z0-9_]*]] = muli [[ARG0]], [[ARG1]]
+  // CHECK-DAG: [[TEMP2:%[a-zA-Z0-9_]*]] = addi [[ARG2]], [[ARG3]]
+  // CHECK: subi [[TEMP1]], [[TEMP2]]
+  // CHECK-NOT: linalg.generic
+  %4 = "xla_hlo.mul"(%0, %1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %5 = "xla_hlo.add"(%2, %3) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %6 = "xla_hlo.sub"(%4, %5) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  iree.store_output(%6 : tensor<4x8xi32>, %arg4 : memref<4x8xi32>)
+  iree.return
+}
+
+// -----
+
+// CHECK-LABEL: @pw_fusion_dag2
+func @pw_fusion_dag2(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>, %arg2 : memref<4x8xi32>, %arg3: memref<4x8xi32>)
+attributes { iree.executable.export, iree.executable.workgroup_size = dense<[32, 8, 1]> : tensor<3xi32>, iree.executable.workload = dense<[8, 4, 1]> : tensor<3xi32>} {
+  %0 = iree.load_input(%arg0 : memref<4x8xi32>) : tensor<4x8xi32>
+  %1 = iree.load_input(%arg1 : memref<4x8xi32>) : tensor<4x8xi32>
+  %2 = iree.load_input(%arg2 : memref<4x8xi32>) : tensor<4x8xi32>
+  // CHECK: linalg.generic
+  // CHECK: ^{{[a-zA-Z0-9_]*}}
+  // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]]: i32
+  // CHECK-DAG: [[TEMP1:%[a-zA-Z0-9_]*]] = muli [[ARG0]], [[ARG1]]
+  // CHECK-DAG: [[TEMP2:%[a-zA-Z0-9_]*]] = addi [[ARG2]], [[ARG3]]
+  // CHECK: subi [[TEMP1]], [[TEMP2]]
+  // CHECK-NOT: linalg.generic
+  %3 = "xla_hlo.mul"(%0, %1) : (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %4 = "xla_hlo.add"(%0, %2) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  %5 = "xla_hlo.sub"(%3, %4) :  (tensor<4x8xi32>, tensor<4x8xi32>) -> tensor<4x8xi32>
+  iree.store_output(%5 : tensor<4x8xi32>, %arg3 : memref<4x8xi32>)
+  iree.return
+}