blob: 7f207dff80a99a0ff4849612f2731c5ea4ddbd25 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/python/pyiree/compiler.h"
#include <stdexcept>
#include "bindings/python/pyiree/binding.h"
#include "bindings/python/pyiree/initialize.h"
#include "bindings/python/pyiree/status_utils.h"
#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h"
#include "iree/schemas/module_def_generated.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::iree_compiler;
using llvm::MemoryBuffer;
using llvm::MemoryBufferRef;
using llvm::StringRef;
namespace iree {
namespace python {
namespace {
OwningModuleRef parseMLIRModuleFromString(StringRef contents,
MLIRContext* context) {
std::unique_ptr<MemoryBuffer> contents_buffer;
if (contents.back() == 0) {
// If it has a nul terminator, just use as-is.
contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back());
} else {
// Otherwise, make a copy.
contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED");
}
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
return mlir_module;
}
} // namespace
CompilerContextBundle::CompilerContextBundle() {
InitializeExtension({});
// Setup a diagnostic handler.
mlir_context()->getDiagEngine().registerHandler(
[this](mlir::Diagnostic& d) { diagnostics_.push_back(std::move(d)); });
}
CompilerContextBundle::~CompilerContextBundle() = default;
std::string CompilerContextBundle::ConsumeDiagnosticsAsString() {
std::string s;
llvm::raw_string_ostream sout(s);
bool first = true;
for (auto& d : diagnostics_) {
if (!first) {
sout << "\n\n";
} else {
first = false;
}
switch (d.getSeverity()) {
case DiagnosticSeverity::Note:
sout << "[NOTE]";
break;
case DiagnosticSeverity::Warning:
sout << "[WARNING]";
break;
case DiagnosticSeverity::Error:
sout << "[ERROR]";
break;
case DiagnosticSeverity::Remark:
sout << "[REMARK]";
break;
default:
sout << "[UNKNOWN]";
}
// Message.
sout << ": " << d << "\n\t";
// Location.
d.getLocation().print(sout);
}
diagnostics_.clear();
return sout.str();
}
void CompilerContextBundle::ClearDiagnostics() { diagnostics_.clear(); }
CompilerModuleBundle CompilerContextBundle::ParseAsm(
const std::string& asm_text) {
// Arrange to get a view that includes a terminating null to avoid additional
// copy.
const char* asm_chars = asm_text.c_str();
StringRef asm_sr(asm_chars, asm_text.size() + 1);
auto module_ref = parseMLIRModuleFromString(asm_sr, mlir_context());
if (!module_ref) {
throw RaiseValueError("Failed to parse MLIR asm");
}
return CompilerModuleBundle(shared_from_this(), module_ref.release());
}
std::string CompilerModuleBundle::ToAsm() {
// Print to asm.
std::string asm_output;
llvm::raw_string_ostream sout(asm_output);
OpPrintingFlags print_flags;
module_op().print(sout, print_flags);
return sout.str();
}
std::shared_ptr<OpaqueBlob> CompilerModuleBundle::CompileToSequencerBlob() {
auto module_blob =
mlir::iree_compiler::translateMlirToIreeSequencerModule(module_op());
if (module_blob.empty()) {
throw std::runtime_error("Failed to translate MLIR module");
}
return std::make_shared<OpaqueByteVectorBlob>(std::move(module_blob));
}
void CompilerModuleBundle::RunPassPipeline(
const std::vector<std::string>& pipelines) {
mlir::PassManager pm(context_->mlir_context());
// Parse the pass pipelines.
std::string error;
llvm::raw_string_ostream error_stream(error);
for (const auto& pipeline : pipelines) {
if (failed(mlir::parsePassPipeline(pipeline, pm, error_stream))) {
throw RaiseValueError(error_stream.str().c_str());
}
}
// Run them.
if (failed(pm.run(module_op_))) {
throw RaisePyError(PyExc_RuntimeError,
"Error running pass pipelines (see diagnostics)");
}
}
void SetupCompilerBindings(pybind11::module m) {
py::class_<CompilerContextBundle, std::shared_ptr<CompilerContextBundle>>(
m, "CompilerContext")
.def(py::init<>([]() {
// Need explicit make_shared to avoid UB with enable_shared_from_this.
return std::make_shared<CompilerContextBundle>();
}))
.def("parse_asm", &CompilerContextBundle::ParseAsm)
.def("get_diagnostics",
&CompilerContextBundle::ConsumeDiagnosticsAsString)
.def("clear_diagnostics", &CompilerContextBundle::ClearDiagnostics);
py::class_<CompilerModuleBundle>(m, "CompilerModule")
.def("to_asm", &CompilerModuleBundle::ToAsm)
.def("compile_to_sequencer_blob",
&CompilerModuleBundle::CompileToSequencerBlob)
.def("run_pass_pipeline", &CompilerModuleBundle::RunPassPipeline);
}
} // namespace python
} // namespace iree