Remove PromoteTensorLoads pass, convert ExtractOp in TensorToFlow. (#6852)

Fixes https://github.com/google/iree/issues/6756 (the [`tosa if.mlir`](https://github.com/google/iree/blob/main/iree/test/e2e/tosa_ops/if.mlir) file compiles successfully using `-iree-flow-enable-linalg-detensorize` with this change)

The `PromoteTensorLoads` pass was converting `i1` loads to `i8` loads using `ZeroExtendIOp` and `TruncateIOp`. That was producing weird cycles during compilation when detensoring was applied, and `flow` ops should be fine with i1 types. We still need to handle `i1` types when going to the HAL (since storage is incompatible) on the outside (external interface) and inside (codegen).
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
index eecfa67..b02d07c 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
@@ -216,6 +216,22 @@
   }
 };
 
+struct ConvertTensorExtractPattern
+    : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<IREE::Flow::TensorLoadOp>(
+        op, op.getResult().getType(), op.tensor(), op.indices());
+    return success();
+  }
+};
+
 struct ConvertTensorCastPattern : public OpRewritePattern<tensor::CastOp> {
   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
 
@@ -307,14 +323,19 @@
 
 }  // namespace
 
-void populateTensorToFlowPatterns(MLIRContext *context,
-                                  OwningRewritePatternList &patterns) {
+void populateTensorToFlowPatternsBeforeDispatchFormation(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
   patterns
       .insert<ConvertTensorInsertSlicePattern, ConvertTensorExtractSlicePattern,
               ConvertTensorCastPattern, ConvertTensorFromElementsPattern>(
           context);
 }
 
+void populateTensorToFlowPatternsAfterDispatchFormation(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<ConvertTensorExtractPattern>(context);
+}
+
 }  // namespace Flow
 }  // namespace IREE
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h
index e5b4505..8a56667 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h
@@ -15,9 +15,13 @@
 namespace IREE {
 namespace Flow {
 
-// Populates rewrite patterns for Tensor->Flow.
-void populateTensorToFlowPatterns(MLIRContext *context,
-                                  OwningRewritePatternList &patterns);
+// Adds patterns for Tensor->Flow, for running before dispatch region formation.
+void populateTensorToFlowPatternsBeforeDispatchFormation(
+    MLIRContext *context, OwningRewritePatternList &patterns);
+
+// Adds patterns for Tensor->Flow, for running after dispatch region formation.
+void populateTensorToFlowPatternsAfterDispatchFormation(
+    MLIRContext *context, OwningRewritePatternList &patterns);
 
 }  // namespace Flow
 }  // namespace IREE
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD
index 1d1e9da..e606a1d 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD
@@ -18,6 +18,7 @@
     srcs = enforce_glob(
         [
             "cast.mlir",
+            "extract.mlir",
             "extract_slice.mlir",
             "from_elements.mlir",
             "insert_slice.mlir",
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt
index d515c1e..ad22390 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "cast.mlir"
+    "extract.mlir"
     "extract_slice.mlir"
     "from_elements.mlir"
     "insert_slice.mlir"
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/extract.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/extract.mlir
new file mode 100644
index 0000000..a17899a
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/extract.mlir
@@ -0,0 +1,17 @@
+// RUN: iree-opt -split-input-file -iree-flow-convert-to-flow-after-dispatch-formation %s | IreeFileCheck %s
+
+func @tensor_extract(%arg0 : tensor<1xi32>, %arg1 : index) -> i32 {
+  // CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1] : tensor<1xi32>
+  // CHECK: return %[[RESULT]]
+  %extract = tensor.extract %arg0[%arg1] : tensor<1xi32>
+  return %extract : i32
+}
+
+// -----
+
+func @tensor_extract_i1(%arg0 : tensor<1xi1>, %arg1 : index) -> i1 {
+  // CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1] : tensor<1xi1>
+  // CHECK: return %[[RESULT]]
+  %extract = tensor.extract %arg0[%arg1] : tensor<1xi1>
+  return %extract : i1
+}
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 41c6aae..0c6d78c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -53,7 +53,6 @@
         "PassDetail.h",
         "Passes.cpp",
         "PromoteI1ToI8Pass.cpp",
-        "PromoteTensorLoads.cpp",
         "StripAndSplatConstantVariables.cpp",
         "TypeConverter.cpp",
         "VerifyInputLegality.cpp",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 4152bbf..b82b3d6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -50,7 +50,6 @@
     "PassDetail.h"
     "Passes.cpp"
     "PromoteI1ToI8Pass.cpp"
-    "PromoteTensorLoads.cpp"
     "StripAndSplatConstantVariables.cpp"
     "TypeConverter.cpp"
     "VerifyInputLegality.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
index c8da8d0..95671d6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
@@ -108,7 +108,7 @@
         LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
         LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>>(
         context);
-    populateTensorToFlowPatterns(context, patterns);
+    populateTensorToFlowPatternsBeforeDispatchFormation(context, patterns);
     IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
@@ -131,6 +131,7 @@
     RewritePatternSet patterns(&getContext());
 
     patterns.insert<LinalgFillToFlowTensorSplat>(context);
+    populateTensorToFlowPatternsAfterDispatchFormation(context, patterns);
     IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
 
     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 415f0f3..174beef 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -167,12 +167,6 @@
   // an argument if two executables differ only in that one dimension).
   passManager.addPass(IREE::Flow::createDeduplicateExecutablesPass());
 
-  // TODO: Prune and rename this pass. This runs after sending everything
-  // possible to the device and then legalizes any remaining h<->d loads,
-  // typically coming from top level flow control.
-  passManager.addNestedPass<mlir::FuncOp>(
-      IREE::Flow::createPromoteTensorLoadsPass());
-
   // Create one function per remaining flow.executable that can be used with
   // iree-benchmark-module to benchmark each dispatch individually, as well as
   // exporting all original model entry points.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 91e65ba..2bfe73f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -80,13 +80,6 @@
 // Promote I1 tensor constants to I8 tensors to match later operations.
 std::unique_ptr<OperationPass<mlir::FuncOp>> createPromoteI1ToI8Pass();
 
-// Converts standard ops which match to flow.tensor.load (typically causing a
-// read-back).
-// Note that there are typically very specific phase ordering issues with
-// performing such a conversion, so even though it is of fine granularity,
-// this is maintained separately.
-std::unique_ptr<OperationPass<mlir::FuncOp>> createPromoteTensorLoadsPass();
-
 // Expands dynamic !shapex.ranked_shape dimensions in variables.
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createExpandGlobalDynamicDimsPass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index e8b6956..ca70028 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -131,12 +131,6 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteI1ToI8Pass()";
 }
 
-def PromoteTensorLoads :
-    Pass<"iree-flow-promote-tensor-loads", "mlir::FuncOp"> {
-  let summary = "Converts standard ops which match to flow.tensor.load (typically causing a read-back)";
-  let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteTensorLoadsPass()";
-}
-
 def StripAndSplatConstantVariables :
     Pass<"iree-flow-strip-and-splat-constant-variables", "mlir::ModuleOp"> {
   let summary = "Strips constant util.globals and replaces them with splats.";
diff --git a/iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp b/iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp
deleted file mode 100644
index 2e6d158..0000000
--- a/iree/compiler/Dialect/Flow/Transforms/PromoteTensorLoads.cpp
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace Flow {
-
-namespace {
-
-/// tensor::ExtractOp will be lowered to IREE::Flow::TensorLoadOp. If the type
-/// is i1, it's not valid to load. In this case, we need to cast it to i8 before
-/// the load, and truncate the value after the load.
-struct ExtractElementOpLowering
-    : public OpConversionPattern<tensor::ExtractOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      tensor::ExtractOp op, ArrayRef<Value> args,
-      ConversionPatternRewriter &rewriter) const override {
-    // tensor<i1> is not valid to load, it needs to be converted to i8 or
-    // something else instead.
-    auto tensorType = op.tensor().getType().cast<TensorType>();
-    if (tensorType.getElementType().isInteger(1)) {
-      auto i1Type = rewriter.getI1Type();
-      auto i8Type = rewriter.getIntegerType(8);
-      auto convertedOperand = rewriter.createOrFold<ZeroExtendIOp>(
-          op.getLoc(), args[0],
-          RankedTensorType::get(tensorType.getShape(), i8Type));
-      auto i8Value = rewriter.createOrFold<IREE::Flow::TensorLoadOp>(
-          op.getLoc(), i8Type, convertedOperand, op.indices());
-      rewriter.replaceOpWithNewOp<TruncateIOp>(op, i1Type, i8Value);
-    } else {
-      rewriter.replaceOpWithNewOp<IREE::Flow::TensorLoadOp>(
-          op, tensorType.getElementType(), op.tensor(), op.indices());
-    }
-    return success();
-  }
-};
-
-void setupStandardToFlowTensorLoadLegality(MLIRContext *context,
-                                           ConversionTarget &conversionTarget) {
-  conversionTarget.addIllegalOp<tensor::ExtractOp>();
-}
-
-void populateStandardToFlowTensorLoadPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
-  patterns.insert<ExtractElementOpLowering>(context);
-}
-
-}  // namespace
-
-class PromoteTensorLoadsPass
-    : public PromoteTensorLoadsBase<PromoteTensorLoadsPass> {
- public:
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<FlowDialect, StandardOpsDialect, tensor::TensorDialect>();
-  }
-
-  void runOnOperation() override {
-    auto *context = &getContext();
-    ConversionTarget conversionTarget(*context);
-    OwningRewritePatternList conversionPatterns(&getContext());
-
-    conversionTarget.addLegalDialect<IREE::Flow::FlowDialect>();
-    conversionTarget.addLegalDialect<StandardOpsDialect>();
-    setupStandardToFlowTensorLoadLegality(context, conversionTarget);
-    populateStandardToFlowTensorLoadPatterns(context, conversionPatterns);
-
-    if (failed(applyPartialConversion(getOperation(), conversionTarget,
-                                      std::move(conversionPatterns)))) {
-      return signalPassFailure();
-    }
-  }
-};
-
-std::unique_ptr<OperationPass<mlir::FuncOp>> createPromoteTensorLoadsPass() {
-  return std::make_unique<PromoteTensorLoadsPass>();
-}
-
-}  // namespace Flow
-}  // namespace IREE
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index 7872f95..1290d4a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -38,7 +38,6 @@
             "pad_linalg_ops.mlir",
             "pad_tensor_to_tensor.mlir",
             "promote_i1_to_i8.mlir",
