Add LinalgExt TypePropagation pattern that handles i1 inputs/outputs (#13936)

This pattern functions under several assumptions:
* Input ("Updates") element type matches output element type.
* There is only one input/output -- although listed as variadic,
LinalgExt doesn't actually support variadic inputs/outputs. The input
list must always be length 2 (updates + indices) and the output list
must be always of length 1.
* We are only handling illegal types for updates and outputs. This is
because an input/output matrix of type i1 (for example) is a reasonable
situation, but an index matrix of an illegal type is an indication that
something else is going wrong (at least, I personally can't think of an
example where that would be the case)
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 1d82847..2b10f5a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -24,6 +24,8 @@
 //
 //===---------------------------------------------------------------------===//
 
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Codegen/Common/CommonPasses.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -320,6 +322,90 @@
   }
 };
 
+/// Pattern to legalize `iree_linalg_ext.scatter` operations.
+struct IREELinalgExtScatterTypePropagation
+    : TypePropagationPattern<IREE::LinalgExt::ScatterOp> {
+  using TypePropagationPattern<
+      IREE::LinalgExt::ScatterOp>::TypePropagationPattern;
+  LogicalResult matchAndRewrite(
+      IREE::LinalgExt::ScatterOp scatterOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const final {
+    auto opOperands = scatterOp->getOpOperands();
+    Type inputType = opOperands[0].get().getType();
+    Type legalizedInputType = this->getTypeConverter()->convertType(inputType);
+
+    if (inputType == legalizedInputType) {
+      return scatterOp.emitOpError(
+          "unexpected all types legal within conversion pattern");
+    }
+
+    Type resultType = opOperands[2].get().getType();
+    Type legalizedResultType =
+        this->getTypeConverter()->convertType(resultType);
+
+    // Create a clone of the operation without cloning its regions.
+    auto modifiedOp =
+        cast<IREE::LinalgExt::ScatterOp>(mlir::cloneWithoutRegions(
+            rewriter, scatterOp, {legalizedResultType}, adaptor.getOperands()));
+
+    // Inline the region from the original operation into the new operation.
+    rewriter.inlineRegionBefore(scatterOp->getRegions().front(),
+                                modifiedOp->getRegions().front(),
+                                modifiedOp->getRegions().front().begin());
+    Region &modifiedOpRegion = modifiedOp->getRegions().front();
+
+    // Convert the signature of the region to use the corresponding element
+    // type.
+    TypeConverter::SignatureConversion signatureConverter(
+        modifiedOpRegion.getNumArguments());
+    Type argType = modifiedOpRegion.getArguments()[0].getType();
+    std::optional<Type> legalizedArgType = legalizeStorageElementType(argType);
+    if (!legalizedArgType) {
+      return scatterOp.emitOpError("failed to get legalized type for argument");
+    }
+    signatureConverter.addInputs(0, legalizedArgType.value());
+    signatureConverter.addInputs(1, legalizedArgType.value());
+    rewriter.applySignatureConversion(&modifiedOpRegion, signatureConverter);
+
+    {
+      // Introduce scalar conversion operations to convert back to the original
+      // scalar type.
+      OpBuilder::InsertionGuard g(rewriter);
+      Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
+      BlockArgument inputArg = entryBlock->getArgument(0);
+      BlockArgument outputArg = entryBlock->getArgument(1);
+
+      auto destType = getElementTypeOrSelf(inputType);
+      rewriter.setInsertionPointToStart(entryBlock);
+
+      Value replacementInput =
+          convertElementType(rewriter, inputArg.getLoc(), destType, inputArg);
+      rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(0),
+                                          replacementInput);
+      Value replacementOutput =
+          convertElementType(rewriter, outputArg.getLoc(), destType, outputArg);
+      rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(1),
+                                          replacementOutput);
+
+      // If the output is of an illegal type, the yield value needs to be
+      // modified
+      auto yieldOp = entryBlock->getTerminator();
+
+      rewriter.setInsertionPoint(yieldOp);
+      OpOperand *modifiedOpOperand = &yieldOp->getOpOperand(0);
+
+      auto yieldOperand = convertElementType(rewriter, yieldOp->getLoc(),
+                                             legalizedArgType.value(),
+                                             modifiedOpOperand->get());
+
+      rewriter.replaceOpWithNewOp<IREE::LinalgExt::YieldOp>(yieldOp,
+                                                            yieldOperand);
+    }
+    rewriter.replaceOp(scatterOp, modifiedOp->getResults());
+    return success();
+  }
+};
+
 /// Simple rewrite pattern that just forwards the source as the result if the
 /// result type is not legal (but source type is)
 template <typename OpTy>
