blob: 7e098576216bb5b05a5ea4b98417a1b27944a0a4 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <memory>
#include <utility>
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Any fp32 derived type is illegal.
static bool isIllegalType(Type type) {
if (type.isF32()) return true;
if (auto ptrType = type.dyn_cast<IREE::Util::PtrType>()) {
return isIllegalType(ptrType.getTargetType());
}
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return isIllegalType(shapedType.getElementType());
}
return false;
}
class FloatTypeConverter : public TypeConverter {
public:
static Type convertTensor(RankedTensorType type) {
if (!type.getElementType().isF32()) return type;
auto newType = RankedTensorType::get(type.getShape(),
Float16Type::get(type.getContext()));
return newType;
}
explicit FloatTypeConverter() {
addConversion([](Type type) { return type; });
addConversion([&](FloatType type) {
if (type.isF32()) return FloatType::getF16(type.getContext());
return type;
});
addConversion(convertTensor);
addConversion([&](IREE::Util::PtrType ptrType) {
if (auto tensorType =
ptrType.getTargetType().dyn_cast<RankedTensorType>()) {
return IREE::Util::PtrType::get(convertTensor(tensorType));
}
return ptrType;
});
}
};
// Generic pattern to convert FP32 values and attributes to FP16.
class GenericTypeConvert : public ConversionPattern {
public:
GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<NamedAttribute, 4> newAttr;
convertAttributes(op->getAttrs(), rewriter, newAttr);
llvm::SmallVector<Type, 4> newResults;
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttr, op->getSuccessors());
for (Region &r : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
}
Operation *newOp = rewriter.createOperation(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
protected:
static void convertAttributes(ArrayRef<NamedAttribute> attrs,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<NamedAttribute> &newAttrs) {
for (auto attr : attrs) {
if (auto fpAttr = attr.getValue().dyn_cast<DenseFPElementsAttr>()) {
std::vector<llvm::APFloat> args;
if (!fpAttr.getType().getElementType().isF32()) continue;
for (llvm::APFloat f : fpAttr.getValues<APFloat>()) {
bool losesInfo;
f.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &losesInfo);
args.push_back(f);
}
auto tensorType = RankedTensorType::get(fpAttr.getType().getShape(),
rewriter.getF16Type());
newAttrs.emplace_back(attr.getName(),
DenseElementsAttr::get(tensorType, args));
} else if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
if (isIllegalType(typeAttr.getValue())) {
if (auto tensorType =
typeAttr.getValue().dyn_cast<RankedTensorType>()) {
Type newType = RankedTensorType::get(tensorType.getShape(),
rewriter.getF16Type());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newType));
}
}
} else {
newAttrs.push_back(attr);
}
}
}
};
struct DemoteF32ToF16Pass : public DemoteF32ToF16Base<DemoteF32ToF16Pass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp moduleOp = getOperation();
FloatTypeConverter converter;
RewritePatternSet patterns(&getContext());
patterns.insert<GenericTypeConvert>(context, converter);
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
converter);
ConversionTarget target(*context);
// Operations are legal if they don't contain any illegal type.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (auto globalOp = dyn_cast<IREE::Util::GlobalOp>(op)) {
return !isIllegalType(globalOp.type());
}
if (auto funcOp = dyn_cast<FuncOp>(op)) {
for (Type type : funcOp.getType().getInputs()) {
if (isIllegalType(type)) return false;
}
for (Type type : funcOp.getType().getResults()) {
if (isIllegalType(type)) return false;
}
}
for (Type type : op->getResultTypes()) {
if (isIllegalType(type)) return false;
}
for (Type type : op->getOperandTypes()) {
if (isIllegalType(type)) return false;
}
return true;
});
if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> createDemoteF32ToF16Pass() {
return std::make_unique<DemoteF32ToF16Pass>();
}
} // namespace iree_compiler
} // namespace mlir