Avoid assert failure with `index` as elem type. (#10814)
This is a more common occurrence than you'd think because of what's
discussed in #10813.
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
index d857352..aef4b9a 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
@@ -585,88 +585,90 @@
// Select the op to lower to and configure the emitter.
// Emit from the iree_ukernel_x32b_opcode_t table.
+ Type resultType = binaryOp->getResult(0).getType();
+ if (!resultType.isIntOrFloat()) return failure();
Optional<BinaryEmitter> emitter =
TypeSwitch<Operation *, Optional<BinaryEmitter>>(binaryOp)
.Case([&](arith::AddFOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "add");
}
return None;
})
.Case([&](arith::AddIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "add");
}
return None;
})
.Case([&](arith::AndIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "and");
}
return None;
})
.Case([&](arith::DivFOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "div");
}
return None;
})
.Case([&](arith::DivSIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "divs");
}
return None;
})
.Case([&](arith::DivUIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "divu");
}
return None;
})
.Case([&](arith::MulFOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "mul");
}
return None;
})
.Case([&](arith::MulIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "mul");
}
return None;
})
.Case([&](arith::OrIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "or");
}
return None;
})
.Case([&](arith::ShLIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "shl");
}
return None;
})
.Case([&](arith::ShRSIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "shrs");
}
return None;
})
.Case([&](arith::XOrIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "xor");
}
return None;
})
.Case([&](arith::SubFOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "sub");
}
return None;
})
.Case([&](arith::SubIOp op) -> Optional<BinaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericBinary(op, "sub");
}
return None;
@@ -735,52 +737,54 @@
// Select the op to lower to and configure the emitter.
// Emit from the iree_ukernel_x32b_opcode_t table.
+ Type resultType = unaryOp->getResult(0).getType();
+ if (!resultType.isIntOrFloat()) return failure();
Optional<UnaryEmitter> emitter =
TypeSwitch<Operation *, Optional<UnaryEmitter>>(unaryOp)
.Case([&](math::AbsFOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "abs");
}
return None;
})
.Case([&](math::CeilOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "ceil");
}
return None;
})
.Case([&](math::CountLeadingZerosOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "ctlz");
}
return None;
})
.Case([&](math::ExpOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "exp");
}
return None;
})
.Case([&](math::FloorOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "floor");
}
return None;
})
.Case([&](math::LogOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "log");
}
return None;
})
.Case([&](arith::NegFOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "neg");
}
return None;
})
.Case([&](math::RsqrtOp op) -> Optional<UnaryEmitter> {
- if (op.getResult().getType().getIntOrFloatBitWidth() == 32) {
+ if (resultType.getIntOrFloatBitWidth() == 32) {
return configureGenericUnary(op, "rsqrt");
}
return None;