blob: 3b389198166fdc1768c6376bb88d04c062f9cf1c [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree_tf_compiler/TF/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace iree_integrations {
namespace TF {
static bool isTFAttr(NamedAttribute &namedAttr) {
auto name = namedAttr.first.strref();
if (name.startswith("tf.") || name.startswith("tf_")) {
return true;
}
StringRef attrNamespace = namedAttr.second.getDialect().getNamespace();
return attrNamespace == mlir::TF::TensorFlowDialect::getDialectNamespace() ||
attrNamespace == mlir::tf_executor::TensorFlowExecutorDialect::
getDialectNamespace() ||
attrNamespace ==
mlir::tf_device::TensorFlowDeviceDialect::getDialectNamespace() ||
attrNamespace == mlir::tf_saved_model::TensorFlowSavedModelDialect::
getDialectNamespace();
}
class StripModuleMetadataPass
: public PassWrapper<StripModuleMetadataPass, OperationPass<ModuleOp>> {
public:
void runOnOperation() override {
auto moduleOp = getOperation();
auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
moduleOp->getAttrs(),
[](NamedAttribute namedAttr) { return isTFAttr(namedAttr); }));
for (auto namedAttr : stripAttrs) {
moduleOp->removeAttr(namedAttr.first);
}
}
};
class StripFunctionMetadataPass
: public PassWrapper<StripFunctionMetadataPass, OperationPass<FuncOp>> {
public:
void runOnOperation() override {
auto funcOp = getOperation();
auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
funcOp->getAttrs(),
[](NamedAttribute namedAttr) { return isTFAttr(namedAttr); }));
for (auto namedAttr : stripAttrs) {
funcOp->removeAttr(namedAttr.first);
}
for (int i = 0; i < funcOp.getNumArguments(); ++i) {
auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
funcOp.getArgAttrs(i),
[](NamedAttribute namedAttr) { return isTFAttr(namedAttr); }));
for (auto namedAttr : stripAttrs) {
funcOp.removeArgAttr(i, namedAttr.first);
}
}
for (int i = 0; i < funcOp.getNumResults(); ++i) {
auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range(
funcOp.getResultAttrs(i),
[](NamedAttribute namedAttr) { return isTFAttr(namedAttr); }));
for (auto namedAttr : stripAttrs) {
funcOp.removeResultAttr(i, namedAttr.first);
}
}
}
};
std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass() {
return std::make_unique<StripModuleMetadataPass>();
}
std::unique_ptr<OperationPass<FuncOp>> createStripFunctionMetadataPass() {
return std::make_unique<StripFunctionMetadataPass>();
}
static PassRegistration<StripModuleMetadataPass> modulePass(
"iree-tf-strip-module-metadata",
"Remove unneeded TensorFlow attributes from module ops");
static PassRegistration<StripFunctionMetadataPass> funcPass(
"iree-tf-strip-function-metadata",
"Remove unneeded TensorFlow attributes from func ops");
} // namespace TF
} // namespace iree_integrations
} // namespace mlir