blob: 842dcee034100a353d6cb116d2ce4ec4c05504d3 [file] [log] [blame]
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include "flatbuffers/util.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace py = pybind11;
void align_tflite_model(const char* input_file_name,
const char* output_file_name) {
std::string model_file;
// Read the file into a string using the included util API call:
flatbuffers::LoadFile(input_file_name, false, &model_file);
// Parse the string into a C++ class. Model is the root object of a tflite
// flatbuffer file.
const tflite::Model* model = tflite::GetModel(model_file.c_str());
// A packed model is basically the file format mmaped into memory.
// Unpacking it and then packing it with the C++ API should yield
// a file with the force_align attributes respected.
// ModelT is just the unpacked version of the model file.
tflite::ModelT* unpacked_model = model->UnPack();
flatbuffers::FlatBufferBuilder fbb;
auto new_model = tflite::Model::Pack(fbb, unpacked_model);
fbb.Finish(new_model, tflite::ModelIdentifier());
flatbuffers::SaveFile(output_file_name,
reinterpret_cast<char*>(fbb.GetBufferPointer()),
fbb.GetSize(), /*binary*/ true);
}
PYBIND11_MODULE(tflite_flatbuffer_align_wrapper, m) {
m.doc() = "tflite_flatbuffer_align_wrapper";
m.def("align_tflite_model", &align_tflite_model,
"Aligns the tflite flatbuffer to (16), by unpacking and repacking via "
"the flatbuffer C++ API.",
py::arg("input_file_name"), py::arg("output_file_name"));
}