@@ -407,8 +493,8 @@
         NamedOpTypePropagation<linalg::BatchMatmulOp>,
         NamedOpTypePropagation<linalg::MatmulOp>,
         NamedOpTypePropagation<linalg::MatvecOp>,
-        NamedOpTypePropagation<linalg::DotOp>, TensorExtractTypePropagation>(
-        typeConverter, context);
+        NamedOpTypePropagation<linalg::DotOp>, TensorExtractTypePropagation,
+        IREELinalgExtScatterTypePropagation>(typeConverter, context);
 
     ConversionTarget target(*context);
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 312a493..46d1c3a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -359,3 +359,43 @@
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[FILL]] :
 //      CHECK:   return %[[GEMM]]
+
+// -----
+
+func.func @scatter() {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8xi8>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8x1xi32>>
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<3xi8>>
+  %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [8], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8xi8>> -> tensor<8xi8>
+  %4 = arith.trunci %3 : tensor<8xi8> to tensor<8xi1>
+  %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [8, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x1xi32>> -> tensor<8x1xi32>
+  %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [3], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<3xi8>> -> tensor<3xi8>
+  %7 = arith.trunci %6 : tensor<3xi8> to tensor<3xi1>
+  %8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%4, %5 : tensor<8xi1>, tensor<8x1xi32>) outs(%7 : tensor<3xi1>) {
+  ^bb0(%arg0: i1, %arg1: i1):
+    %10 = arith.minui %arg1, %arg0 : i1
+    iree_linalg_ext.yield %10 : i1
+  } -> tensor<3xi1>
+  %9 = arith.extui %8 : tensor<3xi1> to tensor<3xi8>
+  flow.dispatch.tensor.store %9, %2, offsets = [0], sizes = [3], strides = [1] : tensor<3xi8> -> !flow.dispatch.tensor<readwrite:tensor<3xi8>>
+  return
+}
+
+// CHECK-LABEL: func.func @scatter()
+//   CHECK-DAG:   %[[UPDATES:.+]] = hal.interface.binding.subspan set(0) binding(0)
+//   CHECK-DAG:   %[[INDICES:.+]] = hal.interface.binding.subspan set(0) binding(1)
+//   CHECK-DAG:   %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(2)
+//   CHECK-DAG:   %[[UPDATES_TENSOR:.+]] = flow.dispatch.tensor.load %[[UPDATES]]
+//   CHECK-DAG:   %[[INDICES_TENSOR:.+]] = flow.dispatch.tensor.load %[[INDICES]]
+//   CHECK-DAG:   %[[OUT_TENSOR:.+]] = flow.dispatch.tensor.load %[[OUT]]
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false)
+//  CHECK-SAME:       ins(%[[UPDATES_TENSOR]], %[[INDICES_TENSOR]] : tensor<8xi8>, tensor<8x1xi32>)
+//  CHECK-SAME:       outs(%[[OUT_TENSOR]] : tensor<3xi8>)
+//  CHECK-NEXT:     ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: i8, %[[ARG1:[a-zA-Z0-9]+]]: i8)
+//   CHECK-DAG:       %[[TRUNC0:.+]] = arith.trunci %[[ARG0]] : i8 to i1
+//   CHECK-DAG:       %[[TRUNC1:.+]] = arith.trunci %[[ARG1]] : i8 to i1
+//   CHECK-DAG:       %[[MIN:.+]] = arith.minui %[[TRUNC1]], %[[TRUNC0]] : i1
+//   CHECK-DAG:       %[[EXTUI:.+]] = arith.extui %[[MIN]] : i1 to i8
+//       CHECK:       iree_linalg_ext.yield %[[EXTUI]]
+//       CHECK:   flow.dispatch.tensor.store %[[SCATTER]], %[[OUT]]