Plumb MLIR diagnostics through to python exceptions. PiperOrigin-RevId: 277325865
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc index dccad81..d95a496 100644 --- a/bindings/python/pyiree/compiler.cc +++ b/bindings/python/pyiree/compiler.cc
@@ -59,17 +59,43 @@ } // namespace -CompilerContextBundle::CompilerContextBundle() { - // Setup a diagnostic handler. - mlir_context()->getDiagEngine().registerHandler( - [this](mlir::Diagnostic& d) { diagnostics_.push_back(std::move(d)); }); +DiagnosticCapture::DiagnosticCapture(mlir::MLIRContext* mlir_context, + DiagnosticCapture* parent) + : mlir_context_(mlir_context), parent_(parent) { + handler_id_ = mlir_context_->getDiagEngine().registerHandler( + [&](Diagnostic& d) -> LogicalResult { + diagnostics_.push_back(std::move(d)); + return success(); + }); } -CompilerContextBundle::~CompilerContextBundle() = default; +DiagnosticCapture::~DiagnosticCapture() { + if (mlir_context_) { + mlir_context_->getDiagEngine().eraseHandler(handler_id_); + if (parent_) { + for (auto& d : diagnostics_) { + parent_->diagnostics_.push_back(std::move(d)); + } + } + } +} -std::string CompilerContextBundle::ConsumeDiagnosticsAsString() { +DiagnosticCapture::DiagnosticCapture(DiagnosticCapture&& other) { + mlir_context_ = other.mlir_context_; + parent_ = other.parent_; + diagnostics_.swap(other.diagnostics_); + handler_id_ = other.handler_id_; + other.mlir_context_ = nullptr; +} + +std::string DiagnosticCapture::ConsumeDiagnosticsAsString( + const char* error_message) { std::string s; llvm::raw_string_ostream sout(s); bool first = true; + if (error_message) { + sout << error_message; + first = false; + } for (auto& d : diagnostics_) { if (!first) { sout << "\n\n"; @@ -104,7 +130,11 @@ return sout.str(); } -void CompilerContextBundle::ClearDiagnostics() { diagnostics_.clear(); } +void DiagnosticCapture::ClearDiagnostics() { diagnostics_.clear(); } + +CompilerContextBundle::CompilerContextBundle() + : default_capture_(&mlir_context_, nullptr) {} +CompilerContextBundle::~CompilerContextBundle() = default; CompilerModuleBundle CompilerContextBundle::ParseAsm( const std::string& asm_text) { @@ -113,9 +143,11 @@ const char* asm_chars = asm_text.c_str(); StringRef asm_sr(asm_chars, asm_text.size() + 1); + auto diag_capture = CaptureDiagnostics(); auto module_ref = parseMLIRModuleFromString(asm_sr, mlir_context()); if (!module_ref) { - throw RaiseValueError("Failed to parse MLIR asm"); + throw RaiseValueError( + diag_capture.ConsumeDiagnosticsAsString("Error parsing ASM").c_str()); } return CompilerModuleBundle(shared_from_this(), module_ref.release()); } @@ -130,10 +162,14 @@ } std::shared_ptr<OpaqueBlob> CompilerModuleBundle::CompileToSequencerBlob() { + auto diag_capture = context_->CaptureDiagnostics(); auto module_blob = mlir::iree_compiler::translateMlirToIreeSequencerModule(module_op()); if (module_blob.empty()) { - throw std::runtime_error("Failed to translate MLIR module"); + throw RaiseValueError( + diag_capture + .ConsumeDiagnosticsAsString("Failed to translate MLIR module") + .c_str()); } return std::make_shared<OpaqueByteVectorBlob>(std::move(module_blob)); } @@ -152,9 +188,12 @@ } // Run them. + auto diag_capture = context_->CaptureDiagnostics(); if (failed(pm.run(module_op_))) { - throw RaisePyError(PyExc_RuntimeError, - "Error running pass pipelines (see diagnostics)"); + throw RaisePyError( + PyExc_RuntimeError, + diag_capture.ConsumeDiagnosticsAsString("Error running pass pipelines:") + .c_str()); } }
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h index d40c057..b7daca4 100644 --- a/bindings/python/pyiree/compiler.h +++ b/bindings/python/pyiree/compiler.h
@@ -49,6 +49,28 @@ mlir::ModuleOp module_op_; }; +// Registers to receive diagnostics for a scope. +// When this goes out of scope, any remaining diagnostics will be added to +// the parent. +class DiagnosticCapture { + public: + DiagnosticCapture(mlir::MLIRContext* mlir_context, DiagnosticCapture* parent); + ~DiagnosticCapture(); + DiagnosticCapture(DiagnosticCapture&& other); + + std::vector<mlir::Diagnostic>& diagnostics() { return diagnostics_; } + + // Consumes/clears diagnostics. + std::string ConsumeDiagnosticsAsString(const char* error_message); + void ClearDiagnostics(); + + private: + mlir::MLIRContext* mlir_context_; + DiagnosticCapture* parent_; + std::vector<mlir::Diagnostic> diagnostics_; + mlir::DiagnosticEngine::HandlerID handler_id_; +}; + // Bundle of MLIRContext related things that facilitates interop with // Python. class CompilerContextBundle @@ -61,13 +83,24 @@ CompilerModuleBundle ParseAsm(const std::string& asm_text); + // Gets the default diagnostic capture. + DiagnosticCapture& DefaultDiagnosticCapture() { return default_capture_; } + + // Creates a new diagnostic region. + // Note that this only supports one deep at present. + DiagnosticCapture CaptureDiagnostics() { + return DiagnosticCapture(&mlir_context_, &default_capture_); + } + // Consumes/clears diagnostics. - std::string ConsumeDiagnosticsAsString(); - void ClearDiagnostics(); + std::string ConsumeDiagnosticsAsString() { + return default_capture_.ConsumeDiagnosticsAsString(nullptr); + } + void ClearDiagnostics() { default_capture_.ClearDiagnostics(); } private: mlir::MLIRContext mlir_context_; - std::vector<mlir::Diagnostic> diagnostics_; + DiagnosticCapture default_capture_; }; void SetupCompilerBindings(pybind11::module m);
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py index b90b2c4..4c9eb34 100644 --- a/bindings/python/pyiree/compiler_test.py +++ b/bindings/python/pyiree/compiler_test.py
@@ -25,10 +25,8 @@ def testParseError(self): ctx = binding.compiler.CompilerContext() - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"): ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""") - diag_str = ctx.get_diagnostics() - self.assertRegex(diag_str, "custom op 'FOOBAR' is unknown") def testParseAndCompileToSequencer(self): ctx = binding.compiler.CompilerContext()