[i1] Do not emit `arith.trunci` cast from i1 to i1 (#19176)
`arith.trunci` does not allow cast to same type
Signed-off-by: Alan Li <me@alanli.org>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
index 42307b7..2f85e23 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
@@ -229,9 +229,12 @@
Value maskVal = args[0];
// TODO: Replace bool mask condition once treated as i1 (instead of i8)
- if (maskVal.getType().isInteger()) {
- maskVal =
- b.create<arith::TruncIOp>(loc, builder.getI1Type(), maskVal);
+ auto maskValType = maskVal.getType();
+ if (maskValType.isInteger()) {
+ if (maskValType.getIntOrFloatBitWidth() != 1) {
+ maskVal =
+ b.create<arith::TruncIOp>(loc, builder.getI1Type(), maskVal);
+ }
maskVal = b.create<arith::SelectOp>(loc, maskVal, zero, negInf);
} else {
maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(),