-            "promote_tensor_loads.mlir",
             "strip_and_splat_constant_variables.mlir",
             "transformation.mlir",
             "verify_input_ir.mlir",
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index fe71d4c..eb0988c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -35,7 +35,6 @@
     "pad_linalg_ops.mlir"
     "pad_tensor_to_tensor.mlir"
     "promote_i1_to_i8.mlir"
-    "promote_tensor_loads.mlir"
     "strip_and_splat_constant_variables.mlir"
     "transformation.mlir"
     "verify_input_ir.mlir"
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
index 12aa304..a2a906d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
@@ -24,7 +24,7 @@
 //  CHECK-SAME:   %[[ARG5:[a-zA-Z0-9]+]]: index
 //   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//       CHECK:   %[[VAL:.+]] = tensor.extract %[[ARG1]][]
+//       CHECK:   %[[VAL:.+]] = flow.tensor.load %[[ARG1]] : tensor<f32>
 //   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //   CHECK-DAG:   %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/promote_tensor_loads.mlir b/iree/compiler/Dialect/Flow/Transforms/test/promote_tensor_loads.mlir
deleted file mode 100644
index 983c28a..0000000
--- a/iree/compiler/Dialect/Flow/Transforms/test/promote_tensor_loads.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-flow-promote-tensor-loads %s | IreeFileCheck %s
-
-func @tensor_extract(%arg0 : tensor<1xi32>, %arg1 : index) -> i32 {
-  // CHECK: %[[RESULT:.*]] = flow.tensor.load %arg0[%arg1]
-  // CHECK: return %[[RESULT]]
-  %extract = tensor.extract %arg0[%arg1] : tensor<1xi32>
-  return %extract : i32
-}
-
-// -----
-func @tensor_extract_i1(%arg0 : tensor<1xi1>, %arg1 : index) -> i1 {
-  // CHECK: %[[ZEXT:.*]] = zexti %arg0 : tensor<1xi1> to tensor<1xi8>
-  // CHECK: %[[LOADED:.*]] = flow.tensor.load %[[ZEXT]][%arg1] : tensor<1xi8>
-  // CHECK: %[[RESULT:.*]] = trunci %[[LOADED]] : i8 to i1
-  // CHECK: return %[[RESULT]]
-  %extract = tensor.extract %arg0[%arg1] : tensor<1xi1>
-  return %extract : i1
-}
diff --git a/iree/compiler/InputConversion/MHLO/Passes.cpp b/iree/compiler/InputConversion/MHLO/Passes.cpp
index 2b589cb..c11ed17 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.cpp
+++ b/iree/compiler/InputConversion/MHLO/Passes.cpp
@@ -6,10 +6,10 @@
 
 #include "iree/compiler/InputConversion/MHLO/Passes.h"
 
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/InputConversion/Common/Passes.h"
 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassOptions.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/Passes.h"
