Legalize arith.constant of `tensor<...xi1>` inlined into dispatch to `tensor<..xi8>` (#9867)
This should be fix for #8491.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 913b9dd..f3d4900 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -100,6 +100,46 @@
: OpConversionPattern<T>(typeConverter, context, 100) {}
};
+/// Type conversion for arith.constant operands.
+struct ConstantOpTypeConversion
+ : public TypePropagationPattern<arith::ConstantOp> {
+ using TypePropagationPattern<arith::ConstantOp>::TypePropagationPattern;
+
+ LogicalResult matchAndRewrite(
+ arith::ConstantOp constantOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ auto attr = constantOp.getValue().cast<DenseElementsAttr>();
+ auto attrType = attr.getType().dyn_cast<ShapedType>();
+ if (!attrType) {
+ return rewriter.notifyMatchFailure(
+ constantOp, "expected attribute type to be shaped type");
+ }
+ Optional<Type> legalizedElementType =
+ getLegalizedElementType(attrType.getElementType());
+ if (!legalizedElementType) {
+ return rewriter.notifyMatchFailure(constantOp,
+ "cannot legalize elementType");
+ }
+ if (!legalizedElementType->isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ constantOp, "expected legalized type to be integer or float type");
+ }
+ SmallVector<APInt> legalizedValues;
+ unsigned numElements = attr.isSplat() ? 1 : attr.getNumElements();
+ legalizedValues.reserve(numElements);
+ unsigned bitWidth = legalizedElementType->getIntOrFloatBitWidth();
+ for (auto value : attr.getValues<APInt>()) {
+ legalizedValues.emplace_back(bitWidth, value.getZExtValue());
+ }
+ auto newAttrType = RankedTensorType::get(attrType.getShape(),
+ legalizedElementType.getValue());
+ auto newAttr = DenseElementsAttr::get(newAttrType, legalizedValues);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constantOp, newAttr,
+ newAttrType);
+ return success();
+ }
+};
+
/// Propagates the type for `linalg.generic` operation.
/// - Convert operands whose type has changed.
/// - Convert corresponding basic block argument type and introduce element
@@ -349,7 +389,7 @@
TypePropagationTypeConverter typeConverter;
patterns
- .insert<ForwardSourceType<arith::ExtUIOp>,
+ .insert<ConstantOpTypeConversion, ForwardSourceType<arith::ExtUIOp>,
ForwardSourceType<arith::TruncIOp>, GenericOpTypePropagation,
LinalgFillTypePropagation, LegalizeResultElementType>(
typeConverter, context);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 2881896..1909fe2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -239,3 +239,61 @@
// CHECK-SAME: ins(%[[EXT_SCALAR]] :
// CHECK-SAME: outs(%[[INIT]] :
// CHECK: flow.dispatch.tensor.store %[[FILL]], %[[OUT]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @constant_op() {
+ %a = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+ %b = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+ %c = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4xi32>
+ %at = flow.dispatch.tensor.load %a, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+ %bt = flow.dispatch.tensor.load %b, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+ %select = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+ %init = linalg.init_tensor [4] : tensor<4xi32>
+ %result = linalg.generic {
+ indexing_maps = [#map, #map, #map, #map],
+ iterator_types = ["parallel"]}
+ ins(%select, %at, %bt : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+ outs(%init : tensor<4xi32>) {
+ ^bb0(%b0 : i1, %b1 : i32, %b2 : i32, %b3 : i32) :
+ %0 = arith.select %b0, %b1, %b2 : i32
+ linalg.yield %0 : i32
+ } -> tensor<4xi32>
+ flow.dispatch.tensor.store %result, %c, offsets = [0], sizes = [4], strides = [1] : tensor<4xi32> -> !flow.dispatch.tensor<writeonly:4xi32>
+ return
+}
+// CHECK-LABEL: func.func @constant_op()
+// CHECK: %[[CONST:.+]] = arith.constant dense<[1, 0, 1, 0]> : tensor<4xi8>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[CONST]]
+// CHECK-NEXT: ^bb0
+// CHECK-SAME: %[[B0:.+]]: i8
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[B0]] : i8 to i1
+// CHECK: arith.select %[[TRUNC]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @constant_splat_op() {
+ %a = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+ %b = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4xi32>
+ %c = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4xi32>
+ %at = flow.dispatch.tensor.load %a, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+ %bt = flow.dispatch.tensor.load %b, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<4xi32>
+ %select = arith.constant dense<true> : tensor<4xi1>
+ %init = linalg.init_tensor [4] : tensor<4xi32>
+ %result = linalg.generic {
+ indexing_maps = [#map, #map, #map, #map],
+ iterator_types = ["parallel"]}
+ ins(%select, %at, %bt : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+ outs(%init : tensor<4xi32>) {
+ ^bb0(%b0 : i1, %b1 : i32, %b2 : i32, %b3 : i32) :
+ %0 = arith.select %b0, %b1, %b2 : i32
+ linalg.yield %0 : i32
+ } -> tensor<4xi32>
+ flow.dispatch.tensor.store %result, %c, offsets = [0], sizes = [4], strides = [1] : tensor<4xi32> -> !flow.dispatch.tensor<writeonly:4xi32>
+ return
+}
+// CHECK-LABEL: func.func @constant_splat_op()
+// CHECK: arith.constant dense<1> : tensor<4xi8>
diff --git a/tests/e2e/regression/BUILD b/tests/e2e/regression/BUILD
index dfea378..22f14a6 100644
--- a/tests/e2e/regression/BUILD
+++ b/tests/e2e/regression/BUILD
@@ -27,6 +27,7 @@
"dynamic_torch_index_select_negative.mlir",
"dynamic_torch_index_select_scalar.mlir",
"dynamic_torch_index_select_vector.mlir",
+ "i1_inlined_constant.mlir",
"linalg_ops.mlir",
"strided_slice.mlir",
]
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index ca40775..5ff8cde 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -43,6 +43,7 @@
"dynamic_torch_index_select_negative.mlir"
"dynamic_torch_index_select_scalar.mlir"
"dynamic_torch_index_select_vector.mlir"
+ "i1_inlined_constant.mlir"
"linalg_ops.mlir"
"lowering_config.mlir"
"strided_slice.mlir"
@@ -79,6 +80,7 @@
"dynamic_torch_index_select_negative.mlir"
"dynamic_torch_index_select_scalar.mlir"
"dynamic_torch_index_select_vector.mlir"
+ "i1_inlined_constant.mlir"
"linalg_ops.mlir"
"strided_slice.mlir"
TARGET_BACKEND
@@ -101,6 +103,7 @@
"dynamic_torch_index_select_negative.mlir"
"dynamic_torch_index_select_scalar.mlir"
"dynamic_torch_index_select_vector.mlir"
+ "i1_inlined_constant.mlir"
"linalg_ops.mlir"
"strided_slice.mlir"
TARGET_BACKEND
@@ -123,6 +126,7 @@
"dynamic_torch_index_select_negative.mlir"
"dynamic_torch_index_select_scalar.mlir"
"dynamic_torch_index_select_vector.mlir"
+ "i1_inlined_constant.mlir"
"large_reduction.mlir"
"linalg_ops.mlir"
"strided_slice.mlir"
diff --git a/tests/e2e/regression/i1_inlined_constant.mlir b/tests/e2e/regression/i1_inlined_constant.mlir
new file mode 100644
index 0000000..93e838e
--- /dev/null
+++ b/tests/e2e/regression/i1_inlined_constant.mlir
@@ -0,0 +1,18 @@
+func.func @select_with_binary() {
+ %control = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+ %a = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %b = arith.constant dense<[5, 6, 7, 8]> : tensor<4xi32>
+ %init = linalg.init_tensor [4] : tensor<4xi32>
+ %c = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%control, %a, %b : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+ outs(%init : tensor<4xi32>) {
+ ^bb0(%b1 : i1, %b2 : i32, %b3 : i32, %b4 : i32):
+ %0 = arith.select %b1, %b2, %b3 : i32
+ linalg.yield %0 : i32
+ } -> tensor<4xi32>
+ check.expect_eq_const(%c, dense<[1, 6, 3, 8]> : tensor<4xi32>) : tensor<4xi32>
+ return
+}