| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree_tf_compiler/TFL/PassDetail.h" |
| #include "iree_tf_compiler/TFL/Passes.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| |
| namespace mlir { |
| namespace iree_integrations { |
| namespace TFL { |
| namespace { |
| |
| static bool isTFLAttr(NamedAttribute &namedAttr) { |
| // NOTE: tflite mixes tf and tfl, for some reason. |
| auto name = namedAttr.first.strref(); |
| if (name.startswith("tf.") || name.startswith("tf_") || |
| name.startswith("tfl.") || name.startswith("tfl_")) { |
| return true; |
| } |
| StringRef attrNamespace = namedAttr.second.getDialect().getNamespace(); |
| return attrNamespace == "tf" || attrNamespace == "tfl"; |
| } |
| |
| class StripModuleMetadataPass |
| : public StripModuleMetadataBase<StripModuleMetadataPass> { |
| public: |
| |
| void runOnOperation() override { |
| auto moduleOp = getOperation(); |
| auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( |
| moduleOp->getAttrs(), |
| [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); |
| for (auto namedAttr : stripAttrs) { |
| moduleOp->removeAttr(namedAttr.first); |
| } |
| } |
| }; |
| |
| class StripFunctionMetadataPass |
| : public StripFunctionMetadataBase<StripFunctionMetadataPass> { |
| public: |
| |
| void runOnOperation() override { |
| auto funcOp = getOperation(); |
| auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( |
| funcOp->getAttrs(), |
| [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); |
| for (auto namedAttr : stripAttrs) { |
| funcOp->removeAttr(namedAttr.first); |
| } |
| |
| for (int i = 0; i < funcOp.getNumArguments(); ++i) { |
| auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( |
| funcOp.getArgAttrs(i), |
| [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); |
| for (auto namedAttr : stripAttrs) { |
| funcOp.removeArgAttr(i, namedAttr.first); |
| } |
| } |
| |
| for (int i = 0; i < funcOp.getNumResults(); ++i) { |
| auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( |
| funcOp.getResultAttrs(i), |
| [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); |
| for (auto namedAttr : stripAttrs) { |
| funcOp.removeResultAttr(i, namedAttr.first); |
| } |
| } |
| } |
| }; |
| |
| } // anonymous namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createStripModuleMetadataPass() { |
| return std::make_unique<StripModuleMetadataPass>(); |
| } |
| |
| std::unique_ptr<OperationPass<FuncOp>> createStripFunctionMetadataPass() { |
| return std::make_unique<StripFunctionMetadataPass>(); |
| } |
| |
| } // namespace TFL |
| } // namespace iree_integrations |
| } // namespace mlir |