@@ -41,11 +41,6 @@
   passManager.addPass(createConvertShapeToStandardPass());
   passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
 
-  // Now that control flow has been lowered, promote and extract_element
-  // to tensor loads. This will be done again later once everything that can
-  // be is lowered to device.
-  passManager.addNestedPass<FuncOp>(IREE::Flow::createPromoteTensorLoadsPass());
-
   // We also don't handle calls well on the old codepath; until we remove the
   // use of the CFG we can continue inlining.
   passManager.addPass(mlir::createInlinerPass());
diff --git a/iree/compiler/InputConversion/TOSA/Passes.cpp b/iree/compiler/InputConversion/TOSA/Passes.cpp
index b3951d3..fcee9b9 100644
--- a/iree/compiler/InputConversion/TOSA/Passes.cpp
+++ b/iree/compiler/InputConversion/TOSA/Passes.cpp
@@ -36,11 +36,6 @@
   passManager.addNestedPass<FuncOp>(tosa::createTosaToSCF());
   passManager.addNestedPass<FuncOp>(createTopLevelSCFToCFGPass());
 
-  // Now that control flow has been lowered, promote and extract_element
-  // to tensor loads. This will be done again later once everything that can
-  // be is lowered to device.
-  passManager.addNestedPass<FuncOp>(IREE::Flow::createPromoteTensorLoadsPass());
-
   // We also don't handle calls well on the old codepath; until we remove the
   // use of the CFG we can continue inlining.
   passManager.addPass(mlir::createInlinerPass());