Extend the XLA importer to also allow MHLO MLIR Text. (#6103)

* Extend the XLA importer to also allow MHLO MLIR Text.
* Previously was just allowing protos and XLA text format.
* Also adds some niceties that were missing since this importer was pretty bare bones.
* Add control flow legalization passes that were dropped in prior refactor.
* Add SCFToStandard conversion to iree-opt (needed it for debugging).
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 18fdf5b..03f6647 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -104,7 +104,9 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
         "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
         "@org_tensorflow//tensorflow/compiler/xla/service:hlo_parser",
         "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
index b718455..b9bf77a 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
@@ -40,6 +40,7 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:Shape",
         "@llvm-project//mlir:ShapeTransforms",
         "@llvm-project//mlir:StandardOps",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
index 4a71f1f..2719f23 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
@@ -19,6 +19,8 @@
 #include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Dialect/SCF/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -42,7 +44,9 @@
   pm.addPass(mlir::createInlinerPass());
   pm.addNestedPass<FuncOp>(mhlo::createControlFlowToScfPass());
   pm.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
+  pm.addNestedPass<FuncOp>(mlir::createLowerToCFGPass());
   pm.addPass(createFlattenTuplesInCFGPass());
+  pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
 
   // Mostly delicate to the IREE side MHLO legalization pipeline, now that
   // we have handled the weird that comes from legacy HLO clients.
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 07562db..dbff7f9 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -15,12 +15,15 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/ToolOutputFile.h"
+#include "mlir-hlo/Dialect/mhlo/IR/register.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Parser.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
@@ -37,6 +40,7 @@
   binary_proto,
   text_proto,
   hlo_text,
+  mlir_text,
 };
 
 // Error collector that prints errors.
@@ -103,69 +107,40 @@
       llvm::cl::init(""));
   static llvm::cl::opt<XlaFormat> inputFormat(
       "xla-format", cl::desc("XLA Format"),
-      cl::values(clEnumVal(binary_proto, "Parse a binary protocol buffer"),
-                 clEnumVal(text_proto, "Parse a text protocol buffer"),
-                 clEnumVal(hlo_text,
-                           "Parse an HLO module in its native text format")));
+      cl::values(
+          clEnumVal(binary_proto, "Parse a binary protocol buffer"),
+          clEnumVal(text_proto, "Parse a text protocol buffer"),
+          clEnumVal(hlo_text, "Parse an HLO module in its native text format"),
+          clEnumVal(mlir_text, "Parse MLIR text containing MHLO ops")));
 
   // Register any command line options.
   registerAsmPrinterCLOptions();
   registerMLIRContextCLOptions();
+  registerPassManagerCLOptions();
   registerDefaultTimingManagerCLOptions();
   cl::ParseCommandLineOptions(argc, argv);
 
+  auto openInputStream =
+      [&]() -> llvm::Optional<
+                std::pair<std::istream *, std::unique_ptr<std::ifstream>>> {
+    auto fileInputStream = std::make_unique<std::ifstream>();
+    std::istream *inputStream;
+    if (inputPath == "-") {
+      inputStream = &std::cin;
+    } else {
+      fileInputStream->open(inputPath, std::ios::in | std::ios::binary);
+      if (!fileInputStream->is_open()) {
+        llvm::errs() << "Unable to open input file " << inputPath << "\n";
+        return llvm::None;
+      }
+      inputStream = fileInputStream.get();
+    }
+    return std::make_pair(inputStream, std::move(fileInputStream));
+  };
+
   DialectRegistry registry;
