Fix and re-enable jax testing (#5183)
In https://github.com/google/iree/pull/5101 I accidentally removed
bindings/python tests from the CI because I assumed that they would be
covered not in the integrations build. It turns out that we have
tests depending on the TF integrations living under bindings/python
(gross. https://github.com/google/iree/issues/5181). So when
https://github.com/google/iree/pull/5104 moved MHLO tuple and control
flow legalization out of core, we didn't catch the test failure for
JAX.
This adds passes to convert from MHLO as produced by XLA->MHLO
translation to the form accepted by IREE. As we move more towards
making linalg on tensors our input and push more MHLO out of IREE core
the pipeline here will likely grow.
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
index ec73635..f9f898c 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
@@ -70,4 +70,6 @@
export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-$(nproc)}
echo "Testing with CTest"
-ctest --output-on-failure -L 'integrations/tensorflow' --label-exclude "^nokokoro$"
+ctest --output-on-failure \
+ --tests-regex "^integrations/tensorflow/|^bindings/python/" \
+ --label-exclude "^nokokoro$"
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh
index 2777107..b3cb298 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh
@@ -78,6 +78,6 @@
# as well.
echo "Testing with CTest"
ctest --output-on-failure \
- --tests-regex "^integrations/tensorflow/" \
+ --tests-regex "^integrations/tensorflow/|^bindings/python/" \
--label-regex "^driver=vulkan$|^driver=cuda$" \
--label-exclude "^nokokoro$"
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 5392611..81293eb 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -93,6 +93,7 @@
"//iree_tf_compiler/TF",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@org_tensorflow//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_parser",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
index e26b7b4..bb76f08 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
@@ -127,6 +127,10 @@
pm.nest<ModuleOp>().addPass(createStripFunctionMetadataPass());
pm.addPass(createVerifyFullyConvertedPass());
+ buildMHLOImportPassPipeline(pm);
+}
+
+void buildMHLOImportPassPipeline(OpPassManager &pm) {
//----------------------------------------------------------------------------
// Convert control flow and flatten tuples (like tuple<tensor<...>, ...>)
//----------------------------------------------------------------------------
@@ -148,6 +152,15 @@
pm.addPass(iree_compiler::Shape::createConvertHLOToShapePass());
}
+void registerMHLOImportPassPipeline() {
+ mlir::PassPipelineRegistration<> pipeline(
+ "iree-mhlo-import-pipeline",
+ "Run IREE-specific passes for importing MHLO code into IREE",
+ [](OpPassManager &passManager) {
+ buildMHLOImportPassPipeline(passManager);
+ });
+}
+
void registerTFImportPassPipeline() {
mlir::PassPipelineRegistration<> pipeline(
"iree-tf-import-pipeline",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
index dda2e33..8c2f74b 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
@@ -33,8 +33,12 @@
// passes in the right order.
void buildTFImportPassPipeline(OpPassManager &pm);
+void buildMHLOImportPassPipeline(OpPassManager &pm);
+
void registerTFImportPassPipeline();
+void registerMHLOImportPassPipeline();
+
//===----------------------------------------------------------------------===//
// IREE-specific Passes For TensorFlow Import
//===----------------------------------------------------------------------===//
@@ -75,6 +79,7 @@
inline void registerAllPasses() {
registerTFImportPassPipeline();
+ registerMHLOImportPassPipeline();
createConvertToMHLOPass();
createFlattenTuplesInCFGPass();
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index ae979c7..f85d299 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -18,6 +18,7 @@
#include <fstream>
#include <iostream>
+#include "iree_tf_compiler/TF/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
@@ -28,6 +29,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -218,6 +220,17 @@
if (failed(saveToFile(saveTempIreeImport))) return 10;
}
+ // Run passes.
+ PassManager pm(&context, PassManager::Nesting::Implicit);
+ applyPassManagerCLOptions(pm);
+
+ iree_integrations::TF::buildMHLOImportPassPipeline(pm);
+ if (failed(pm.run(*module))) {
+ llvm::errs()
+ << "Running iree-xla-import pass pipeline failed (see diagnostics)\n";
+ return 2;
+ }
+
// Save output.
if (failed(saveToFile(outputFilename))) return 3;
return 0;