Handle `tensor.extract` of i1 types in type legalization. (#11274)

Also remove unnecessary legality checks within conversion patterns.
The dialect conversion infra already filters out legal ops. By
construction all patterns are invoked only on illegal ops, so checks
within the pattern to make sure the op is illegal are redundant.

Fixes #11267 
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 78f0570..e59a47e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -170,9 +170,8 @@
     }
 
     // 2. If there are no operands modified, just return failure.
-    if (modifiedOperandIndex.empty()) {
-      return rewriter.notifyMatchFailure(genericOp, "all types legal");
-    }
+    assert(!modifiedOperandIndex.empty() &&
+           "unexpected all types legal within conversion pattern");
 
     // 3. Create a clone of the operation without cloning its regions.
     auto linalgOp = cast<linalg::LinalgOp>(genericOp.getOperation());
@@ -274,11 +273,6 @@
   LogicalResult matchAndRewrite(
       linalg::FillOp fillOp, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const final {
-    auto outputType = fillOp.output().getType();
-    auto legalizedOutputType = this->typeConverter->convertType(outputType);
-    if (outputType == legalizedOutputType) {
-      return rewriter.notifyMatchFailure(fillOp, "op already legal");
-    }
     Value value = adaptor.getInputs().front();
     Optional<Type> legalizedElementType =
         getLegalizedElementType(value.getType());
@@ -293,6 +287,24 @@
   }
 };
 
+/// Pattern to legalize `tensor.extract` operations.
+struct TensorExtractTypePropagation
+    : public TypePropagationPattern<tensor::ExtractOp> {
+  using TypePropagationPattern<tensor::ExtractOp>::TypePropagationPattern;
+
+  LogicalResult matchAndRewrite(
+      tensor::ExtractOp extractOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const final {
+    Location loc = extractOp.getLoc();
+    Value newExtract = rewriter.create<tensor::ExtractOp>(
+        loc, adaptor.getTensor(), adaptor.getIndices());
+    Value replacement = convertElementType(
+        rewriter, loc, extractOp.getResult().getType(), newExtract);
+    rewriter.replaceOp(extractOp, replacement);
+    return success();
+  }
+};
+
 /// Simple rewrite pattern that just forwards the source as the result if the
 /// result type is not legal (but source type is)
 template <typename OpTy>
@@ -306,14 +318,7 @@
       return rewriter.notifyMatchFailure(
           op, "unhandled op with multiple operands/results");
     }
-    Type outputType = op->getResult(0).getType();
-    Type legalizedOutputType = this->typeConverter->convertType(outputType);
     Value input = adaptor.getOperands()[0];
-    Value originalInput = op->getOperand(0);
-    if (outputType == legalizedOutputType &&
-        input.getType() == originalInput.getType()) {
-      return rewriter.notifyMatchFailure(op, "op is legal");
-    }
     rewriter.replaceOp(op, input);
     return success();
   }
@@ -334,19 +339,10 @@
       return rewriter.notifyMatchFailure(op, "unhandled ops with successors");
     }
     Location loc = op->getLoc();
-    bool illegalOp = llvm::any_of(
-        llvm::zip(op->getOperands(), convertedOperands),
-        [](std::tuple<Value, Value> tuple) {
-          return std::get<0>(tuple).getType() != std::get<1>(tuple).getType();
-        });
     SmallVector<Type> resultTypes;
     for (Type resultType : op->getResultTypes()) {
       Type legalizedType = this->typeConverter->convertType(resultType);
       resultTypes.push_back(legalizedType);
-      illegalOp |= legalizedType != resultType;
-    }
-    if (!illegalOp) {
-      return rewriter.notifyMatchFailure(op, "op is already legal");
     }
     OperationState state(loc, op->getName(), convertedOperands, resultTypes,
                          op->getAttrs());
@@ -389,11 +385,11 @@
     RewritePatternSet patterns(context);
 
     TypePropagationTypeConverter typeConverter;
-    patterns
-        .insert<ConstantOpTypeConversion, ForwardSourceType<arith::ExtUIOp>,
-                ForwardSourceType<arith::TruncIOp>, GenericOpTypePropagation,
-                LinalgFillTypePropagation, LegalizeResultElementType>(
-            typeConverter, context);
+    patterns.insert<ConstantOpTypeConversion, ForwardSourceType<arith::ExtUIOp>,
+                    ForwardSourceType<arith::TruncIOp>,
+                    GenericOpTypePropagation, LinalgFillTypePropagation,
+                    LegalizeResultElementType, TensorExtractTypePropagation>(
+        typeConverter, context);
 
     ConversionTarget target(*context);
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 5f9ab12..61d3da3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -297,3 +297,37 @@
 }
 // CHECK-LABEL: func.func @constant_splat_op()
 //       CHECK:   arith.constant dense<1> : tensor<4xi8>
+
+// -----
+
+func.func @tensor_extract() {
+  %c0 = arith.constant 0 : index
+  %c13 = arith.constant 13 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64)
+      : !flow.dispatch.tensor<readonly:tensor<14xi8>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64)
+      : !flow.dispatch.tensor<writeonly:tensor<14xi8>>
+  %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [14], strides = [1]
+      : !flow.dispatch.tensor<readonly:tensor<14xi8>> -> tensor<14xi8>
+  %3 = arith.trunci %2 : tensor<14xi8> to tensor<14xi1>
+  %4 = tensor.empty() : tensor<14xi1>
+  %5 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
+      outs(%4 : tensor<14xi1>) {
+  ^bb0(%out: i1):
+    %7 = linalg.index 0 : index
+    %8 = arith.subi %c13, %7 : index
+    %extracted = tensor.extract %3[%8] : tensor<14xi1>
+    linalg.yield %extracted : i1
+  } -> tensor<14xi1>
+  %6 = arith.extui %5 : tensor<14xi1> to tensor<14xi8>
+  flow.dispatch.tensor.store %6, %1, offsets = [0], sizes = [14], strides = [1]
+      : tensor<14xi8> -> !flow.dispatch.tensor<writeonly:tensor<14xi8>>
+  return
+}
+// CHECK-LABEL: func @tensor_extract()
+//       CHECK:   %[[BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<14xi8>>
+//       CHECK:   %[[LOAD:.+]] = flow.dispatch.tensor.load %[[BINDING]]
+//       CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[LOAD]]
+//       CHECK:   arith.trunci %[[EXTRACTED]] : i8 to i1