Legalize arith.constant of `tensor<...xi1>` inlined into dispatch to `tensor<..xi8>` (#9867)

This should be fix for #8491.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 913b9dd..f3d4900 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -100,6 +100,46 @@
       : OpConversionPattern<T>(typeConverter, context, 100) {}
 };
 
+/// Type conversion for arith.constant operands.
+struct ConstantOpTypeConversion
+    : public TypePropagationPattern<arith::ConstantOp> {
+  using TypePropagationPattern<arith::ConstantOp>::TypePropagationPattern;
+
+  LogicalResult matchAndRewrite(
+      arith::ConstantOp constantOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const final {
+    auto attr = constantOp.getValue().cast<DenseElementsAttr>();
+    auto attrType = attr.getType().dyn_cast<ShapedType>();
+    if (!attrType) {
+      return rewriter.notifyMatchFailure(
+          constantOp, "expected attribute type to be shaped type");
+    }
+    Optional<Type> legalizedElementType =
+        getLegalizedElementType(attrType.getElementType());
+    if (!legalizedElementType) {
+      return rewriter.notifyMatchFailure(constantOp,
+                                         "cannot legalize elementType");
+    }
+    if (!legalizedElementType->isIntOrFloat()) {
+      return rewriter.notifyMatchFailure(
+          constantOp, "expected legalized type to be integer or float type");
+    }
+    SmallVector<APInt> legalizedValues;
+    unsigned numElements = attr.isSplat() ? 1 : attr.getNumElements();
+    legalizedValues.reserve(numElements);
+    unsigned bitWidth = legalizedElementType->getIntOrFloatBitWidth();
+    for (auto value : attr.getValues<APInt>()) {
+      legalizedValues.emplace_back(bitWidth, value.getZExtValue());
+    }
+    auto newAttrType = RankedTensorType::get(attrType.getShape(),
+                                             legalizedElementType.getValue());
+    auto newAttr = DenseElementsAttr::get(newAttrType, legalizedValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constantOp, newAttr,
+                                                   newAttrType);
+    return success();
+  }
+};
+
 /// Propagates the type for `linalg.generic` operation.
 /// - Convert operands whose type has changed.
 /// - Convert corresponding basic block argument type and introduce element
@@ -349,7 +389,7 @@
 
     TypePropagationTypeConverter typeConverter;
     patterns
-        .insert<ForwardSourceType<arith::ExtUIOp>,
+        .insert<ConstantOpTypeConversion, ForwardSourceType<arith::ExtUIOp>,
                 ForwardSourceType<arith::TruncIOp>, GenericOpTypePropagation,
                 LinalgFillTypePropagation, LegalizeResultElementType>(
             typeConverter, context);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 2881896..1909fe2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -239,3 +239,61 @@
 //  CHECK-SAME:       ins(%[[EXT_SCALAR]] :
 //  CHECK-SAME:       outs(%[[INIT]] :
 //       CHECK:   flow.dispatch.tensor.store %[[FILL]], %[[OUT]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @constant_op() {
+  %a = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+  %b = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+  %c = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4xi32>
+  %at = flow.dispatch.tensor.load %a, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+  %bt = flow.dispatch.tensor.load %b, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+  %select = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+  %init = linalg.init_tensor [4] : tensor<4xi32>
+  %result = linalg.generic {
+      indexing_maps = [#map, #map, #map, #map],
+      iterator_types = ["parallel"]}
+      ins(%select, %at, %bt : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+      outs(%init : tensor<4xi32>) {
+    ^bb0(%b0 : i1, %b1 : i32, %b2 : i32, %b3 : i32) :
+      %0 = arith.select %b0, %b1, %b2 : i32
+      linalg.yield %0 : i32
+  } -> tensor<4xi32>
+  flow.dispatch.tensor.store %result, %c, offsets = [0], sizes = [4], strides = [1] : tensor<4xi32> -> !flow.dispatch.tensor<writeonly:4xi32>
+  return
+}
+// CHECK-LABEL: func.func @constant_op()
+//       CHECK:   %[[CONST:.+]] = arith.constant dense<[1, 0, 1, 0]> : tensor<4xi8>
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       ins(%[[CONST]]
+//  CHECK-NEXT:   ^bb0
+//  CHECK-SAME:       %[[B0:.+]]: i8
+//       CHECK:     %[[TRUNC:.+]] = arith.trunci %[[B0]] : i8 to i1
+//       CHECK:     arith.select %[[TRUNC]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @constant_splat_op() {
+  %a = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+  %b = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+  %c = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4xi32>
+  %at = flow.dispatch.tensor.load %a, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+  %bt = flow.dispatch.tensor.load %b, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+  %select = arith.constant dense<true> : tensor<4xi1>
+  %init = linalg.init_tensor [4] : tensor<4xi32>
+  %result = linalg.generic {
+      indexing_maps = [#map, #map, #map, #map],
+      iterator_types = ["parallel"]}
+      ins(%select, %at, %bt : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+      outs(%init : tensor<4xi32>) {
+    ^bb0(%b0 : i1, %b1 : i32, %b2 : i32, %b3 : i32) :
+      %0 = arith.select %b0, %b1, %b2 : i32
+      linalg.yield %0 : i32
+  } -> tensor<4xi32>
+  flow.dispatch.tensor.store %result, %c, offsets = [0], sizes = [4], strides = [1] : tensor<4xi32> -> !flow.dispatch.tensor<writeonly:4xi32>
+  return
+}
+// CHECK-LABEL: func.func @constant_splat_op()
+//       CHECK:   arith.constant dense<1> : tensor<4xi8>
diff --git a/tests/e2e/regression/BUILD b/tests/e2e/regression/BUILD
index dfea378..22f14a6 100644
--- a/tests/e2e/regression/BUILD
+++ b/tests/e2e/regression/BUILD
@@ -27,6 +27,7 @@
     "dynamic_torch_index_select_negative.mlir",
     "dynamic_torch_index_select_scalar.mlir",
     "dynamic_torch_index_select_vector.mlir",
+    "i1_inlined_constant.mlir",
     "linalg_ops.mlir",
     "strided_slice.mlir",
 ]
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index ca40775..5ff8cde 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -43,6 +43,7 @@
     "dynamic_torch_index_select_negative.mlir"
     "dynamic_torch_index_select_scalar.mlir"
     "dynamic_torch_index_select_vector.mlir"
+    "i1_inlined_constant.mlir"
     "linalg_ops.mlir"
     "lowering_config.mlir"
     "strided_slice.mlir"
@@ -79,6 +80,7 @@
     "dynamic_torch_index_select_negative.mlir"
     "dynamic_torch_index_select_scalar.mlir"
     "dynamic_torch_index_select_vector.mlir"
+    "i1_inlined_constant.mlir"
     "linalg_ops.mlir"
     "strided_slice.mlir"
   TARGET_BACKEND
@@ -101,6 +103,7 @@
     "dynamic_torch_index_select_negative.mlir"
     "dynamic_torch_index_select_scalar.mlir"
     "dynamic_torch_index_select_vector.mlir"
+    "i1_inlined_constant.mlir"
     "linalg_ops.mlir"
     "strided_slice.mlir"
   TARGET_BACKEND
@@ -123,6 +126,7 @@
     "dynamic_torch_index_select_negative.mlir"
     "dynamic_torch_index_select_scalar.mlir"
     "dynamic_torch_index_select_vector.mlir"
+    "i1_inlined_constant.mlir"
     "large_reduction.mlir"
     "linalg_ops.mlir"
     "strided_slice.mlir"
diff --git a/tests/e2e/regression/i1_inlined_constant.mlir b/tests/e2e/regression/i1_inlined_constant.mlir
new file mode 100644
index 0000000..93e838e
--- /dev/null
+++ b/tests/e2e/regression/i1_inlined_constant.mlir
@@ -0,0 +1,18 @@
+func.func @select_with_binary() {
+  %control = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+  %a = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %b = arith.constant dense<[5, 6, 7, 8]> : tensor<4xi32>
+  %init = linalg.init_tensor [4] : tensor<4xi32>
+  %c = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>,
+                       affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%control, %a, %b : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+      outs(%init : tensor<4xi32>) {
+    ^bb0(%b1 : i1, %b2 : i32, %b3 : i32, %b4 : i32):
+      %0 = arith.select %b1, %b2, %b3 : i32
+      linalg.yield %0 : i32
+    } -> tensor<4xi32>
+  check.expect_eq_const(%c, dense<[1, 6, 3, 8]> : tensor<4xi32>) : tensor<4xi32>
+  return
+}