Enable SourceMgrDiagnosticHandler in importers. (#5600)
* Enable SourceMgrDiagnosticHandler in importers.
* This is a strict improvement over not having a diagnostic handler.
* Also introduces a stub pass for cleaning up TensorFlow locations but I have not yet found an algorithm I like so just leaving it as a stub for later.
* Progress on #5295
diff --git a/integrations/tensorflow/build_tools/testdata/generate_errors_module.py b/integrations/tensorflow/build_tools/testdata/generate_errors_module.py
new file mode 100644
index 0000000..692edd7
--- /dev/null
+++ b/integrations/tensorflow/build_tools/testdata/generate_errors_module.py
@@ -0,0 +1,55 @@
+# Lint as: python3
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Generates sample models for excercising various function signatures.
+
+Usage:
+ generate_errors_module.py /tmp/errors.sm
+
+This can then be fed into iree-tf-import to process it:
+
+Fully convert to IREE input (run all import passes):
+ iree-tf-import /tmp/errors.sm
+
+Import only (useful for crafting test cases for the import pipeline):
+ iree-tf-import -o /dev/null -save-temp-tf-input=- /tmp/errors.sm
+
+Can be further lightly pre-processed via:
+ | iree-tf-opt --tf-standard-pipeline
+"""
+
+import sys
+
+import tensorflow as tf
+
+
+class ErrorsModule(tf.Module):
+
+ @tf.function(input_signature=[tf.TensorSpec([16], tf.float32)])
+ def string_op(self, a):
+ tf.print(a)
+ return a
+
+
+try:
+ file_name = sys.argv[1]
+except IndexError:
+ print("Expected output file name")
+ sys.exit(1)
+
+m = ErrorsModule()
+tf.saved_model.save(m,
+ file_name,
+ options=tf.saved_model.SaveOptions(save_debug_info=True))
+print(f"Saved to {file_name}")
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index 2342967..c7a9002 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -26,6 +26,7 @@
"LowerExportedFunctions.cpp",
"LowerGlobalTensors.cpp",
"Passes.cpp",
+ "PrettifyDebugInfo.cpp",
"PropagateResourceCasts.cpp",
"SavedModelToIreeABI.cpp",
"StripAsserts.cpp",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
index aa94ac3..a34f563 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
@@ -65,6 +65,10 @@
// functions for any saved model exported functions.
std::unique_ptr<OperationPass<ModuleOp>> createSavedModelToIREEABIPass();
+// Simplifies TensorFlow debug info for the purposes of making it easier to
+// look at.
+std::unique_ptr<OperationPass<ModuleOp>> createPrettifyDebugInfoPass();
+
// Push resource casts forward to better propagate resource related shapes.
std::unique_ptr<OperationPass<ModuleOp>> createPropagateResourceCastsPass();
@@ -92,6 +96,7 @@
createFlattenTuplesInCFGPass();
createLowerGlobalTensorsPass();
createLowerExportedFunctionsPass();
+ createPrettifyDebugInfoPass();
createPropagateResourceCastsPass();
createSavedModelToIREEABIPass();
createStripAssertsPass();
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/PrettifyDebugInfo.cpp b/integrations/tensorflow/iree_tf_compiler/TF/PrettifyDebugInfo.cpp
new file mode 100644
index 0000000..71e7374
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TF/PrettifyDebugInfo.cpp
@@ -0,0 +1,47 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TF/Passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TF {
+
+class PrettifyDebugInfoPass
+ : public PassWrapper<PrettifyDebugInfoPass, OperationPass<ModuleOp>> {
+ public:
+ void runOnOperation() override {
+ // TODO: Finish algorithm for simplifying TF debug info.
+ // auto moduleOp = getOperation();
+ // moduleOp.walk([&](Operation *op) {
+ // Location loc = op->getLoc();
+ // if (auto callSite = loc.dyn_cast<CallSiteLoc>()) {
+ // callSite.getCallee().dump();
+ // }
+ // });
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createPrettifyDebugInfoPass() {
+ return std::make_unique<PrettifyDebugInfoPass>();
+}
+
+static PassRegistration<PrettifyDebugInfoPass> modulePass(
+ "iree-tf-prettify-debug-info",
+ "Simplifies TF debug info to make it easier to look at");
+
+} // namespace TF
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
index 6887dd7..7e65d68 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
@@ -76,6 +76,9 @@
MLIRContext context(registry);
context.loadAllAvailableDialects();
+ llvm::SourceMgr sourceMgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+
// Load input buffer.
std::string errorMessage;
auto inputFile = openInputFile(inputPath, &errorMessage);
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 10b672a..e48ed4d 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -178,6 +178,9 @@
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
+ llvm::SourceMgr sourceMgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+
auto status =
ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
if (!status.ok()) {
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp
index 5ac7224..4e8e90b 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp
@@ -24,6 +24,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
@@ -160,6 +161,11 @@
llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
"intermediate in a pipeline)"),
llvm::cl::init(""));
+ static llvm::cl::opt<bool> prettifyTfDebugInfo(
+ "prettify-tf-debug-info",
+ llvm::cl::desc("Prettifies TF debug information to make it easier "
+ "to look at"),
+ llvm::cl::init(true));
// Register any command line options.
registerAsmPrinterCLOptions();
@@ -172,6 +178,10 @@
MLIRContext context(registry);
context.loadAllAvailableDialects();
+
+ llvm::SourceMgr sourceMgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+
OwningModuleRef module;
auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult {
@@ -210,6 +220,10 @@
PassManager pm(&context, PassManager::Nesting::Implicit);
applyPassManagerCLOptions(pm);
+ if (prettifyTfDebugInfo) {
+ pm.addPass(iree_integrations::TF::createPrettifyDebugInfoPass());
+ }
+
iree_integrations::TF::buildTFImportPassPipeline(pm);
if (failed(pm.run(*module))) {
llvm::errs()