blob: dcbf2e7152cbe214fd9d6722dec3c8b4d65ef67c [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Main entry-point for the XLA proto importer frontend.
// This will read from XLA proto files and produce MLIR MHLO assembly.
#include <fstream>
#include <iostream>
#include "iree_tf_compiler/MHLO/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/core/platform/protobuf.h"
using namespace llvm;
using namespace mlir;
namespace {
enum XlaFormat {
binary_proto,
text_proto,
hlo_text,
mlir_text,
};
// Error collector that prints errors.
class PrintErrorCollector : public tensorflow::protobuf::io::ErrorCollector {
public:
PrintErrorCollector(std::string filePath) : filePath(std::move(filePath)) {}
void AddError(int line, int column, const std::string &message) override {
llvm::errs() << "Text protobuf parse error(" << filePath << ":" << line
<< ":" << column << "): " << message << "\n";
hadError = true;
}
std::string filePath;
bool hadError = false;
};
class IStreamCopyingInputStream
: public tensorflow::protobuf::io::CopyingInputStream {
public:
IStreamCopyingInputStream(std::istream *input) : input(input) {}
int Read(void *buffer, int size) override {
input->read(static_cast<char *>(buffer), size);
if (input->fail()) return -1;
return input->gcount();
}
private:
std::istream *input;
};
LogicalResult ReadHloTextFormatFromStream(std::istream *in,
xla::HloModuleProto *moduleProto) {
std::string contents(std::istreambuf_iterator<char>(*in), {});
if (in->fail()) {
llvm::errs() << "Error reading input stream\n";
return failure();
}
auto moduleOr = xla::ParseAndReturnUnverifiedModule(contents);
if (!moduleOr.ok()) {
llvm::errs() << "XLA failed to parse a text format HloModule:\n"
<< moduleOr.status().ToString() << "\n";
return failure();
}
auto module = std::move(*moduleOr);
*moduleProto = module->ToProto();
return success();
}
} // namespace
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
static cl::opt<std::string> inputPath(
cl::Positional, cl::desc("<XLA Protocol Buffer Path>"), cl::Required);
static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));
static llvm::cl::opt<std::string> saveTempMhloInput(
"save-temp-mhlo-input",
llvm::cl::desc("Save the MHLO pipeline input IR to this file"),
llvm::cl::init(""));
static llvm::cl::opt<std::string> saveTempIreeImport(
"save-temp-iree-input",
llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
"intermediate in a pipeline)"),
llvm::cl::init(""));
static llvm::cl::opt<XlaFormat> inputFormat(
"xla-format", cl::desc("XLA Format"),
cl::values(
clEnumVal(binary_proto, "Parse a binary protocol buffer"),
clEnumVal(text_proto, "Parse a text protocol buffer"),
clEnumVal(hlo_text, "Parse an HLO module in its native text format"),
clEnumVal(mlir_text, "Parse MLIR text containing MHLO ops")));
// Register any command line options.
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
registerPassManagerCLOptions();
registerDefaultTimingManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv);
auto openInputStream =
[&]() -> llvm::Optional<
std::pair<std::istream *, std::unique_ptr<std::ifstream>>> {
auto fileInputStream = std::make_unique<std::ifstream>();
std::istream *inputStream;
if (inputPath == "-") {
inputStream = &std::cin;
} else {
fileInputStream->open(inputPath, std::ios::in | std::ios::binary);
if (!fileInputStream->is_open()) {
llvm::errs() << "Unable to open input file " << inputPath << "\n";
return llvm::None;
}
inputStream = fileInputStream.get();
}
return std::make_pair(inputStream, std::move(fileInputStream));
};
DialectRegistry registry;
mlir::mhlo::registerAllMhloDialects(registry);
registry.insert<mlir::StandardOpsDialect, mlir::arith::ArithmeticDialect,
mlir::math::MathDialect>();
MLIRContext context;
OwningModuleRef module = ModuleOp::create(mlir::UnknownLoc::get(&context));
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
auto loadHloProtoIntoModule = [&](xla::HloProto &hloProto) -> LogicalResult {
auto status =
ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
if (!status.ok()) {
llvm::errs() << "Error converting HLO Module Proto to MLIR: "
<< status.ToString() << "\n";
return failure();
}
return success();
};
switch (inputFormat) {
case binary_proto: {
xla::HloProto hloProto;
auto input = openInputStream();
if (!input) {
return 1;
}
if (!hloProto.mutable_hlo_module()->ParseFromIstream(input->first)) {
llvm::errs() << "Could not parse binary protocol buffer from "
<< inputPath << "\n";
return 1;
}
if (failed(loadHloProtoIntoModule(hloProto))) return 2;
break;
}
case text_proto: {
xla::HloProto hloProto;
auto input = openInputStream();
if (!input) {
return 1;
}
tensorflow::protobuf::TextFormat::Parser parser;
PrintErrorCollector collector(inputPath);
IStreamCopyingInputStream copyingStream(input->first);
tensorflow::protobuf::io::CopyingInputStreamAdaptor streamAdaptor(
&copyingStream);
parser.RecordErrorsTo(&collector);
parser.Parse(&streamAdaptor, hloProto.mutable_hlo_module());
if (collector.hadError) {
llvm::errs() << "Unable to parse text format protocol buffer\n";
return 1;
}
if (failed(loadHloProtoIntoModule(hloProto))) return 2;
break;
}
case hlo_text: {
xla::HloProto hloProto;
auto input = openInputStream();
if (!input) {
return 1;
}
if (failed(ReadHloTextFormatFromStream(input->first,
hloProto.mutable_hlo_module()))) {
return 1;
}
if (failed(loadHloProtoIntoModule(hloProto))) return 2;
break;
}
case mlir_text: {
std::string errorMessage;
auto file = openInputFile(inputPath, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
module = parseSourceFile(sourceMgr, &context);
if (!module) return 2;
break;
}
default:
llvm_unreachable("illegal XlaFormat");
}
// Find the entry function and annotate it as exported.
// Note that the XLA importer always produced an MLIR module with a @main
// function.
std::string entryName = "main";
SymbolTable symbolTable(module.get());
auto mainFunc = symbolTable.lookup<FuncOp>(entryName);
if (!mainFunc) {
llvm::errs() << "Unable to find main function '" << entryName
<< "' in converted module.\n";
return 3;
}
mainFunc.setPublic();
// Save.
auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult {
auto outputFile = openOutputFile(savePath);
if (!outputFile) {
llvm::errs() << "Could not open output file: " << savePath << "\n";
return failure();
}
OpPrintingFlags printFlags;
module->print(outputFile->os(), printFlags);
outputFile->os() << "\n";
outputFile->keep();
return success();
};
// Save temp output.
if (!saveTempMhloInput.empty()) {
if (failed(saveToFile(saveTempMhloInput))) return 10;
}
// Run passes.
PassManager pm(&context, PassManager::Nesting::Implicit);
applyPassManagerCLOptions(pm);
applyDefaultTimingPassManagerCLOptions(pm);
iree_integrations::MHLO::buildMHLOImportPassPipeline(pm);
// Note that we emit the ABI last since any needed function-level
// transformations (i.e. de-tupling, etc) should have been done.
pm.addNestedPass<FuncOp>(
iree_integrations::MHLO::createEmitDefaultIREEABIPass());
if (failed(pm.run(*module))) {
llvm::errs() << "Running iree-xla-import MHLO import pass pipeline failed "
"(see diagnostics)\n";
return 2;
}
// Save temp output.
if (!saveTempIreeImport.empty()) {
if (failed(saveToFile(saveTempIreeImport))) return 10;
}
// Save output.
if (failed(saveToFile(outputFilename))) return 3;
return 0;
}