Add a pass to canonicalize ops used as loop carried values by scf.for (#3616)
This is an ad hoc pass to clean up some unnecessary operations added
during progressing lowering. MLIR canonicalization currently cannot
clean it up, this is a temporary solution to allow those to be removed
on IREE side.
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index 067383d..4b58a54 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -23,12 +23,14 @@
cc_library(
name = "CodegenUtils",
srcs = [
+ "ForOpCanonicalization.cpp",
"FunctionUtils.cpp",
"GetNumWorkgroups.cpp",
"MarkerUtils.cpp",
"MatmulCodegenStrategy.cpp",
],
hdrs = [
+ "ForOpCanonicalization.h",
"FunctionUtils.h",
"GetNumWorkgroups.h",
"MarkerUtils.h",
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index 86fe0c7..bf8ddd7 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -18,11 +18,13 @@
NAME
CodegenUtils
HDRS
+ "ForOpCanonicalization.h"
"FunctionUtils.h"
"GetNumWorkgroups.h"
"MarkerUtils.h"
"MatmulCodegenStrategy.h"
SRCS
+ "ForOpCanonicalization.cpp"
"FunctionUtils.cpp"
"GetNumWorkgroups.cpp"
"MarkerUtils.cpp"
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
new file mode 100644
index 0000000..c7f2134
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
@@ -0,0 +1,155 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+// Pass to combine instructions across ForOp boundary. It is common when doing
+// incremental lowering to generate transient ops that cancel each others out.
+// Canonicalization usually clean up those operations. When the value is loop
+// carried, MLIR canonicalization currently doesn't remove the redundant
+// operations.
+//
+// This pass allow to workaround MLIR limitation and does ad hoc clean up of
+// instructions found in IREE. Once we have a more general mechanism in MLIR
+// this pass can be completely removed.
+// This pass does this kind of transformation:
+// ```
+// %21 = vector.shape_cast %20 : vector<4xf32> to vector<1x4xf32>
+// %22 = scf.for %arg3 = %c0 to %c4096 step %c4 iter_args(%arg4 = %21)
+// -> vector<1x4xf32> {
+// [...]
+// %100 = vector.shape_cast %arg4 : vector<1x4xf32> to vector<4xf32>
+// [...]
+// %109 = vector.shape_cast %108 : vector<4xf32> to vector<1x4xf32>
+// scf.yield %109 : vector<1x4xf32>
+// }
+// %24 = vector.shape_cast %22 : vector<1x4xf32> to vector<4xf32>
+// ```
+// ->
+// ```
+// %22 = scf.for %arg3 = %c0 to %c4096 step %c4 iter_args(%arg4 = %20)
+// -> vector<4xf32> {
+// [...]
+// scf.yield %108 : vector<4xf32>
+// }
+// ```
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+class ForOpArgFolding final : public OpRewritePattern<scf::ForOp> {
+ public:
+ using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+ Value FoldCarryDep(scf::ForOp forOp, Operation* ivUser,
+ Operation* ivDef) const {
+ if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(ivUser)) {
+ if (auto souceOp = dyn_cast<vector::ShapeCastOp>(ivDef)) {
+ if (shapeCast.getType() == souceOp.source().getType())
+ return souceOp.source();
+ }
+ } else if (auto extractOp = dyn_cast<vector::ExtractOp>(ivUser)) {
+ if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(ivDef)) {
+ if (extractOp.getType() == broadcastOp.getSourceType())
+ return broadcastOp.source();
+ }
+ }
+ return Value();
+ }
+
+ void transferBody(Block* source, Block* dest, ArrayRef<Value> results,
+ PatternRewriter& rewriter) const {
+ // Move all operations to the destination block.
+ rewriter.mergeBlocks(source, dest, dest->getArguments());
+ // Replace the yield op by one that returns only the used values.
+ auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
+ rewriter.updateRootInPlace(
+ yieldOp, [&]() { yieldOp.getOperation()->setOperands(results); });
+ }
+
+ LogicalResult matchAndRewrite(scf::ForOp forOp,
+ PatternRewriter& rewriter) const override {
+ SmallVector<unsigned, 8> iteratorFolded;
+ SmallVector<Operation*, 8> resultOps;
+ auto terminator = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ auto returnValues = llvm::to_vector<8>(terminator.getOperands());
+ auto initArgs = llvm::to_vector<8>(forOp.getIterOperands());
+ for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
+ if (!it.value().hasOneUse()) continue;
+ Operation* op = it.value().use_begin()->getOwner();
+ if (!isa<vector::ShapeCastOp, vector::ExtractOp>(op)) continue;
+ Operation* returnValDef = returnValues[it.index()].getDefiningOp();
+ Value newReturn = FoldCarryDep(forOp, op, returnValDef);
+ if (!newReturn) continue;
+ iteratorFolded.push_back(it.index());
+ resultOps.push_back(returnValDef);
+ returnValues[it.index()] = newReturn;
+
+ BlockAndValueMapping mapping;
+ mapping.map(op->getOperand(0), initArgs[it.index()]);
+ initArgs[it.index()] = rewriter.clone(*op, mapping)->getResult(0);
+ }
+ if (iteratorFolded.empty()) return success();
+ auto newLoop =
+ rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.lowerBound(),
+ forOp.upperBound(), forOp.step(), initArgs);
+ transferBody(forOp.getBody(), newLoop.getBody(), returnValues, rewriter);
+
+ // Replace the operation by the new one.
+ SmallVector<Value, 8> repResults(newLoop.getResults().begin(),
+ newLoop.getResults().end());
+ for (auto en : llvm::enumerate(iteratorFolded)) {
+ BlockAndValueMapping mapping;
+ mapping.map(returnValues[en.value()], newLoop.getResult(en.value()));
+ repResults[en.index()] =
+ rewriter.clone(*resultOps[en.index()], mapping)->getResult(0);
+ Operation* oldOp =
+ newLoop.getRegionIterArgs()[en.index()].use_begin()->getOwner();
+ SmallVector<Value, 1> arg(1, newLoop.getRegionIterArgs()[en.index()]);
+ oldOp->replaceAllUsesWith(arg);
+ }
+ rewriter.replaceOp(forOp, repResults);
+ return success();
+ }
+};
+
+struct ForOpCanonicalizationPass
+ : PassWrapper<ForOpCanonicalizationPass, FunctionPass> {
+ void runOnFunction() override {
+ FuncOp fn = getFunction();
+ OwningRewritePatternList patterns;
+ patterns.insert<ForOpArgFolding>(fn.getContext());
+ applyPatternsAndFoldGreedily(fn, patterns);
+ }
+};
+} // namespace
+
+std::unique_ptr<FunctionPass> createForOpCanonicalizationPass() {
+ return std::make_unique<ForOpCanonicalizationPass>();
+}
+
+static PassRegistration<ForOpCanonicalizationPass> pass(
+ "iree-codegen-canonicalize-scf-for",
+ "An ad-hoc pass to canonicalize selected loop carried dependencies on "
+ "scf.for",
+ [] { return std::make_unique<ForOpCanonicalizationPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
new file mode 100644
index 0000000..ddda183
--- /dev/null
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_FOROPCANONICALIZATION_H_
+#define IREE_COMPILER_CONVERSION_CODEGENUTILS_FOROPCANONICALIZATION_H_
+
+#include <memory>
+
+namespace mlir {
+class FunctionPass;
+namespace iree_compiler {
+
+/// An ad-hoc pass to canonicalize selected loop carried dependencies on
+/// scf.for.
+std::unique_ptr<FunctionPass> createForOpCanonicalizationPass();
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_CODEGENUTILS_FOROPCANONICALIZATION_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/forop_canonicalization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/forop_canonicalization.mlir
new file mode 100644
index 0000000..b44b89a
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/forop_canonicalization.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-opt %s -iree-codegen-canonicalize-scf-for | FileCheck %s
+
+func @loop_carried_cast(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0 = vector.shape_cast %arg0 : vector<4xf32> to vector<1x4xf32>
+ %1 = vector.shape_cast %arg1 : vector<4xf32> to vector<1x4xf32>
+ %20:2 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %0, %arg5 = %1) -> (vector<1x4xf32>, vector<1x4xf32>) {
+ %a = vector.shape_cast %arg4 : vector<1x4xf32> to vector<4xf32>
+ %b = vector.shape_cast %arg5 : vector<1x4xf32> to vector<4xf32>
+ %c = addf %a, %b : vector<4xf32>
+ %d = mulf %a, %b : vector<4xf32>
+ %cc = vector.shape_cast %c : vector<4xf32> to vector<1x4xf32>
+ %dc = vector.shape_cast %d : vector<4xf32> to vector<1x4xf32>
+ scf.yield %cc, %dc : vector<1x4xf32>, vector<1x4xf32>
+ }
+ %21 = vector.shape_cast %20#0 : vector<1x4xf32> to vector<4xf32>
+ %22 = vector.shape_cast %20#1 : vector<1x4xf32> to vector<4xf32>
+ return %21, %22 : vector<4xf32>, vector<4xf32>
+}
+
+// CHECK-LABEL: func @loop_carried_cast
+// CHECK-NOT: vector.shape_cast
+// CHECK: scf.for {{.*}} -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-NOT: vector.shape_cast
+// CHECK: scf.yield {{.*}}, {{.*}} : vector<4xf32>, vector<4xf32>
+// CHECK: }
+// CHECK-NOT: vector.shape_cast
+// CHECK: return {{.*}}, {{.*}} : vector<4xf32>, vector<4xf32>
+
+func @loop_carried_extract(%arg0: f32) -> f32 {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0 = vector.broadcast %arg0 : f32 to vector<4xf32>
+ %20 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %0) -> (vector<4xf32>) {
+ %a = vector.extract %arg4[0] : vector<4xf32>
+ %c = addf %a, %a : f32
+ %bc = vector.broadcast %c : f32 to vector<4xf32>
+ scf.yield %bc : vector<4xf32>
+ }
+ %21 = vector.extract %20[0] : vector<4xf32>
+ return %21 : f32
+}
+
+// CHECK-LABEL: func @loop_carried_extract
+// CHECK-NOT: vector.broadcast
+// CHECK: scf.for {{.*}} -> (f32) {
+// CHECK-NOT: vector.extract
+// CHECK-NOT: vector.broadcast
+// CHECK: scf.yield {{.*}} : f32
+// CHECK: }
+// CHECK-NOT: vector.extract
+// CHECK: return {{.*}} : f32
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 12315db..4d4c53c 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -15,6 +15,7 @@
#ifndef IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
#define IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
+#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
@@ -50,6 +51,7 @@
createSplitDispatchFunctionPass();
createVectorToGPUPass();
createMatMulTileAndVectorizeGPUPass();
+ createForOpCanonicalizationPass();
createVectorizeMemref();
return true;
}();