blob: dbb82adc8775e5e84785ad4a51823c923f22fa1a [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "integrations/tensorflow/compiler/Passes.h"
#include "iree/base/signature_mangle.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace iree_compiler {
using ::iree::SipSignatureMangler;
namespace {
LogicalResult setRawSignatureIndex(FuncOp funcOp, SipSignatureMangler &mangler,
int rawIndex, ArrayAttr indexPathAttr) {
llvm::SmallVector<SipSignatureMangler::Key, 8> indexKeys;
for (auto &indexAttr : indexPathAttr) {
if (auto stringAttr = indexAttr.dyn_cast<StringAttr>()) {
auto stringRef = stringAttr.getValue();
indexKeys.emplace_back(
absl::string_view(stringRef.data(), stringRef.size()));
} else if (auto intAttr = indexAttr.dyn_cast<IntegerAttr>()) {
indexKeys.emplace_back(intAttr.getInt());
} else {
return funcOp.emitError()
<< "Each index path component must be a string or integer";
}
}
if (!mangler.SetRawSignatureIndex(rawIndex, indexKeys)) {
return funcOp.emitError()
<< "Unable to generate mangled form for index path";
}
return success();
}
} // namespace
class TFSavedModelLowerExportedFunctions
: public PassWrapper<TFSavedModelLowerExportedFunctions,
OperationPass<ModuleOp>> {
public:
void runOnOperation() override {
if (failed(run())) {
signalPassFailure();
}
}
LogicalResult run() {
mlir::Builder builder(getOperation());
Identifier savedModelIndexPathIdent =
builder.getIdentifier("tf_saved_model.index_path");
Identifier ireeReflectionIdent = builder.getIdentifier("iree.reflection");
Identifier ireeModuleExportIdent =
builder.getIdentifier("iree.module.export");
Identifier sipIdent = builder.getIdentifier("sip");
Identifier abiIdent = builder.getIdentifier("abi");
Identifier abiVersionIdent = builder.getIdentifier("abiv");
// Handle saved model exported functions.
for (auto func : getOperation().getOps<FuncOp>()) {
// Transfer exported names to IREE.
auto exported_names = mlir::tf_saved_model::GetExportedNames(func);
if (exported_names.empty()) continue;
// TODO(laurenzo): After VM rework, we should just keep the
// function name as-is and create explicit export ops for each exported
// function.
if (exported_names.size() > 1) {
return func.emitError() << "Multiple exported names not supported yet";
}
func.setName(exported_names.front());
// Function level reflection attributes.
SipSignatureMangler inputsMangler;
SipSignatureMangler resultsMangler;
SmallVector<NamedAttribute, 3> funcReflectAttrs;
funcReflectAttrs.push_back(
builder.getNamedAttr(abiIdent, builder.getStringAttr(sipIdent)));
funcReflectAttrs.push_back(
builder.getNamedAttr(abiVersionIdent, builder.getI32IntegerAttr(1)));
// Tag it as an IREE exported function.
func.setAttr(ireeModuleExportIdent, builder.getUnitAttr());
// Process per-argument attrs and generate reflection metadata.
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
auto indexPathAttr =
func.getArgAttrOfType<mlir::ArrayAttr>(i, savedModelIndexPathIdent);
if (!indexPathAttr) {
return func.emitError()
<< "Missing argument attribute: " << savedModelIndexPathIdent;
}
func.removeArgAttr(i, savedModelIndexPathIdent);
if (failed(
setRawSignatureIndex(func, inputsMangler, i, indexPathAttr))) {
return failure();
}
}
// Process per-result attrs and generate reflection metadata.
for (int i = 0, e = func.getNumResults(); i < e; i++) {
auto indexPathAttr = func.getResultAttrOfType<mlir::ArrayAttr>(
i, savedModelIndexPathIdent);
if (!indexPathAttr) {
return func.emitError()
<< "Missing result attribute: " << savedModelIndexPathIdent;
}
func.removeResultAttr(i, savedModelIndexPathIdent);
if (failed(
setRawSignatureIndex(func, resultsMangler, i, indexPathAttr))) {
return failure();
}
}
// Add the function level reflection attribute.
auto functionSignature = SipSignatureMangler::ToFunctionSignature(
inputsMangler, resultsMangler);
if (!functionSignature) {
return func.emitError() << "Unable to generate sip function signature";
}
funcReflectAttrs.push_back(builder.getNamedAttr(
sipIdent, builder.getStringAttr(functionSignature->encoded())));
if (!funcReflectAttrs.empty()) {
func.setAttr(ireeReflectionIdent,
builder.getDictionaryAttr(funcReflectAttrs));
}
// Remove its designation as a saved model export.
func.removeAttr("tf_saved_model.exported_names");
}
// We should have now removed anything requiring saved model semantics.
getOperation().removeAttr("tf_saved_model.semantics");
return success();
}
};
std::unique_ptr<OperationPass<ModuleOp>>
createTFSavedModelLowerExportedFunctions() {
return std::make_unique<TFSavedModelLowerExportedFunctions>();
}
static PassRegistration<TFSavedModelLowerExportedFunctions> pass(
"iree-tf-saved-model-lower-exported-functions",
"Lower tf_saved_model exported functions.");
} // namespace iree_compiler
} // namespace mlir