Preserving MHLO demotion behavior until #8745 is fixed.
diff --git a/iree/compiler/InputConversion/MHLO/BUILD b/iree/compiler/InputConversion/MHLO/BUILD
index e7bfa5a..1ebb8c4 100644
--- a/iree/compiler/InputConversion/MHLO/BUILD
+++ b/iree/compiler/InputConversion/MHLO/BUILD
@@ -64,6 +64,7 @@
":PassesIncGen",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/Util/IR",
+ "//iree/compiler/Dialect/Util/Transforms",
"//iree/compiler/InputConversion/Common",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
diff --git a/iree/compiler/InputConversion/MHLO/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
index 317734a..630b81a 100644
--- a/iree/compiler/InputConversion/MHLO/CMakeLists.txt
+++ b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
@@ -90,6 +90,7 @@
MhloToStandard
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
+ iree::compiler::Dialect::Util::Transforms
iree::compiler::InputConversion::Common
tensorflow::external_mhlo_includes
PUBLIC
diff --git a/iree/compiler/InputConversion/MHLO/Passes.cpp b/iree/compiler/InputConversion/MHLO/Passes.cpp
index 4935e05..7dd22b2 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.cpp
+++ b/iree/compiler/InputConversion/MHLO/Passes.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/InputConversion/MHLO/Passes.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
@@ -22,6 +23,19 @@
namespace iree_compiler {
namespace MHLO {
+// TODO(#8745): remove these flags when the -iree-flow-demote-* flags can be
+// used without tripping upstream verifier issues.
+static llvm::cl::opt<bool> clDemoteI64ToI32(
+ "iree-mhlo-demote-i64-to-i32",
+ llvm::cl::desc(
+ "Converts all MHLO i64 ops and values into i32 counterparts."),
+ llvm::cl::init(true));
+static llvm::cl::opt<bool> clDemoteF64ToF32(
+ "iree-mhlo-demote-f64-to-f32",
+ llvm::cl::desc(
+ "Converts all MHLO f64 ops and values into f32 counterparts."),
+ llvm::cl::init(true));
+
void registerMHLOConversionPassPipeline() {
PassPipelineRegistration<> mhlo(
"iree-mhlo-input-transformation-pipeline",
@@ -58,6 +72,17 @@
// use of the CFG we can continue inlining.
passManager.addPass(mlir::createInlinerPass());
+ // Hacky type conversion to work around lack of type support lower in the
+ // stack. This is often required because of implicit i64 insertion by JAX/HLO
+ // that we don't want forcing 32-bit embedded devices to support.
+ // TODO(#8745): remove these and prefer the flow pipeline options instead.
+ if (clDemoteI64ToI32) {
+ passManager.addPass(IREE::Util::createDemoteI64ToI32Pass());
+ }
+ if (clDemoteF64ToF32) {
+ passManager.addPass(IREE::Util::createDemoteF64ToF32Pass());
+ }
+
// Perform initial cleanup. createLegalizeInputTypes could rewrite types. In
// this context, some operations could be folded away.
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
diff --git a/iree/test/e2e/models/unidirectional_lstm.mlir b/iree/test/e2e/models/unidirectional_lstm.mlir
index 6d68dda..96dab18 100644
--- a/iree/test/e2e/models/unidirectional_lstm.mlir
+++ b/iree/test/e2e/models/unidirectional_lstm.mlir
@@ -1,6 +1,5 @@
// An example LSTM exported from a python reference model with dummy weights.
-// RUN: iree-run-mlir %s --iree-input-type=mhlo -iree-hal-target-backends=vmvx -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s
// RUN: iree-run-mlir %s --iree-input-type=mhlo -iree-hal-target-backends=dylib-llvm-aot -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s
// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir %s --iree-input-type=mhlo -iree-hal-target-backends=vmvx -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s)
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir %s --iree-input-type=mhlo -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=[0,1,0,3,4]" -function-input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s)
diff --git a/llvm-external-projects/iree-compiler-api/pyproject.toml b/llvm-external-projects/iree-compiler-api/pyproject.toml
index a5ba7a4..43e846c 100644
--- a/llvm-external-projects/iree-compiler-api/pyproject.toml
+++ b/llvm-external-projects/iree-compiler-api/pyproject.toml
@@ -1,15 +1,15 @@
-[build-system]
-requires = [
- "setuptools>=42",
- "wheel",
- # There is no fundamental reason to pin this CMake version, beyond
- # build stability.
- "cmake==3.22.2",
- "ninja==1.10.2",
- # MLIR build depends.
- "numpy",
- # Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136
- "pybind11>=2.6.0,!=2.7.0",
- "PyYAML",
-]
-build-backend = "setuptools.build_meta"
+[build-system]
+requires = [
+ "setuptools>=42",
+ "wheel",
+ # There is no fundamental reason to pin this CMake version, beyond
+ # build stability.
+ "cmake==3.22.2",
+ "ninja==1.10.2",
+ # MLIR build depends.
+ "numpy",
+ # Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136
+ "pybind11>=2.6.0,!=2.7.0",
+ "PyYAML",
+]
+build-backend = "setuptools.build_meta"