Create new pass PromoteTensorLoads. (#6023)
* Create new pass PromoteTensorLoads.
* In the new input pipeline, the only part of
PrePartitioningConversionPass which survives is the
tensor.extract_element -> flow.tensor.load conversion.
* This conversion actually needs to be done at a couple of points during
lowering, first promoting any extracts that are introduced as part of
control flow (in the input pipeline), then allowing most of the program
to be loaded onto the device, and finally converting any remaining,
otherwise unrecognized extract elements.
* As such, I opted to make it a very specific pass that does exactly
what it says on the label. We may want to do something more
sophisticated later, and at least having one thing to see and replace
will help.
* Once the new input pipeline lands, PrePostPartitioningConversion.cpp
will be deleted.
diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp
index d41a5eb..72b5d3c 100644
--- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp
@@ -50,13 +50,13 @@
} // namespace
-void setupDirectStandardToFlowLegality(MLIRContext *context,
- ConversionTarget &conversionTarget) {
+void setupStandardToFlowTensorLoadLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget) {
conversionTarget.addIllegalOp<tensor::ExtractOp>();
}
-void populateStandardToFlowPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
+void populateStandardToFlowTensorLoadPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<ExtractElementOpLowering>(context);
}
diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h
index 81bb4a1..8243091 100644
--- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h
+++ b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h
@@ -13,16 +13,16 @@
namespace mlir {
namespace iree_compiler {
-// Setup the |conversionTarget| op legality for early-phase direct-to-flow
-// conversion from the standard op dialect. This will make certain ops illegal
-// that we know we have good patterns for such that we can be sure we catch them
-// before they are outlined into dispatch regions.
-void setupDirectStandardToFlowLegality(MLIRContext *context,
- ConversionTarget &conversionTarget);
+// Setup the |conversionTarget| op legality for conversion of standard ops
+// which should be mapped to flow.tensor.load. This is maintained as a very
+// specific legalization because flow.tensor.load represents a kind of host
+// read-back and should be materialized at specific points.
+void setupStandardToFlowTensorLoadLegality(MLIRContext *context,
+ ConversionTarget &conversionTarget);
-// Appends all patterns for converting std ops to flow ops.
-void populateStandardToFlowPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+// Appends all patterns for converting to flow.tensor.load.
+void populateStandardToFlowTensorLoadPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 0a1541b..a1e2a10 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -47,6 +47,7 @@
"PassDetail.h",
"Passes.cpp",
"PrePostPartitioningConversion.cpp",
+ "PromoteTensorLoads.cpp",
"StripAndSplatConstantVariables.cpp",
"VerifyCompilerInputLegality.cpp",
],
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 8f278df..7c10c2f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -43,6 +43,7 @@
"PassDetail.h"
"Passes.cpp"
"PrePostPartitioningConversion.cpp"
+ "PromoteTensorLoads.cpp"
"StripAndSplatConstantVariables.cpp"
"VerifyCompilerInputLegality.cpp"
DEPS
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 4eca5b0..53397bb 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -77,6 +77,13 @@
// benefit. Other ops are left unmodified and will be outlined later on.
std::unique_ptr<OperationPass<FuncOp>> createPrePartitioningConversionPass();
+// Converts standard ops which match to flow.tensor.load (typically causing a
+// read-back).
+// Note that there are typically very specific phase ordering issues with
+// performing such a conversion, so even though it is of fine granularity,
+// this is maintained separately.
+std::unique_ptr<OperationPass<FuncOp>> createPromoteTensorLoadsPass();
+
// Expands dynamic !shapex.ranked_shape dimensions in variables.
std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index c09351a..ba9dbd0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -82,6 +82,12 @@
let constructor = "mlir::iree_compiler::IREE::Flow::createOutlineLargeConstantsPass(25)";
}
+def PromoteTensorLoads :
+ Pass<"iree-flow-promote-tensor-loads", "FuncOp"> {
+ let summary = "Converts standard ops which match to flow.tensor.load (typically causing a read-back)";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteTensorLoadsPass()";
+}
+
def PrePartitioningConversion :
Pass<"iree-flow-pre-partitioning-conversion", "FuncOp"> {
let summary = "Dialect conversion prior to partitioning";
diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
index 8df4e44..2b443cf 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
@@ -71,12 +71,12 @@
// For example, DynamicUpdateSlice should end up as a stream operation.
setupDirectHLOToFlowLegality(context, conversionTarget);
populateHLOToFlowPatterns(context, conversionPatterns);
- setupDirectStandardToFlowLegality(context, conversionTarget);
+ setupStandardToFlowTensorLoadLegality(context, conversionTarget);
conversionTarget.addLegalOp<linalg::GenericOp, linalg::IndexedGenericOp>();
conversionTarget
.markOpRecursivelyLegal<linalg::GenericOp, linalg::IndexedGenericOp>();
- populateStandardToFlowPatterns(context, conversionPatterns);
+ populateStandardToFlowTensorLoadPatterns(context, conversionPatterns);
if (failed(applyPartialConversion(getOperation(), conversionTarget,
std::move(conversionPatterns)))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp b/iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp
new file mode 100644
index 0000000..e8c0f00
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp
@@ -0,0 +1,51 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+class PromoteTensorLoadsPass
+ : public PromoteTensorLoadsBase<PromoteTensorLoadsPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<FlowDialect, StandardOpsDialect, tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+ ConversionTarget conversionTarget(*context);
+ OwningRewritePatternList conversionPatterns(&getContext());
+
+ conversionTarget.addLegalDialect<IREE::Flow::FlowDialect>();
+ conversionTarget.addLegalDialect<StandardOpsDialect>();
+ setupStandardToFlowTensorLoadLegality(context, conversionTarget);
+ populateStandardToFlowTensorLoadPatterns(context, conversionPatterns);
+
+ if (failed(applyPartialConversion(getOperation(), conversionTarget,
+ std::move(conversionPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<FuncOp>> createPromoteTensorLoadsPass() {
+ return std::make_unique<PromoteTensorLoadsPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index c07d81b..faa51da 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -34,6 +34,7 @@
"outline_dispatch_regions.mlir",
"outline_large_constants.mlir",
"pre_partitioning_conversion.mlir",
+ "promote_tensor_loads.mlir",
"strip_and_splat_constant_variables.mlir",
"transformation.mlir",
"verify_compiler_input_legality.mlir",
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index f799916..9a82563 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -31,6 +31,7 @@
"outline_dispatch_regions.mlir"
"outline_large_constants.mlir"
"pre_partitioning_conversion.mlir"
+ "promote_tensor_loads.mlir"
"strip_and_splat_constant_variables.mlir"
"transformation.mlir"
"verify_compiler_input_legality.mlir"
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/promote_tensor_loads.mlir b/iree/compiler/Dialect/Flow/Transforms/test/promote_tensor_loads.mlir
new file mode 100644
index 0000000..983c28a
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/promote_tensor_loads.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt -split-input-file -iree-flow-promote-tensor-loads %s | IreeFileCheck %s
+
+func @tensor_extract(%arg0 : tensor<1xi32>, %arg1 : index) -> i32 {
+ // CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1]
+ // CHECK: return %[[RESULT]]
+ %extract = tensor.extract %arg0[%arg1] : tensor<1xi32>
+ return %extract : i32
+}
+
+// -----
+func @tensor_extract_i1(%arg0 : tensor<1xi1>, %arg1 : index) -> i1 {
+ // CHECK: %[[ZEXT:.*]] = zexti %arg0 : tensor<1xi1> to tensor<1xi8>
+ // CHECK: %[[LOADED:.*]] = flow.tensor.load %[[ZEXT]][%arg1] : tensor<1xi8>
+ // CHECK: %[[RESULT:.*]] = trunci %[[LOADED]] : i8 to i1
+ // CHECK: return %[[RESULT]]
+ %extract = tensor.extract %arg0[%arg1] : tensor<1xi1>
+ return %extract : i1
+}