| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "integrations/tensorflow/compiler/Passes.h" |
| #include "iree/base/signature_mangle.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassRegistry.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/Utils.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| using ::iree::SipSignatureMangler; |
| |
| namespace { |
| |
| LogicalResult setRawSignatureIndex(FuncOp funcOp, SipSignatureMangler &mangler, |
| int rawIndex, ArrayAttr indexPathAttr) { |
| llvm::SmallVector<SipSignatureMangler::Key, 8> indexKeys; |
| for (auto &indexAttr : indexPathAttr) { |
| if (auto stringAttr = indexAttr.dyn_cast<StringAttr>()) { |
| auto stringRef = stringAttr.getValue(); |
| indexKeys.emplace_back( |
| absl::string_view(stringRef.data(), stringRef.size())); |
| } else if (auto intAttr = indexAttr.dyn_cast<IntegerAttr>()) { |
| indexKeys.emplace_back(intAttr.getInt()); |
| } else { |
| return funcOp.emitError() |
| << "Each index path component must be a string or integer"; |
| } |
| } |
| |
| if (!mangler.SetRawSignatureIndex(rawIndex, indexKeys)) { |
| return funcOp.emitError() |
| << "Unable to generate mangled form for index path"; |
| } |
| |
| return success(); |
| } |
| |
| } // namespace |
| |
| class TFSavedModelLowerExportedFunctions |
| : public PassWrapper<TFSavedModelLowerExportedFunctions, |
| OperationPass<ModuleOp>> { |
| public: |
| void runOnOperation() override { |
| if (failed(run())) { |
| signalPassFailure(); |
| } |
| } |
| |
| LogicalResult run() { |
| mlir::Builder builder(getOperation()); |
| Identifier savedModelIndexPathIdent = |
| builder.getIdentifier("tf_saved_model.index_path"); |
| Identifier ireeReflectionIdent = builder.getIdentifier("iree.reflection"); |
| Identifier ireeModuleExportIdent = |
| builder.getIdentifier("iree.module.export"); |
| Identifier sipIdent = builder.getIdentifier("sip"); |
| Identifier abiIdent = builder.getIdentifier("abi"); |
| Identifier abiVersionIdent = builder.getIdentifier("abiv"); |
| |
| // Handle saved model exported functions. |
| for (auto func : getOperation().getOps<FuncOp>()) { |
| // Transfer exported names to IREE. |
| auto exported_names = mlir::tf_saved_model::GetExportedNames(func); |
| if (exported_names.empty()) continue; |
| |
| // TODO(laurenzo): After VM rework, we should just keep the |
| // function name as-is and create explicit export ops for each exported |
| // function. |
| if (exported_names.size() > 1) { |
| return func.emitError() << "Multiple exported names not supported yet"; |
| } |
| func.setName(exported_names.front()); |
| |
| // Function level reflection attributes. |
| SipSignatureMangler inputsMangler; |
| SipSignatureMangler resultsMangler; |
| SmallVector<NamedAttribute, 3> funcReflectAttrs; |
| funcReflectAttrs.push_back( |
| builder.getNamedAttr(abiIdent, builder.getStringAttr(sipIdent))); |
| funcReflectAttrs.push_back( |
| builder.getNamedAttr(abiVersionIdent, builder.getI32IntegerAttr(1))); |
| |
| // Tag it as an IREE exported function. |
| func.setAttr(ireeModuleExportIdent, builder.getUnitAttr()); |
| |
| // Process per-argument attrs and generate reflection metadata. |
| for (int i = 0, e = func.getNumArguments(); i < e; i++) { |
| auto indexPathAttr = |
| func.getArgAttrOfType<mlir::ArrayAttr>(i, savedModelIndexPathIdent); |
| if (!indexPathAttr) { |
| return func.emitError() |
| << "Missing argument attribute: " << savedModelIndexPathIdent; |
| } |
| func.removeArgAttr(i, savedModelIndexPathIdent); |
| |
| if (failed( |
| setRawSignatureIndex(func, inputsMangler, i, indexPathAttr))) { |
| return failure(); |
| } |
| } |
| |
| // Process per-result attrs and generate reflection metadata. |
| for (int i = 0, e = func.getNumResults(); i < e; i++) { |
| auto indexPathAttr = func.getResultAttrOfType<mlir::ArrayAttr>( |
| i, savedModelIndexPathIdent); |
| if (!indexPathAttr) { |
| return func.emitError() |
| << "Missing result attribute: " << savedModelIndexPathIdent; |
| } |
| func.removeResultAttr(i, savedModelIndexPathIdent); |
| |
| if (failed( |
| setRawSignatureIndex(func, resultsMangler, i, indexPathAttr))) { |
| return failure(); |
| } |
| } |
| |
| // Add the function level reflection attribute. |
| auto functionSignature = SipSignatureMangler::ToFunctionSignature( |
| inputsMangler, resultsMangler); |
| if (!functionSignature) { |
| return func.emitError() << "Unable to generate sip function signature"; |
| } |
| funcReflectAttrs.push_back(builder.getNamedAttr( |
| sipIdent, builder.getStringAttr(functionSignature->encoded()))); |
| |
| if (!funcReflectAttrs.empty()) { |
| func.setAttr(ireeReflectionIdent, |
| builder.getDictionaryAttr(funcReflectAttrs)); |
| } |
| |
| // Remove its designation as a saved model export. |
| func.removeAttr("tf_saved_model.exported_names"); |
| } |
| |
| // We should have now removed anything requiring saved model semantics. |
| getOperation().removeAttr("tf_saved_model.semantics"); |
| return success(); |
| } |
| }; |
| |
| std::unique_ptr<OperationPass<ModuleOp>> |
| createTFSavedModelLowerExportedFunctions() { |
| return std::make_unique<TFSavedModelLowerExportedFunctions>(); |
| } |
| |
| static PassRegistration<TFSavedModelLowerExportedFunctions> pass( |
| "iree-tf-saved-model-lower-exported-functions", |
| "Lower tf_saved_model exported functions."); |
| |
| } // namespace iree_compiler |
| } // namespace mlir |