blob: 74f08cf2196478940121414d05dd4460e60d9bde [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Utils/ConversionUtils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
static void emitLegalizationErrors(Location loc,
const DenseSet<Operation *> &illegalOps) {
// Print op errors for each of the illegal ops that still remain.
llvm::MapVector<StringRef, int> opNameCounts;
for (Operation *illegalOp : illegalOps) {
StringRef opName = illegalOp->getName().getStringRef();
opNameCounts[opName]++;
illegalOp->emitOpError() << ": illegal op still exists";
}
std::vector<std::string> errorMessages;
errorMessages.reserve(opNameCounts.size());
for (const auto &opInfo : opNameCounts) {
errorMessages.push_back(
llvm::formatv("\t{} (count: {})", opInfo.first, opInfo.second));
}
emitError(loc) << "The following illegal operations still remain: \n"
<< llvm::join(errorMessages, "\n") << "\n";
}
LogicalResult verifyAllOperationsAreLegal(Operation *op,
const ConversionTarget &target) {
// We don't just use applyPartialConversion with no patterns because this pass
// shouldn't alter the IR at all (including via folding or canonicalizations
// that dialect conversion does automatically).
DenseSet<Operation *> illegalOps;
op->walk([&](Operation *op) {
if (!target.isLegal(op)) {
illegalOps.insert(op);
}
});
if (illegalOps.empty())
return success();
emitLegalizationErrors(op->getLoc(), illegalOps);
return failure();
}
Attribute convertAttribute(Location loc, Attribute oldAttr,
const TypeConverter &typeConverter) {
// Type attributes get their nested type converted.
if (auto oldTypeAttr = llvm::dyn_cast<TypeAttr>(oldAttr)) {
return TypeAttr::get(typeConverter.convertType(oldTypeAttr.getValue()));
}
// Return the same attribute if it doesn't have a type.
auto typedOldAttr = llvm::dyn_cast<TypedAttr>(oldAttr);
if (!typedOldAttr)
return oldAttr;
// Convert the attribute type - if it's the same then it's already legal.
auto oldType = typedOldAttr.getType();
auto newType = typeConverter.convertType(oldType);
if (oldType == newType)
return typedOldAttr;
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(typedOldAttr)) {
APInt value = intAttr.getValue();
if (newType.isSignedInteger()) {
value = value.truncSSat(newType.getIntOrFloatBitWidth());
} else if (newType.isUnsignedInteger()) {
value = value.truncUSat(newType.getIntOrFloatBitWidth());
} else {
value = value.trunc(newType.getIntOrFloatBitWidth());
}
return IntegerAttr::get(newType, value);
} else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(typedOldAttr)) {
auto newFloatType = llvm::cast<FloatType>(newType);
APFloat value = floatAttr.getValue();
bool losesInfo = false;
value.convert(newFloatType.getFloatSemantics(), APFloat::rmTowardZero,
&losesInfo);
return FloatAttr::get(newType, value);
} else if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(typedOldAttr)) {
// NOTE: splats are also dense but this way we avoid needing to convert the
// same splat value N times.
return SplatElementsAttr::get(
llvm::cast<ShapedType>(newType),
convertAttribute(loc, splatAttr.getSplatValue<Attribute>(),
typeConverter));
} else if (auto denseAttr =
llvm::dyn_cast<DenseIntElementsAttr>(typedOldAttr)) {
auto newElementType = llvm::cast<ShapedType>(newType).getElementType();
auto newElementBitWidth = newElementType.getIntOrFloatBitWidth();
if (newElementType.isSignedInteger()) {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.truncSSat(newElementBitWidth);
});
} else if (newElementType.isUnsignedInteger()) {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.truncUSat(newElementBitWidth);
});
} else {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.trunc(newElementBitWidth);
});
}
} else if (auto denseAttr =
llvm::dyn_cast<DenseFPElementsAttr>(typedOldAttr)) {
auto newElementType =
llvm::cast<FloatType>(cast<ShapedType>(newType).getElementType());
const auto &newFloatSemantics = newElementType.getFloatSemantics();
return denseAttr.mapValues(newElementType, [&](APFloat src) {
bool losesInfo = false;
src.convert(newFloatSemantics, APFloat::rmTowardZero, &losesInfo);
return src.bitcastToAPInt();
});
}
return oldAttr;
}
} // namespace mlir::iree_compiler