blob: fec5542053781b362167a83d0a8b690dbb5dd60d [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::GlobalOptimization {
#define GEN_PASS_DEF_OPTIMIZENUMERICSPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
namespace {
int getNextPotBitWidth(int bitWidth, int minBitWidth = 8) {
for (int i = minBitWidth;; i *= 2) {
if (i >= bitWidth)
return i;
}
}
Type withNewElementType(Type originalType, Type elementType) {
if (auto st = llvm::dyn_cast<ShapedType>(originalType)) {
return st.clone(elementType);
} else {
return elementType;
}
}
Type makeLowPType(Type originalType, int bitWidth) {
auto *context = originalType.getContext();
auto elementType = IntegerType::get(context, bitWidth);
return withNewElementType(originalType, elementType);
}
Value castNumeric(Value origValue, Type toType, bool isSigned,
OpBuilder &builder) {
Location loc = origValue.getLoc();
Type origElementType = getElementTypeOrSelf(origValue.getType());
Type toElementType = getElementTypeOrSelf(toType);
if (llvm::isa<FloatType>(origElementType) &&
llvm::isa<IntegerType>(toElementType)) {
if (isSigned) {
return builder.create<arith::FPToSIOp>(loc, toType, origValue);
} else {
return builder.create<arith::FPToUIOp>(loc, toType, origValue);
}
} else if (llvm::isa<IntegerType>(origElementType) &&
llvm::isa<FloatType>(toElementType)) {
if (isSigned) {
return builder.create<arith::SIToFPOp>(loc, toType, origValue);
} else {
return builder.create<arith::UIToFPOp>(loc, toType, origValue);
}
} else {
// If we need int<->int and float<->float, implement those cases. Since
// this is just needed for things in this file, it is ok to leave it
// under implemented.
assert(false && "unsupported numeric cast");
return Value();
}
}
struct NarrowParams {
static std::optional<NarrowParams> forValue(Value value) {
if (auto narrowOp =
llvm::dyn_cast_or_null<IREE::Util::NumericOptionalNarrowOp>(
value.getDefiningOp())) {
NarrowParams params;
params.producer = narrowOp.getOperand();
params.fromType = value.getType();
params.toElementType = narrowOp.getSemanticType();
params.range = narrowOp.getIntegerRange();
return params;
}
return {};
}
bool isFromFloat() {
return llvm::isa<FloatType>(getElementTypeOrSelf(fromType));
}
bool isToInteger() { return llvm::isa<IntegerType>(toElementType); }
bool isToSigned() {
return llvm::cast<IntegerType>(toElementType).isSigned();
}
int getToBitWidth() {
return llvm::cast<IntegerType>(toElementType).getWidth();
}
Value producer;
Type fromType;
Type toElementType;
std::optional<std::pair<int64_t, int64_t>> range;
};
// Eliminates a cast produced by an empty by just initializing to that
// type directly.
struct TensorEmptyCast
: OpInterfaceRewritePattern<IREE::Util::NumericCastOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(IREE::Util::NumericCastOpInterface castOp,
PatternRewriter &rewriter) const override {
auto emptyOp = castOp.getInput().getDefiningOp<tensor::EmptyOp>();
if (!emptyOp)
return failure();
Type resultType = castOp.getCasted().getType();
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(castOp, resultType,
emptyOp.getDynamicSizes());
return success();
}
};
// For a cast produced by a fill, rewrites the cast to be on the fill operands.
struct LinalgFillCast
: public OpInterfaceRewritePattern<IREE::Util::NumericCastOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(IREE::Util::NumericCastOpInterface castOp,
PatternRewriter &rewriter) const override {
auto loc = castOp.getLoc();
auto fillOp = castOp.getInput().getDefiningOp<linalg::FillOp>();
if (!fillOp)
return failure();
Type toElementType = getElementTypeOrSelf(castOp.getCastedType());
Value fillInput = fillOp.value();
Value fillInit = fillOp.output();
fillInput = castOp
.cloneWithInput(
rewriter,
withNewElementType(fillInput.getType(), toElementType),
fillInput)
.getCasted();
fillInit =
castOp
.cloneWithInput(
rewriter, withNewElementType(fillInit.getType(), toElementType),
fillInit)
.getCasted();
Value fillResult =
rewriter.create<linalg::FillOp>(loc, fillInput, fillInit).result();
rewriter.replaceOp(castOp, fillResult);
return success();
}
};
// For narrowable inputs, selects
struct LinalgFpMatmulToLowP : public OpRewritePattern<linalg::MatmulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
PatternRewriter &rewriter) const override {
Location loc = matmulOp.getLoc();
Type origResultType = matmulOp.getResult(0).getType();
auto lhsParams = NarrowParams::forValue(matmulOp.getInputs()[0]);
auto rhsParams = NarrowParams::forValue(matmulOp.getInputs()[1]);
auto accumParams = NarrowParams::forValue(matmulOp.getOutputs()[0]);
if (!lhsParams || !rhsParams || !accumParams) {
return rewriter.notifyMatchFailure(matmulOp, "no narrowing annotations");
}
// TODO(#7987): This could be more flexible, allowing mix and match
// integer/float types.
if (!lhsParams->isFromFloat() || !rhsParams->isFromFloat()) {
return rewriter.notifyMatchFailure(matmulOp, "not from floating point");
}
// TODO(#7987): Could support partial conversion to integer.
if (!lhsParams->isToInteger() || !rhsParams->isToInteger() ||
!accumParams->isToInteger()) {
return rewriter.notifyMatchFailure(matmulOp, "not to an integer type");
}
int lhsBitWidth = lhsParams->getToBitWidth();
int rhsBitWidth = rhsParams->getToBitWidth();
// Handle signed/unsigned mismatch.
// TODO(#7987): Implement a proper unsigned->signed widening.
bool isSigned;
if (lhsParams->isToSigned() != rhsParams->isToSigned()) {
// Mixed signed/unsigned. Promote to signed.
isSigned = true;
if (!lhsParams->isToSigned()) {
lhsBitWidth += 1;
}
if (!rhsParams->isToSigned()) {
rhsBitWidth += 1;
}
} else {
// Uniform signed/unsigned.
isSigned = lhsParams->isToSigned();
}
// Round up to a suitable POT width.
lhsBitWidth = getNextPotBitWidth(lhsBitWidth);
rhsBitWidth = getNextPotBitWidth(rhsBitWidth);
// Promote accumulator to match signedness.
int accumBitWidth = accumParams->getToBitWidth();
if (isSigned && !accumParams->isToSigned()) {
// TODO(#7987): A proper unsigned widening based on range.
accumBitWidth += 1;
}
// Determine an appropriate accumulator size.
// TODO(#7987): Apply the clamp of:
// lhsBitWidth + rhsBitWidth + log2_ceil(contraction_dim + 1) to determine
// the accumulator size. Note: Can drop the +1 if one of lhs/rhs is signed
// and symmetric (i.e. does not use the asymmetric lower bound).
if (lhsBitWidth > 8 || rhsBitWidth > 8) {
return rewriter.notifyMatchFailure(matmulOp, "outside of low-p range");
}
accumBitWidth = getNextPotBitWidth(accumBitWidth, 32);
if (accumBitWidth > 32) {
return rewriter.notifyMatchFailure(matmulOp, "accumulator > 32 bits");
}
Type lhsLowPType = makeLowPType(lhsParams->fromType, lhsBitWidth);
Type rhsLowPType = makeLowPType(rhsParams->fromType, rhsBitWidth);
Type accumLowPType = makeLowPType(accumParams->fromType, accumBitWidth);
// Replace the matmul op.
Value newLhs =
castNumeric(lhsParams->producer, lhsLowPType, isSigned, rewriter);
Value newRhs =
castNumeric(rhsParams->producer, rhsLowPType, isSigned, rewriter);
Value newAccum =
castNumeric(accumParams->producer, accumLowPType, isSigned, rewriter);
auto newMatmulOp = rewriter.create<linalg::MatmulOp>(
loc, ValueRange{newLhs, newRhs}, ValueRange{newAccum});
if (!isSigned) {
newMatmulOp.setCast(linalg::TypeFn::cast_unsigned);
}
// Cast back.
Value newResult = castNumeric(newMatmulOp.getResult(0), origResultType,
isSigned, rewriter);
rewriter.replaceOp(matmulOp, ValueRange{newResult});
return success();
}
};
class OptimizeNumericsPass
: public impl::OptimizeNumericsPassBase<OptimizeNumericsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
// Precision reduction.
patterns.insert<LinalgFpMatmulToLowP>(context);
// Cast propagation.
patterns.insert<TensorEmptyCast>(context);
patterns.insert<LinalgFillCast>(context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::GlobalOptimization