Plumb MLIR diagnostics through to python exceptions.

PiperOrigin-RevId: 277325865
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc
index dccad81..d95a496 100644
--- a/bindings/python/pyiree/compiler.cc
+++ b/bindings/python/pyiree/compiler.cc
@@ -59,17 +59,43 @@
 
 }  // namespace
 
-CompilerContextBundle::CompilerContextBundle() {
-  // Setup a diagnostic handler.
-  mlir_context()->getDiagEngine().registerHandler(
-      [this](mlir::Diagnostic& d) { diagnostics_.push_back(std::move(d)); });
+DiagnosticCapture::DiagnosticCapture(mlir::MLIRContext* mlir_context,
+                                     DiagnosticCapture* parent)
+    : mlir_context_(mlir_context), parent_(parent) {
+  handler_id_ = mlir_context_->getDiagEngine().registerHandler(
+      [&](Diagnostic& d) -> LogicalResult {
+        diagnostics_.push_back(std::move(d));
+        return success();
+      });
 }
-CompilerContextBundle::~CompilerContextBundle() = default;
+DiagnosticCapture::~DiagnosticCapture() {
+  if (mlir_context_) {
+    mlir_context_->getDiagEngine().eraseHandler(handler_id_);
+    if (parent_) {
+      for (auto& d : diagnostics_) {
+        parent_->diagnostics_.push_back(std::move(d));
+      }
+    }
+  }
+}
 
-std::string CompilerContextBundle::ConsumeDiagnosticsAsString() {
+DiagnosticCapture::DiagnosticCapture(DiagnosticCapture&& other) {
+  mlir_context_ = other.mlir_context_;
+  parent_ = other.parent_;
+  diagnostics_.swap(other.diagnostics_);
+  handler_id_ = other.handler_id_;
+  other.mlir_context_ = nullptr;
+}
+
+std::string DiagnosticCapture::ConsumeDiagnosticsAsString(
+    const char* error_message) {
   std::string s;
   llvm::raw_string_ostream sout(s);
   bool first = true;
+  if (error_message) {
+    sout << error_message;
+    first = false;
+  }
   for (auto& d : diagnostics_) {
     if (!first) {
       sout << "\n\n";
@@ -104,7 +130,11 @@
   return sout.str();
 }
 
-void CompilerContextBundle::ClearDiagnostics() { diagnostics_.clear(); }
+void DiagnosticCapture::ClearDiagnostics() { diagnostics_.clear(); }
+
+CompilerContextBundle::CompilerContextBundle()
+    : default_capture_(&mlir_context_, nullptr) {}
+CompilerContextBundle::~CompilerContextBundle() = default;
 
 CompilerModuleBundle CompilerContextBundle::ParseAsm(
     const std::string& asm_text) {
@@ -113,9 +143,11 @@
   const char* asm_chars = asm_text.c_str();
   StringRef asm_sr(asm_chars, asm_text.size() + 1);
 
+  auto diag_capture = CaptureDiagnostics();
   auto module_ref = parseMLIRModuleFromString(asm_sr, mlir_context());
   if (!module_ref) {
-    throw RaiseValueError("Failed to parse MLIR asm");
+    throw RaiseValueError(
+        diag_capture.ConsumeDiagnosticsAsString("Error parsing ASM").c_str());
   }
   return CompilerModuleBundle(shared_from_this(), module_ref.release());
 }
@@ -130,10 +162,14 @@
 }
 
 std::shared_ptr<OpaqueBlob> CompilerModuleBundle::CompileToSequencerBlob() {
+  auto diag_capture = context_->CaptureDiagnostics();
   auto module_blob =
       mlir::iree_compiler::translateMlirToIreeSequencerModule(module_op());
   if (module_blob.empty()) {
-    throw std::runtime_error("Failed to translate MLIR module");
+    throw RaiseValueError(
+        diag_capture
+            .ConsumeDiagnosticsAsString("Failed to translate MLIR module")
+            .c_str());
   }
   return std::make_shared<OpaqueByteVectorBlob>(std::move(module_blob));
 }
@@ -152,9 +188,12 @@
   }
 
   // Run them.
+  auto diag_capture = context_->CaptureDiagnostics();
   if (failed(pm.run(module_op_))) {
-    throw RaisePyError(PyExc_RuntimeError,
-                       "Error running pass pipelines (see diagnostics)");
+    throw RaisePyError(
+        PyExc_RuntimeError,
+        diag_capture.ConsumeDiagnosticsAsString("Error running pass pipelines:")
+            .c_str());
   }
 }
 
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h
index d40c057..b7daca4 100644
--- a/bindings/python/pyiree/compiler.h
+++ b/bindings/python/pyiree/compiler.h
@@ -49,6 +49,28 @@
   mlir::ModuleOp module_op_;
 };
 
+// Registers to receive diagnostics for a scope.
+// When this goes out of scope, any remaining diagnostics will be added to
+// the parent.
+class DiagnosticCapture {
+ public:
+  DiagnosticCapture(mlir::MLIRContext* mlir_context, DiagnosticCapture* parent);
+  ~DiagnosticCapture();
+  DiagnosticCapture(DiagnosticCapture&& other);
+
+  std::vector<mlir::Diagnostic>& diagnostics() { return diagnostics_; }
+
+  // Consumes/clears diagnostics.
+  std::string ConsumeDiagnosticsAsString(const char* error_message);
+  void ClearDiagnostics();
+
+ private:
+  mlir::MLIRContext* mlir_context_;
+  DiagnosticCapture* parent_;
+  std::vector<mlir::Diagnostic> diagnostics_;
+  mlir::DiagnosticEngine::HandlerID handler_id_;
+};
+
 // Bundle of MLIRContext related things that facilitates interop with
 // Python.
 class CompilerContextBundle
@@ -61,13 +83,24 @@
 
   CompilerModuleBundle ParseAsm(const std::string& asm_text);
 
+  // Gets the default diagnostic capture.
+  DiagnosticCapture& DefaultDiagnosticCapture() { return default_capture_; }
+
+  // Creates a new diagnostic region.
+  // Note that this only supports one deep at present.
+  DiagnosticCapture CaptureDiagnostics() {
+    return DiagnosticCapture(&mlir_context_, &default_capture_);
+  }
+
   // Consumes/clears diagnostics.
-  std::string ConsumeDiagnosticsAsString();
-  void ClearDiagnostics();
+  std::string ConsumeDiagnosticsAsString() {
+    return default_capture_.ConsumeDiagnosticsAsString(nullptr);
+  }
+  void ClearDiagnostics() { default_capture_.ClearDiagnostics(); }
 
  private:
   mlir::MLIRContext mlir_context_;
-  std::vector<mlir::Diagnostic> diagnostics_;
+  DiagnosticCapture default_capture_;
 };
 
 void SetupCompilerBindings(pybind11::module m);
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py
index b90b2c4..4c9eb34 100644
--- a/bindings/python/pyiree/compiler_test.py
+++ b/bindings/python/pyiree/compiler_test.py
@@ -25,10 +25,8 @@
 
   def testParseError(self):
     ctx = binding.compiler.CompilerContext()
-    with self.assertRaises(ValueError):
+    with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"):
       ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")
-    diag_str = ctx.get_diagnostics()
-    self.assertRegex(diag_str, "custom op 'FOOBAR' is unknown")
 
   def testParseAndCompileToSequencer(self):
     ctx = binding.compiler.CompilerContext()