blob: b2e09b09f6a5a5f2cea420b5782219bf2d006d6b [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/python/pyiree/tensorflow/register_tensorflow.h"
#include <string>
#include <vector>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
using namespace mlir; // NOLINT
namespace iree {
namespace python {
namespace {
std::string ImportSavedModelToMlirAsm(const std::string& saved_model_dir,
std::vector<std::string> exported_names,
std::vector<std::string> tags) {
std::unordered_set<std::string> tags_set;
for (const auto& tag : tags) {
tags_set.insert(tag);
}
MLIRContext context;
auto module = tensorflow::SavedModelToMlirImport(
saved_model_dir, tags_set, absl::MakeSpan(exported_names), &context);
// Print to asm.
std::string asm_output;
llvm::raw_string_ostream sout(asm_output);
OpPrintingFlags print_flags;
module->print(sout, print_flags);
return sout.str();
}
} // namespace
void SetupTensorFlowBindings(pybind11::module m) {
m.def("import_saved_model_to_mlir_asm", &ImportSavedModelToMlirAsm,
py::arg("saved_model_dir"),
py::arg("exported_names") = std::vector<std::string>(),
py::arg("tags") = std::vector<std::string>({std::string("serve")}));
}
} // namespace python
} // namespace iree