blob: 0d3ff47d23a8ffda78d38513d810ead388ffac7b [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree_tf_compiler/TF/Passes.h"
#include "llvm/Support/JSON.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
namespace json = llvm::json;
namespace mlir {
namespace iree_integrations {
namespace TF {
class EmitDefaultIREEABIPass
: public PassWrapper<EmitDefaultIREEABIPass, OperationPass<FuncOp>> {
public:
void runOnOperation() override {
auto funcOp = getOperation();
if (SymbolTable::getSymbolVisibility(funcOp) !=
SymbolTable::Visibility::Public) {
return;
}
if (funcOp->hasAttr("iree.abi")) {
return;
}
json::Array refArgs;
for (Type t : funcOp.getArgumentTypes()) {
auto descriptor = mapTypeToJsonTypeRecord(t);
if (!descriptor) {
funcOp.emitWarning()
<< "unable to generate reflection descriptor for argument type "
<< t;
return;
}
refArgs.push_back(*descriptor);
}
json::Array refReturns;
for (Type t : funcOp.getCallableResults()) {
auto descriptor = mapTypeToJsonTypeRecord(t);
if (!descriptor) {
funcOp.emitWarning()
<< "unable to generate reflection descriptor for result type " << t;
return;
}
refReturns.push_back(*descriptor);
}
Builder builder(&getContext());
json::Object refDict;
refDict["v"] = json::Value(1);
refDict["a"] = json::Value(std::move(refArgs));
refDict["r"] = json::Value(std::move(refReturns));
json::Value refDictValue(std::move(refDict));
std::string refStr;
llvm::raw_string_ostream refOut(refStr);
refOut << refDictValue;
refOut.flush();
funcOp->setAttr("iree.abi", builder.getStringAttr(refStr));
}
llvm::Optional<json::Value> mapTypeToJsonTypeRecord(Type type) {
if (auto shapedType = type.dyn_cast<ShapedType>()) {
json::Array record({
json::Value("ndarray"),
mapTypeToJsonTypeRecord(shapedType.getElementType()),
shapedType.hasRank() ? json::Value(shapedType.getRank())
: json::Value(nullptr),
});
if (shapedType.hasRank()) {
for (auto dim : shapedType.getShape()) {
record.push_back(dim == ShapedType::kDynamicSize
? json::Value(nullptr)
: json::Value(dim));
}
}
return json::Value(std::move(record));
}
// Primitives.
if (auto integerType = type.dyn_cast<IntegerType>()) {
std::string name = (Twine("i") + Twine(integerType.getWidth())).str();
return json::Value(std::move(name));
}
if (auto floatType = type.dyn_cast<FloatType>()) {
if (floatType == FloatType::getBF16(floatType.getContext())) {
// Why Google?
return json::Value("bf16");
}
std::string name = (Twine("f") + Twine(floatType.getWidth())).str();
return json::Value(std::move(name));
}
return llvm::None;
}
};
std::unique_ptr<OperationPass<FuncOp>> createEmitDefaultIREEABIPass() {
return std::make_unique<EmitDefaultIREEABIPass>();
}
static PassRegistration<EmitDefaultIREEABIPass> funcPass(
"iree-tf-emit-default-iree-abi", "Emits simple default ABI metadata");
} // namespace TF
} // namespace iree_integrations
} // namespace mlir