Add a pass to canonicalize ops used as loop carried values by scf.for (#3616)

This is an ad hoc pass to clean up some unnecessary operations added
during progressing lowering. MLIR canonicalization currently cannot
clean it up, this is a temporary solution to allow those to be removed
on IREE side.
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index 067383d..4b58a54 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -23,12 +23,14 @@
 cc_library(
     name = "CodegenUtils",
     srcs = [
+        "ForOpCanonicalization.cpp",
         "FunctionUtils.cpp",
         "GetNumWorkgroups.cpp",
         "MarkerUtils.cpp",
         "MatmulCodegenStrategy.cpp",
     ],
     hdrs = [
+        "ForOpCanonicalization.h",
         "FunctionUtils.h",
         "GetNumWorkgroups.h",
         "MarkerUtils.h",
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index 86fe0c7..bf8ddd7 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -18,11 +18,13 @@
   NAME
     CodegenUtils
   HDRS
+    "ForOpCanonicalization.h"
     "FunctionUtils.h"
     "GetNumWorkgroups.h"
     "MarkerUtils.h"
     "MatmulCodegenStrategy.h"
   SRCS
+    "ForOpCanonicalization.cpp"
     "FunctionUtils.cpp"
     "GetNumWorkgroups.cpp"
     "MarkerUtils.cpp"
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
new file mode 100644
index 0000000..c7f2134
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
@@ -0,0 +1,155 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+// Pass to combine instructions across ForOp boundary. It is common when doing
+// incremental lowering to generate transient ops that cancel each others out.
+// Canonicalization usually clean up those operations. When the value is loop
+// carried, MLIR canonicalization currently doesn't remove the redundant
+// operations.
+//
+// This pass allow to workaround MLIR limitation and does ad hoc clean up of
+// instructions found in IREE. Once we have a more general mechanism in MLIR
+// this pass can be completely removed.
+// This pass does this kind of transformation:
+// ```
+// %21 = vector.shape_cast %20 : vector<4xf32> to vector<1x4xf32>
+// %22 = scf.for %arg3 = %c0 to %c4096 step %c4 iter_args(%arg4 = %21)
+//    -> vector<1x4xf32> {
+//    [...]
+//    %100 = vector.shape_cast %arg4 : vector<1x4xf32> to vector<4xf32>
+//    [...]
+//    %109 = vector.shape_cast %108 : vector<4xf32> to vector<1x4xf32>
+//    scf.yield %109 : vector<1x4xf32>
+//  }
+//  %24 = vector.shape_cast %22 : vector<1x4xf32> to vector<4xf32>
+// ```
+// ->
+// ```
+// %22 = scf.for %arg3 = %c0 to %c4096 step %c4 iter_args(%arg4 = %20)
+//    -> vector<4xf32> {
+//    [...]
+//    scf.yield %108 : vector<4xf32>
+//  }
+// ```
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+class ForOpArgFolding final : public OpRewritePattern<scf::ForOp> {
+ public:
+  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+  Value FoldCarryDep(scf::ForOp forOp, Operation* ivUser,
+                     Operation* ivDef) const {
+    if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(ivUser)) {
+      if (auto souceOp = dyn_cast<vector::ShapeCastOp>(ivDef)) {
+        if (shapeCast.getType() == souceOp.source().getType())
+          return souceOp.source();
+      }
+    } else if (auto extractOp = dyn_cast<vector::ExtractOp>(ivUser)) {
+      if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(ivDef)) {
+        if (extractOp.getType() == broadcastOp.getSourceType())
+          return broadcastOp.source();
+      }
+    }
+    return Value();
+  }
+
+  void transferBody(Block* source, Block* dest, ArrayRef<Value> results,
+                    PatternRewriter& rewriter) const {
+    // Move all operations to the destination block.
+    rewriter.mergeBlocks(source, dest, dest->getArguments());
+    // Replace the yield op by one that returns only the used values.
+    auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
+    rewriter.updateRootInPlace(
+        yieldOp, [&]() { yieldOp.getOperation()->setOperands(results); });
+  }
+
+  LogicalResult matchAndRewrite(scf::ForOp forOp,
+                                PatternRewriter& rewriter) const override {
+    SmallVector<unsigned, 8> iteratorFolded;
+    SmallVector<Operation*, 8> resultOps;
+    auto terminator = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+    auto returnValues = llvm::to_vector<8>(terminator.getOperands());
+    auto initArgs = llvm::to_vector<8>(forOp.getIterOperands());
+    for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
+      if (!it.value().hasOneUse()) continue;
+      Operation* op = it.value().use_begin()->getOwner();
+      if (!isa<vector::ShapeCastOp, vector::ExtractOp>(op)) continue;
+      Operation* returnValDef = returnValues[it.index()].getDefiningOp();
+      Value newReturn = FoldCarryDep(forOp, op, returnValDef);
+      if (!newReturn) continue;
+      iteratorFolded.push_back(it.index());
+      resultOps.push_back(returnValDef);
+      returnValues[it.index()] = newReturn;
+
+      BlockAndValueMapping mapping;
+      mapping.map(op->getOperand(0), initArgs[it.index()]);
+      initArgs[it.index()] = rewriter.clone(*op, mapping)->getResult(0);
+    }
+    if (iteratorFolded.empty()) return success();
+    auto newLoop =
+        rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.lowerBound(),
+                                    forOp.upperBound(), forOp.step(), initArgs);
+    transferBody(forOp.getBody(), newLoop.getBody(), returnValues, rewriter);
+
+    // Replace the operation by the new one.
+    SmallVector<Value, 8> repResults(newLoop.getResults().begin(),
+                                     newLoop.getResults().end());
+    for (auto en : llvm::enumerate(iteratorFolded)) {
+      BlockAndValueMapping mapping;
+      mapping.map(returnValues[en.value()], newLoop.getResult(en.value()));
+      repResults[en.index()] =
+          rewriter.clone(*resultOps[en.index()], mapping)->getResult(0);
+      Operation* oldOp =
+          newLoop.getRegionIterArgs()[en.index()].use_begin()->getOwner();
+      SmallVector<Value, 1> arg(1, newLoop.getRegionIterArgs()[en.index()]);
+      oldOp->replaceAllUsesWith(arg);
+    }
+    rewriter.replaceOp(forOp, repResults);
+    return success();
+  }
+};
+
+struct ForOpCanonicalizationPass
+    : PassWrapper<ForOpCanonicalizationPass, FunctionPass> {
+  void runOnFunction() override {
+    FuncOp fn = getFunction();
+    OwningRewritePatternList patterns;
+    patterns.insert<ForOpArgFolding>(fn.getContext());
+    applyPatternsAndFoldGreedily(fn, patterns);
+  }
+};
+}  // namespace
+
+std::unique_ptr<FunctionPass> createForOpCanonicalizationPass() {
+  return std::make_unique<ForOpCanonicalizationPass>();
+}
+
+static PassRegistration<ForOpCanonicalizationPass> pass(
+    "iree-codegen-canonicalize-scf-for",
+    "An ad-hoc pass to canonicalize selected loop carried dependencies on "
+    "scf.for",
+    [] { return std::make_unique<ForOpCanonicalizationPass>(); });
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
new file mode 100644
index 0000000..ddda183
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_FOROPCANONICALIZATION_H_
+#define IREE_COMPILER_CONVERSION_CODEGENUTILS_FOROPCANONICALIZATION_H_
+
+#include <memory>
+
+namespace mlir {
+class FunctionPass;
+namespace iree_compiler {
+
+/// An ad-hoc pass to canonicalize selected loop carried dependencies on
+/// scf.for.
+std::unique_ptr<FunctionPass> createForOpCanonicalizationPass();
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CONVERSION_CODEGENUTILS_FOROPCANONICALIZATION_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/forop_canonicalization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/forop_canonicalization.mlir
new file mode 100644
index 0000000..b44b89a
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/forop_canonicalization.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-opt %s -iree-codegen-canonicalize-scf-for | FileCheck %s
+
+func @loop_carried_cast(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0 = vector.shape_cast %arg0 : vector<4xf32> to vector<1x4xf32>
+  %1 = vector.shape_cast %arg1 : vector<4xf32> to vector<1x4xf32>
+  %20:2 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %0, %arg5 = %1) -> (vector<1x4xf32>, vector<1x4xf32>) {
+    %a = vector.shape_cast %arg4 : vector<1x4xf32> to vector<4xf32>
+    %b = vector.shape_cast %arg5 : vector<1x4xf32> to vector<4xf32>
+    %c = addf %a, %b : vector<4xf32>
+    %d = mulf %a, %b : vector<4xf32>
+    %cc = vector.shape_cast %c : vector<4xf32> to vector<1x4xf32>
+    %dc = vector.shape_cast %d : vector<4xf32> to vector<1x4xf32>
+    scf.yield %cc, %dc : vector<1x4xf32>, vector<1x4xf32>
+  }
+  %21 = vector.shape_cast %20#0 : vector<1x4xf32> to vector<4xf32>
+  %22 = vector.shape_cast %20#1 : vector<1x4xf32> to vector<4xf32>
+  return %21, %22 : vector<4xf32>, vector<4xf32>
+}
+
+// CHECK-LABEL:   func @loop_carried_cast
+//   CHECK-NOT:     vector.shape_cast
+//       CHECK:     scf.for {{.*}} -> (vector<4xf32>, vector<4xf32>) {
+//   CHECK-NOT:       vector.shape_cast
+//       CHECK:       scf.yield {{.*}}, {{.*}} : vector<4xf32>, vector<4xf32>
+//       CHECK:     }
+//   CHECK-NOT:     vector.shape_cast
+//       CHECK:     return {{.*}}, {{.*}} : vector<4xf32>, vector<4xf32>
+
+func @loop_carried_extract(%arg0: f32) -> f32 {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
+  %20 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %0) -> (vector<4xf32>) {
+    %a = vector.extract %arg4[0] : vector<4xf32>
+    %c = addf %a, %a : f32
+    %bc = vector.broadcast %c : f32 to vector<4xf32>
+    scf.yield %bc : vector<4xf32>
+  }
+  %21 = vector.extract %20[0] : vector<4xf32>
+  return %21 : f32
+}
+
+// CHECK-LABEL:   func @loop_carried_extract
+//   CHECK-NOT:     vector.broadcast
+//       CHECK:     scf.for {{.*}} -> (f32) {
+//   CHECK-NOT:       vector.extract
+//   CHECK-NOT:     vector.broadcast
+//       CHECK:       scf.yield {{.*}} : f32
+//       CHECK:     }
+//   CHECK-NOT:     vector.extract
+//       CHECK:     return {{.*}} : f32
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 12315db..4d4c53c 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -15,6 +15,7 @@
 #ifndef IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
 #define IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
 
+#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
 #include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
@@ -50,6 +51,7 @@
     createSplitDispatchFunctionPass();
     createVectorToGPUPass();
     createMatMulTileAndVectorizeGPUPass();
+    createForOpCanonicalizationPass();
     createVectorizeMemref();
     return true;
   }();