-
-  // Read the protocol buffer.
-  std::ifstream fileInputStream;
-  std::istream *inputStream;
-  if (inputPath == "-") {
-    inputStream = &std::cin;
-  } else {
-    fileInputStream.open(inputPath, std::ios::in | std::ios::binary);
-    if (!fileInputStream.is_open()) {
-      llvm::errs() << "Unable to open input file " << inputPath << "\n";
-      return 1;
-    }
-    inputStream = &fileInputStream;
-  }
-
-  xla::HloProto hloProto;
-  switch (inputFormat) {
-    case binary_proto: {
-      if (!hloProto.mutable_hlo_module()->ParseFromIstream(inputStream)) {
-        llvm::errs() << "Could not parse binary protocol buffer from "
-                     << inputPath << "\n";
-        return 1;
-      }
-      break;
-    }
-    case text_proto: {
-      tensorflow::protobuf::TextFormat::Parser parser;
-      PrintErrorCollector collector(inputPath);
-      IStreamCopyingInputStream copyingStream(inputStream);
-      tensorflow::protobuf::io::CopyingInputStreamAdaptor streamAdaptor(
-          &copyingStream);
-      parser.RecordErrorsTo(&collector);
-      parser.Parse(&streamAdaptor, hloProto.mutable_hlo_module());
-      if (collector.hadError) {
-        llvm::errs() << "Unable to parse text format protocol buffer\n";
-        return 1;
-      }
-      break;
-    }
-    case hlo_text: {
-      if (failed(ReadHloTextFormatFromStream(inputStream,
-                                             hloProto.mutable_hlo_module()))) {
-        return 1;
-      }
-      break;
-    }
-    default:
-      llvm_unreachable("illegal XlaFormat");
-  }
-
-  // Convert the Module proto into MLIR.
+  mlir::mhlo::registerAllMhloDialects(registry);
+  registry.insert<mlir::StandardOpsDialect>();
   MLIRContext context;
   OwningModuleRef module = ModuleOp::create(mlir::UnknownLoc::get(&context));
   context.appendDialectRegistry(registry);
@@ -174,12 +149,79 @@
   llvm::SourceMgr sourceMgr;
   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
 
-  auto status =
-      ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
-  if (!status.ok()) {
-    llvm::errs() << "Error converting HLO Module Proto to MLIR: "
-                 << status.ToString() << "\n";
-    return 2;
+  auto loadHloProtoIntoModule = [&](xla::HloProto &hloProto) -> LogicalResult {
+    auto status =
+        ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
+    if (!status.ok()) {
+      llvm::errs() << "Error converting HLO Module Proto to MLIR: "
+                   << status.ToString() << "\n";
+      return failure();
+    }
+    return success();
+  };
+
+  switch (inputFormat) {
+    case binary_proto: {
+      xla::HloProto hloProto;
+      auto input = openInputStream();
+      if (!input) {
+        return 1;
+      }
+      if (!hloProto.mutable_hlo_module()->ParseFromIstream(input->first)) {
+        llvm::errs() << "Could not parse binary protocol buffer from "
+                     << inputPath << "\n";
+        return 1;
+      }
+      if (failed(loadHloProtoIntoModule(hloProto))) return 2;
+      break;
+    }
+    case text_proto: {
+      xla::HloProto hloProto;
+      auto input = openInputStream();
+      if (!input) {
+        return 1;
+      }
+      tensorflow::protobuf::TextFormat::Parser parser;
+      PrintErrorCollector collector(inputPath);
+      IStreamCopyingInputStream copyingStream(input->first);
+      tensorflow::protobuf::io::CopyingInputStreamAdaptor streamAdaptor(
+          &copyingStream);
+      parser.RecordErrorsTo(&collector);
+      parser.Parse(&streamAdaptor, hloProto.mutable_hlo_module());
+      if (collector.hadError) {
+        llvm::errs() << "Unable to parse text format protocol buffer\n";
+        return 1;
+      }
+      if (failed(loadHloProtoIntoModule(hloProto))) return 2;
+      break;
+    }
+    case hlo_text: {
+      xla::HloProto hloProto;
+      auto input = openInputStream();
+      if (!input) {
+        return 1;
+      }
+      if (failed(ReadHloTextFormatFromStream(input->first,
+                                             hloProto.mutable_hlo_module()))) {
+        return 1;
+      }
+      if (failed(loadHloProtoIntoModule(hloProto))) return 2;
+      break;
+    }
+    case mlir_text: {
+      std::string errorMessage;
+      auto file = openInputFile(inputPath, &errorMessage);
+      if (!file) {
+        llvm::errs() << errorMessage << "\n";
+        return 1;
+      }
+      sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+      module = parseSourceFile(sourceMgr, &context);
+      if (!module) return 2;
+      break;
+    }
+    default:
+      llvm_unreachable("illegal XlaFormat");
   }
 
   // Find the entry function and annotate it as exported.
diff --git a/iree/tools/init_mlir_passes.h b/iree/tools/init_mlir_passes.h
index e1e337e..8c076d0 100644
--- a/iree/tools/init_mlir_passes.h
+++ b/iree/tools/init_mlir_passes.h
@@ -64,6 +64,7 @@
   // SCF
   registerSCFParallelLoopFusionPass();
   registerSCFParallelLoopTilingPass();
+  registerSCFToStandardPass();
 
   // Quant
   quant::registerQuantPasses();