Extend the XLA importer to also allow MHLO MLIR Text. (#6103)
* Extend the XLA importer to also allow MHLO MLIR Text.
* Previously was just allowing protos and XLA text format.
* Also adds some niceties that were missing since this importer was pretty bare bones.
* Add control flow legalization passes that were dropped in prior refactor.
* Add SCFToStandard conversion to iree-opt (needed it for debugging).
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 18fdf5b..03f6647 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -104,7 +104,9 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"@org_tensorflow//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_parser",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
index b718455..b9bf77a 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
@@ -40,6 +40,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
index 4a71f1f..2719f23 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
@@ -19,6 +19,8 @@
#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -42,7 +44,9 @@
pm.addPass(mlir::createInlinerPass());
pm.addNestedPass<FuncOp>(mhlo::createControlFlowToScfPass());
pm.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
+ pm.addNestedPass<FuncOp>(mlir::createLowerToCFGPass());
pm.addPass(createFlattenTuplesInCFGPass());
+ pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
// Mostly delicate to the IREE side MHLO legalization pipeline, now that
// we have handled the weird that comes from legacy HLO clients.
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 07562db..dbff7f9 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -15,12 +15,15 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h"
+#include "mlir-hlo/Dialect/mhlo/IR/register.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
@@ -37,6 +40,7 @@
binary_proto,
text_proto,
hlo_text,
+ mlir_text,
};
// Error collector that prints errors.
@@ -103,69 +107,40 @@
llvm::cl::init(""));
static llvm::cl::opt<XlaFormat> inputFormat(
"xla-format", cl::desc("XLA Format"),
- cl::values(clEnumVal(binary_proto, "Parse a binary protocol buffer"),
- clEnumVal(text_proto, "Parse a text protocol buffer"),
- clEnumVal(hlo_text,
- "Parse an HLO module in its native text format")));
+ cl::values(
+ clEnumVal(binary_proto, "Parse a binary protocol buffer"),
+ clEnumVal(text_proto, "Parse a text protocol buffer"),
+ clEnumVal(hlo_text, "Parse an HLO module in its native text format"),
+ clEnumVal(mlir_text, "Parse MLIR text containing MHLO ops")));
// Register any command line options.
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
+ registerPassManagerCLOptions();
registerDefaultTimingManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv);
+ auto openInputStream =
+ [&]() -> llvm::Optional<
+ std::pair<std::istream *, std::unique_ptr<std::ifstream>>> {
+ auto fileInputStream = std::make_unique<std::ifstream>();
+ std::istream *inputStream;
+ if (inputPath == "-") {
+ inputStream = &std::cin;
+ } else {
+ fileInputStream->open(inputPath, std::ios::in | std::ios::binary);
+ if (!fileInputStream->is_open()) {
+ llvm::errs() << "Unable to open input file " << inputPath << "\n";
+ return llvm::None;
+ }
+ inputStream = fileInputStream.get();
+ }
+ return std::make_pair(inputStream, std::move(fileInputStream));
+ };
+
DialectRegistry registry;
-
- // Read the protocol buffer.
- std::ifstream fileInputStream;
- std::istream *inputStream;
- if (inputPath == "-") {
- inputStream = &std::cin;
- } else {
- fileInputStream.open(inputPath, std::ios::in | std::ios::binary);
- if (!fileInputStream.is_open()) {
- llvm::errs() << "Unable to open input file " << inputPath << "\n";
- return 1;
- }
- inputStream = &fileInputStream;
- }
-
- xla::HloProto hloProto;
- switch (inputFormat) {
- case binary_proto: {
- if (!hloProto.mutable_hlo_module()->ParseFromIstream(inputStream)) {
- llvm::errs() << "Could not parse binary protocol buffer from "
- << inputPath << "\n";
- return 1;
- }
- break;
- }
- case text_proto: {
- tensorflow::protobuf::TextFormat::Parser parser;
- PrintErrorCollector collector(inputPath);
- IStreamCopyingInputStream copyingStream(inputStream);
- tensorflow::protobuf::io::CopyingInputStreamAdaptor streamAdaptor(
- ©ingStream);
- parser.RecordErrorsTo(&collector);
- parser.Parse(&streamAdaptor, hloProto.mutable_hlo_module());
- if (collector.hadError) {
- llvm::errs() << "Unable to parse text format protocol buffer\n";
- return 1;
- }
- break;
- }
- case hlo_text: {
- if (failed(ReadHloTextFormatFromStream(inputStream,
- hloProto.mutable_hlo_module()))) {
- return 1;
- }
- break;
- }
- default:
- llvm_unreachable("illegal XlaFormat");
- }
-
- // Convert the Module proto into MLIR.
+ mlir::mhlo::registerAllMhloDialects(registry);
+ registry.insert<mlir::StandardOpsDialect>();
MLIRContext context;
OwningModuleRef module = ModuleOp::create(mlir::UnknownLoc::get(&context));
context.appendDialectRegistry(registry);
@@ -174,12 +149,79 @@
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
- auto status =
- ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
- if (!status.ok()) {
- llvm::errs() << "Error converting HLO Module Proto to MLIR: "
- << status.ToString() << "\n";
- return 2;
+ auto loadHloProtoIntoModule = [&](xla::HloProto &hloProto) -> LogicalResult {
+ auto status =
+ ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
+ if (!status.ok()) {
+ llvm::errs() << "Error converting HLO Module Proto to MLIR: "
+ << status.ToString() << "\n";
+ return failure();
+ }
+ return success();
+ };
+
+ switch (inputFormat) {
+ case binary_proto: {
+ xla::HloProto hloProto;
+ auto input = openInputStream();
+ if (!input) {
+ return 1;
+ }
+ if (!hloProto.mutable_hlo_module()->ParseFromIstream(input->first)) {
+ llvm::errs() << "Could not parse binary protocol buffer from "
+ << inputPath << "\n";
+ return 1;
+ }
+ if (failed(loadHloProtoIntoModule(hloProto))) return 2;
+ break;
+ }
+ case text_proto: {
+ xla::HloProto hloProto;
+ auto input = openInputStream();
+ if (!input) {
+ return 1;
+ }
+ tensorflow::protobuf::TextFormat::Parser parser;
+ PrintErrorCollector collector(inputPath);
+ IStreamCopyingInputStream copyingStream(input->first);
+ tensorflow::protobuf::io::CopyingInputStreamAdaptor streamAdaptor(
+ ©ingStream);
+ parser.RecordErrorsTo(&collector);
+ parser.Parse(&streamAdaptor, hloProto.mutable_hlo_module());
+ if (collector.hadError) {
+ llvm::errs() << "Unable to parse text format protocol buffer\n";
+ return 1;
+ }
+ if (failed(loadHloProtoIntoModule(hloProto))) return 2;
+ break;
+ }
+ case hlo_text: {
+ xla::HloProto hloProto;
+ auto input = openInputStream();
+ if (!input) {
+ return 1;
+ }
+ if (failed(ReadHloTextFormatFromStream(input->first,
+ hloProto.mutable_hlo_module()))) {
+ return 1;
+ }
+ if (failed(loadHloProtoIntoModule(hloProto))) return 2;
+ break;
+ }
+ case mlir_text: {
+ std::string errorMessage;
+ auto file = openInputFile(inputPath, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+ sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+ module = parseSourceFile(sourceMgr, &context);
+ if (!module) return 2;
+ break;
+ }
+ default:
+ llvm_unreachable("illegal XlaFormat");
}
// Find the entry function and annotate it as exported.
diff --git a/iree/tools/init_mlir_passes.h b/iree/tools/init_mlir_passes.h
index e1e337e..8c076d0 100644
--- a/iree/tools/init_mlir_passes.h
+++ b/iree/tools/init_mlir_passes.h
@@ -64,6 +64,7 @@
// SCF
registerSCFParallelLoopFusionPass();
registerSCFParallelLoopTilingPass();
+ registerSCFToStandardPass();
// Quant
quant::registerQuantPasses();