blob: 9c3bdb451058edbb91ed9e75ff5df525b041ef80 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree_tf_compiler/TFL/Passes.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
namespace iree_integrations {
namespace TFL {
// Extract the input and output names
static void splitFunctionIONames(StringAttr namesAttr,
llvm::SmallVectorImpl<std::string> &names) {
SmallVector<StringRef, 4> namesRef;
llvm::SplitString(namesAttr.getValue(), namesRef, ",");
for (auto nameRef : namesRef) {
names.push_back(nameRef.str());
}
}
class ConvertModuleMetadataPass
: public PassWrapper<ConvertModuleMetadataPass, OperationPass<ModuleOp>> {
public:
StringRef getArgument() const override {
return "iree-tflite-convert-module-metadata";
}
StringRef getDescription() const override {
return "Converts TFLite attributes to IREE attributes on modules";
}
void runOnOperation() override {
// None currently handled.
}
};
class ConvertFunctionMetadataPass
: public PassWrapper<ConvertFunctionMetadataPass, OperationPass<FuncOp>> {
public:
StringRef getArgument() const override {
return "iree-tflite-convert-function-metadata";
}
StringRef getDescription() const override {
return "Converts TFLite attributes to IREE attributes on functions";
}
void runOnOperation() override {
auto funcOp = getOperation();
// Setup TF entry functions as an IREE entry point and preserve the
// associated metadata. Note that TFLite uses `tf.entry_function`.
auto entryFunctionAttr =
funcOp->getAttrOfType<DictionaryAttr>("tf.entry_function");
if (entryFunctionAttr) {
setupEntryPointAttrs(funcOp, entryFunctionAttr);
}
}
private:
// TF/TFL pack their I/O names on an annoying dictionary. We want our shape
// names to match up with those for readability so we extract them here.
// Is this ugly? Yeah - but such is what we have to deal with here.
void setupEntryPointAttrs(FuncOp funcOp, DictionaryAttr entryFunctionAttr) {
funcOp.setPublic();
auto inputsAttr =
entryFunctionAttr.get("inputs").template dyn_cast_or_null<StringAttr>();
auto outputsAttr = entryFunctionAttr.get("outputs")
.template dyn_cast_or_null<StringAttr>();
if (!inputsAttr || !outputsAttr) {
funcOp.emitError() << "functions with tf.entry_function must have "
"input and output names to be handled by IREE";
signalPassFailure();
return;
}
SmallVector<std::string, 4> inputNames;
SmallVector<std::string, 4> outputNames;
splitFunctionIONames(inputsAttr, inputNames);
splitFunctionIONames(outputsAttr, outputNames);
if (inputNames.size() != funcOp.getNumArguments() ||
outputNames.size() != funcOp.getNumResults()) {
funcOp.emitError()
<< "tf.entry_function attribute malformed: inputs/outputs don't "
"match the function signature";
signalPassFailure();
return;
}
for (unsigned i = 0; i < inputNames.size(); ++i) {
funcOp.setArgAttr(i, "iree.identifier",
StringAttr::get(&getContext(), inputNames[i]));
}
for (unsigned i = 0; i < outputNames.size(); ++i) {
funcOp.setResultAttr(i, "iree.identifier",
StringAttr::get(&getContext(), outputNames[i]));
}
}
};
std::unique_ptr<OperationPass<ModuleOp>> createConvertModuleMetadataPass() {
return std::make_unique<ConvertModuleMetadataPass>();
}
std::unique_ptr<OperationPass<FuncOp>> createConvertFunctionMetadataPass() {
return std::make_unique<ConvertFunctionMetadataPass>();
}
static PassRegistration<ConvertModuleMetadataPass> modulePass;
static PassRegistration<ConvertFunctionMetadataPass> funcPass;
} // namespace TFL
} // namespace iree_integrations
} // namespace mlir