Loosen requirements on iree_linalg_ext.scatter for complex types (#13055)
We support complex regions in the `iree_linalg_ext.scatter` just loosen
requirements so that we do not fail during successful cases.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index a571b53..eba1558 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -49,6 +49,12 @@
// Utils.
//===----------------------------------------------------------------------===//
+static Type getComplexElementTypeOrSelf(Type ty) {
+ if (auto complex = dyn_cast_or_null<ComplexType>(ty))
+ return complex.getElementType();
+ return ty;
+}
+
static void getEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
@@ -218,7 +224,8 @@
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
- if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
+ if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
+ !getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
return op->emitOpError(
"expected region to have scalar argument of integer or float types");
}