Add `chlo` dialect to auto input conversion pipeline (#15474)
`chlo` operations should be checked for with the `stablehlo` part of the
auto input conversion pipeline.
diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
index 9a9209b..83cc8b1 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
@@ -82,11 +82,12 @@
}
}
-static void populateFeatures(Operation *op, const Dialect *stablehloDialect,
+static void populateFeatures(Operation *op, const Dialect *chloDialect,
+ const Dialect *stablehloDialect,
const Dialect *tosaDialect,
InputFeatures &features) {
Dialect *d = op->getDialect();
- if (d == stablehloDialect) {
+ if (d == stablehloDialect || d == chloDialect) {
features.hasStableHLO = true;
return populateHloFeatures(op, features);
}
@@ -101,14 +102,15 @@
MLIRContext *ctxt = &getContext();
InputFeatures features;
+ const Dialect *chloDialect = ctxt->getLoadedDialect("chlo");
const Dialect *stablehloDialect = ctxt->getLoadedDialect("stablehlo");
const Dialect *tosaDialect = ctxt->getLoadedDialect("tosa");
- if (!stablehloDialect && !tosaDialect) {
+ if (!chloDialect && !stablehloDialect && !tosaDialect) {
return;
}
auto res = module.walk([&](Operation *op) {
- populateFeatures(op, stablehloDialect, tosaDialect, features);
+ populateFeatures(op, chloDialect, stablehloDialect, tosaDialect, features);
if (features.hasStableHLO && features.hasTOSA) {
module.emitError("not yet implemented mixture of *HLO and TOSA");
return WalkResult::interrupt();