blob: df75799425e365fb1bea66737073d82884330b57 [file] [log] [blame] [edit]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-global-opt-fuse-dequantization-matmul"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
namespace mlir::iree_compiler::GlobalOptimization {
#define GEN_PASS_DEF_FUSEDEQUANTIZATIONMATMULPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
namespace {
//----------------------------------------------------------------------------//
// Utility
//----------------------------------------------------------------------------//
// Checks if the passed op is a contraction with two reduction dimensions
// This function checks that the genericOp:
// 1. isaContractionOpInterface
// 2. Has 2 reduction dimensions
static LogicalResult
isContractionWithTwoReductions(linalg::GenericOp genericOp) {
unsigned numLoops = genericOp.getNumLoops();
linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(genericOp.getOperation());
if (numLoops == 0) {
return failure();
}
if (!linalg::isaContractionOpInterface(linalgOp)) {
return failure();
}
if (genericOp.getNumReductionLoops() != 2) {
return failure();
}
return success();
}
// Checks if the passed op is a dequantization on grouped input
// This function checks that the genericOp:
// 1. Has a body like:
// arith.extui
// arith.uitofp
// arith.subf
// arith.mulf
// 2. Increases the bit width of the input
// 3. Has 3 parallel dims
// 4. Has 2 (weights, scales) or 3 (weights, scales, zero points)
// inputs and 1 output
static LogicalResult isGroupedDequantizationOp(linalg::GenericOp genericOp) {
// Check for 1 result, and 2 (input, scales) or 3 (input, scales, zero points)
// inputs
if (genericOp.getNumDpsInits() != 1) {
return failure();
}
if (genericOp.getNumDpsInputs() != 2 && genericOp.getNumDpsInputs() != 3) {
return failure();
}
// Check that the rank is at least 3 and all loops are parallel
unsigned numLoops = genericOp.getNumLoops();
unsigned numParallelLoops = genericOp.getNumParallelLoops();
if (numLoops < 3) {
return failure();
}
if (numLoops != numParallelLoops) {
return failure();
}
// Work back from linalg.yield and check body of genericOp.
// The genericOp should yield the result of an arith.mulf,
// preceded by an arith.subf, arith.uitofp, and arith.extui
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
Value producerOutput;
Operation *producer;
// Producer of linalg.yield op is arith.mulf
{
producerOutput = yieldOp->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0) {
return failure();
}
if (!matchPattern(producer, m_Op<arith::MulFOp>())) {
return failure();
}
}
// Producer of arith.mulf op is arith.subf
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0) {
return failure();
}
if (!matchPattern(producer, m_Op<arith::SubFOp>())) {
return failure();
}
}
// Producer of arith.subf op is arith.uitofp
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0) {
return failure();
}
if (!matchPattern(producer, m_Op<arith::UIToFPOp>())) {
return failure();
}
}
// Producer of arith.uitofp op is arith.extui
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer) {
return failure();
}
if (!matchPattern(producer, m_Op<arith::ExtUIOp>())) {
return failure();
}
}
// Ensure that the dequantization increases the
// bitwidth from the input to the output
auto elementTypeOut =
cast<ShapedType>(genericOp.getOutputs()[0].getType()).getElementType();
if (!elementTypeOut.isIntOrFloat()) {
return failure();
}
unsigned bitWidthOut = elementTypeOut.getIntOrFloatBitWidth();
auto elementTypeIn =
cast<ShapedType>(genericOp.getInputs()[0].getType()).getElementType();
if (!elementTypeIn.isIntOrFloat()) {
return failure();
}
unsigned bitWidthIn = elementTypeIn.getIntOrFloatBitWidth();
if (bitWidthIn >= bitWidthOut) {
return failure();
}
return success();
}
static SmallVector<utils::IteratorType>
getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) {
SmallVector<utils::IteratorType> res(nLoops - nReduction,
utils::IteratorType::parallel);
res.append(nReduction, utils::IteratorType::reduction);
return res;
}
struct QuantizedMatmulRewriter {
QuantizedMatmulRewriter(RewriterBase &rewriter, linalg::GenericOp dequant,
linalg::GenericOp matmul, int quantizedBitWidth);
std::optional<SmallVector<OpOperand *>> getDequantMatmulInputs();
std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>>
getGroupReductionMapsAndIterators(OpOperand *inputOperand);
Value getGroupReductionInit(Value input);
Value generateGroupMaxGeneric();
Value generateScalesGeneric(Value groupMax);
Value generateGroupSumsGeneric();
SmallVector<Value> generateQuantizationGenerics();
linalg::GenericOp
generateQuantizedMatmulGeneric(SmallVector<Value> quantizeResults);
linalg::GenericOp
generateReassociatedDequantizationGeneric(SmallVector<Value> quantizeResults,
Value quantizedIntegerMatmul);
LogicalResult precondition();
private:
// rewriter
RewriterBase &rewriter;
// linalg::GenericOp that performs the dequantization
linalg::GenericOp dequant;
// linalg::GenericOp that performs the matmul
linalg::GenericOp matmul;
// Destination bit width for dynamic quantization of the unquantized input
// (`ins[1]`)
int quantizedBitWidth;
// Type for accumulation of integer matmul
IntegerType accType;
// Type for multiplication in integer matmul (should probably be same as
// `accType`)
IntegerType mulType;
// Type for dynamic quantization of unquantized input
IntegerType quantType;
// inputs to the `dequant` and `matmul` ops
// ins[0] = quantized matrix input to `dequant`
// ins[1] = unquantized input to `matmul`
// ins[2] = scales input to `dequant
// ins[3] = zero points input to `dequant`
// ins[4] = dequantized input to `matmul`
SmallVector<OpOperand *> ins;
// Location for rewrite
Location loc;
};
// Takes as input the dequantization `linalg.generic` op and the matmul
// `linalg.generic` op, and returns the scales, zero points, quantized
// input matrix, unquantizaed input matrix, and dequantized result
// matrix.
// TODO(#) Have stricter matching on inputs. There may be cases where
// the current matching fails
std::optional<SmallVector<OpOperand *>>
QuantizedMatmulRewriter::getDequantMatmulInputs() {
assert(!failed(isContractionWithTwoReductions(matmul)) &&
"expected `matmul` to be a contraction with two reduction dimensions");
assert(!failed(isGroupedDequantizationOp(dequant)) &&
"expected `dequant` to be a grouped dequantization");
OpOperand *scales, *zps, *quantMat, *unquantMat, *dequantMat;
auto maps = dequant.getIndexingMapsArray();
for (auto [idx, map] : enumerate(ArrayRef<AffineMap>(maps).drop_back())) {
if (map.isIdentity()) {
quantMat = dequant.getDpsInputOperand(idx);
} else if (map.isProjectedPermutation(true)) {
for (Operation &bodyOp : dequant.getBlock()->getOperations()) {
if (isa<arith::MulFOp>(bodyOp)) {
if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) {
scales = dequant.getDpsInputOperand(idx);
break;
}
} else if (isa<arith::SubFOp>(bodyOp)) {
if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) {
zps = dequant.getDpsInputOperand(idx);
break;
}
}
}
}
}
Value dequantOut = dequant.getResult(0);
if (matmul.getDpsInputOperand(0)->get() == dequantOut) {
unquantMat = matmul.getDpsInputOperand(1);
dequantMat = matmul.getDpsInputOperand(0);
} else {
unquantMat = matmul.getDpsInputOperand(0);
dequantMat = matmul.getDpsInputOperand(1);
}
if (scales && zps && quantMat && unquantMat) {
return SmallVector<OpOperand *>(
{quantMat, unquantMat, scales, zps, dequantMat});
}
return std::nullopt;
}
QuantizedMatmulRewriter::QuantizedMatmulRewriter(RewriterBase &rewriter,
linalg::GenericOp dequant,
linalg::GenericOp matmul,
int quantizedBitWidth)
: rewriter(rewriter), dequant(dequant), matmul(matmul),
quantizedBitWidth(quantizedBitWidth), loc(dequant.getLoc()) {
accType = rewriter.getI32Type();
mulType = rewriter.getI32Type();
quantType = rewriter.getIntegerType(quantizedBitWidth);
std::optional<SmallVector<OpOperand *>> inputs = getDequantMatmulInputs();
if (inputs) {
ins = *inputs;
}
}
LogicalResult QuantizedMatmulRewriter::precondition() {
if (ins.size() != 5) {
return rewriter.notifyMatchFailure(
matmul,
"expected `ins` to have 5 inputs for quantized matmul reassociation");
}
OpOperand *unquantizedInputOperand = ins[1];
Value unquantizedInput = ins[1]->get();
RankedTensorType unquantizedInputType =
cast<RankedTensorType>(unquantizedInput.getType());
SmallVector<int64_t> unquantizedInputShape(unquantizedInputType.getShape());
AffineMap indexingMap =
matmul.getMatchingIndexingMap(unquantizedInputOperand);
SmallVector<utils::IteratorType> matmulIteratorTypes =
matmul.getIteratorTypesArray();
if (unquantizedInputShape.size() < 2) {
return rewriter.notifyMatchFailure(
matmul, "input expected to have a rank of at least 2");
}
if (matmulIteratorTypes[indexingMap.getNumDims() - 1] ==
utils::IteratorType::parallel ||
matmulIteratorTypes[indexingMap.getNumDims() - 2] ==
utils::IteratorType::parallel) {
return rewriter.notifyMatchFailure(
matmul, "inner 2 dimensions of matmul expected to be reduction");
}
auto affineExprs = indexingMap.getResults();
auto innerDim0 = dyn_cast<AffineDimExpr>(affineExprs.back());
auto innerDim1 = dyn_cast<AffineDimExpr>(affineExprs[affineExprs.size() - 2]);
if (!innerDim0 || !innerDim1 ||
innerDim0.getPosition() != indexingMap.getNumDims() - 1 ||
innerDim1.getPosition() != indexingMap.getNumDims() - 2) {
return rewriter.notifyMatchFailure(
matmul, "inner shape of input expected to be reduced in matmul");
}
Value scales = ins[2]->get();
Value zps = ins[3]->get();
if (!isa<FloatType>(unquantizedInputType.getElementType()) ||
!isa<FloatType>(getElementTypeOrSelf(scales)) ||
!isa<FloatType>(getElementTypeOrSelf(zps))) {
return rewriter.notifyMatchFailure(matmul, "expected float type");
}
OpOperand *matmulDequantizedOperand = ins[4];
auto matmulDequantizedInputExprs =
matmul.getMatchingIndexingMap(matmulDequantizedOperand).getResults();
auto scalesType = dyn_cast<RankedTensorType>(scales.getType());
auto zpsType = dyn_cast<RankedTensorType>(zps.getType());
if (!scalesType || !zpsType) {
return rewriter.notifyMatchFailure(
dequant, "expected scales and zero points to have RankedTensorType");
}
if (scalesType.getShape().size() != matmulDequantizedInputExprs.size() - 1) {
if (scalesType.getShape().size() != matmulDequantizedInputExprs.size() ||
scalesType.getShape().back() != 1) {
return rewriter.notifyMatchFailure(dequant, "unexpected rank for scales");
}
}
if (zpsType.getShape().size() != matmulDequantizedInputExprs.size() - 1) {
if (zpsType.getShape().size() != matmulDequantizedInputExprs.size() ||
zpsType.getShape().back() != 1) {
return rewriter.notifyMatchFailure(dequant,
"unexpected rank for zero points");
}
}
return success();
}
std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>>
QuantizedMatmulRewriter::getGroupReductionMapsAndIterators(
OpOperand *inputOperand) {
Value input = inputOperand->get();
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
SmallVector<int64_t> inputShape(inputType.getShape());
AffineMap indexingMap = matmul.getMatchingIndexingMap(inputOperand);
SmallVector<utils::IteratorType> matmulIteratorTypes =
matmul.getIteratorTypesArray();
assert(inputShape.size() >= 2 &&
"input expected to have a rank of at least 2");
assert((matmulIteratorTypes[indexingMap.getNumDims() - 1] ==
utils::IteratorType::reduction &&
matmulIteratorTypes[indexingMap.getNumDims() - 2] ==
utils::IteratorType::reduction) &&
"inner 2 dimensions of matmul expected to be reduction");
auto affineExprs = indexingMap.getResults();
auto innerDim0 = dyn_cast<AffineDimExpr>(affineExprs.back());
auto innerDim1 = dyn_cast<AffineDimExpr>(affineExprs[affineExprs.size() - 2]);
assert(innerDim0 && innerDim1 &&
innerDim0.getPosition() == indexingMap.getNumDims() - 1 &&
innerDim1.getPosition() == indexingMap.getNumDims() - 2 &&
"inner shape of input expected to be reduced in matmul");
(void)innerDim0;
(void)innerDim1;
SmallVector<utils::IteratorType> iterators(inputShape.size(),
utils::IteratorType::parallel);
iterators.back() = utils::IteratorType::reduction;
AffineMap inputMap = rewriter.getMultiDimIdentityMap(inputShape.size());
AffineMap outputMap = rewriter.getMultiDimIdentityMap(inputShape.size())
.getMajorSubMap(inputShape.size() - 1);
SmallVector<AffineMap> maps{inputMap, outputMap};
return std::make_pair(maps, iterators);
}
// Helper to create an init Value for reductions along the group dimension.
Value QuantizedMatmulRewriter::getGroupReductionInit(Value input) {
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
assert(isa<FloatType>(inputType.getElementType()) && "expected float type");
Value zero = arith::ConstantOp::create(
rewriter, loc, rewriter.getFloatAttr(inputType.getElementType(), 0.0));
SmallVector<int64_t> inputShape(inputType.getShape());
SmallVector<int64_t> outputShape(llvm::drop_end(inputShape));
RankedTensorType outputType =
RankedTensorType::get(outputShape, inputType.getElementType());
Value emptyOut = tensor::EmptyOp::create(rewriter, loc, outputType.getShape(),
outputType.getElementType());
return linalg::FillOp::create(rewriter, loc, zero, emptyOut).result();
}
// Creates a generic that computes the absolute max along the group
// dimension, and returns the result.
Value QuantizedMatmulRewriter::generateGroupMaxGeneric() {
OpOperand *inputOperand = ins[1];
std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>>
mapsAndIterators = getGroupReductionMapsAndIterators(inputOperand);
auto maps = mapsAndIterators.first;
auto iterators = mapsAndIterators.second;
Value input = inputOperand->get();
Value output = getGroupReductionInit(input);
auto groupMaxOp = linalg::GenericOp::create(
rewriter, loc, output.getType(), input, output, maps, iterators,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value abs = math::AbsFOp::create(b, nestedLoc, args[0]);
Value max = arith::MaximumFOp::create(b, nestedLoc, abs, args[1]);
linalg::YieldOp::create(b, nestedLoc, max);
});
LLVM_DEBUG(DBGS() << "groupMaxOp: " << groupMaxOp << "\n");
return groupMaxOp.getResult(0);
}
// Creates a generic that computes the scales for each group, and
// returns the result.
Value QuantizedMatmulRewriter::generateScalesGeneric(Value groupMax) {
auto groupMaxType = cast<RankedTensorType>(groupMax.getType());
assert(isa<FloatType>(groupMaxType.getElementType()) &&
"expected float type");
Value cst = arith::ConstantOp::create(
rewriter, loc,
rewriter.getFloatAttr(groupMaxType.getElementType(),
(1 << (quantizedBitWidth - 1)) - 1));
Value output = tensor::EmptyOp::create(rewriter, loc, groupMaxType.getShape(),
groupMaxType.getElementType());
SmallVector<AffineMap> maps(
2, rewriter.getMultiDimIdentityMap(groupMaxType.getShape().size()));
auto scalesOp = linalg::GenericOp::create(
rewriter, loc, output.getType(), groupMax, output, maps,
getParallelAndReductionIterators(groupMaxType.getRank(), 0),
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value scale = arith::DivFOp::create(b, nestedLoc, args[0], cst);
linalg::YieldOp::create(b, nestedLoc, scale);
});
LLVM_DEBUG(DBGS() << "scalesOp: " << scalesOp << "\n");
return scalesOp.getResult(0);
}
// Creates a generic that computes the sums for each group, and
// returns the result.
Value QuantizedMatmulRewriter::generateGroupSumsGeneric() {
OpOperand *inputOperand = ins[1];
std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>>
mapsAndIterators = getGroupReductionMapsAndIterators(inputOperand);
auto maps = mapsAndIterators.first;
auto iterators = mapsAndIterators.second;
Value input = inputOperand->get();
Value output = getGroupReductionInit(input);
auto groupSumsOp = linalg::GenericOp::create(
rewriter, loc, output.getType(), input, output, maps, iterators,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value sum = arith::AddFOp::create(b, nestedLoc, args[0], args[1]);
linalg::YieldOp::create(b, nestedLoc, sum);
});
LLVM_DEBUG(DBGS() << "groupSumsOp: " << groupSumsOp << "\n");
return groupSumsOp.getResult(0);
}
// Creates 4 linalg::GenericOps that collectively perform a symmetric
// quantization of the unquantized input, and returns the results.
SmallVector<Value> QuantizedMatmulRewriter::generateQuantizationGenerics() {
assert(ins.size() == 5 && "expected `ins` to have 5 inputs");
OpOperand *unquantizedInputOperand = ins[1];
Value unquantizedInput = unquantizedInputOperand->get();
auto unquantizedType = cast<RankedTensorType>(unquantizedInput.getType());
SmallVector<int64_t> unquantizedShape(unquantizedType.getShape());
IntegerType quantizedElementType = rewriter.getIntegerType(quantizedBitWidth);
Value groupMax = generateGroupMaxGeneric();
Value scales = generateScalesGeneric(groupMax);
Value groupSums = generateGroupSumsGeneric();
Value output = tensor::EmptyOp::create(rewriter, loc, unquantizedShape,
quantizedElementType);
AffineMap inputMap = rewriter.getMultiDimIdentityMap(unquantizedShape.size());
AffineMap scalesMap = rewriter.getMultiDimIdentityMap(unquantizedShape.size())
.getMajorSubMap(unquantizedShape.size() - 1);
AffineMap outputMap =
rewriter.getMultiDimIdentityMap(unquantizedShape.size());
SmallVector<AffineMap> maps{inputMap, scalesMap, outputMap};
auto quantizeOp = linalg::GenericOp::create(
rewriter, loc, output.getType(), ValueRange{unquantizedInput, scales},
output, maps,
getParallelAndReductionIterators(unquantizedShape.size(), 0),
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value scaled = arith::DivFOp::create(b, nestedLoc, args[0], args[1]);
Value quant =
arith::FPToSIOp::create(b, nestedLoc, quantizedElementType, scaled);
linalg::YieldOp::create(b, nestedLoc, quant);
});
LLVM_DEBUG(DBGS() << "quantizeOp: " << quantizeOp << "\n");
Value newQuantizedInput = quantizeOp.getResult(0);
SmallVector<Value> results{groupMax, scales, groupSums, newQuantizedInput};
return results;
}
// Creates a generic that computes the main matmul computation in integer
// arithmetic, while values are still quantized, and returns the result.
linalg::GenericOp QuantizedMatmulRewriter::generateQuantizedMatmulGeneric(
SmallVector<Value> quantizeResults) {
Value newQuantizedInput = quantizeResults.back();
assert(ins.size() == 5 && "expected `ins` to have 5 inputs");
Value quantizedInput = ins[0]->get();
OpOperand *matmulUnquantizedOperand = ins[1];
OpOperand *matmulDequantizedOperand = ins[4];
OpOperand *matmulOutputOperand = matmul.getDpsInitOperand(0);
SmallVector<utils::IteratorType> iterators =
getParallelAndReductionIterators(matmul.getNumLoops(), 1);
SmallVector<AffineMap> maps;
AffineMap matmulUnquantizedMap =
matmul.getMatchingIndexingMap(matmulUnquantizedOperand);
AffineMap matmulDequantizedMap =
matmul.getMatchingIndexingMap(matmulDequantizedOperand);
AffineMap matmulOutputMap =
matmul.getMatchingIndexingMap(matmulOutputOperand);
maps.push_back(matmulUnquantizedMap);
maps.push_back(matmulDequantizedMap);
SmallVector<AffineExpr> outputExprs(matmulOutputMap.getResults());
outputExprs.push_back(rewriter.getAffineDimExpr(matmul.getNumLoops() - 2));
maps.push_back(AffineMap::get(iterators.size(), 0, outputExprs,
outputExprs.front().getContext()));
SmallVector<int64_t> newQuantizedInputShape(
cast<RankedTensorType>(newQuantizedInput.getType()).getShape());
assert(newQuantizedInputShape.size() >= 2 &&
"expected new quantized input to have a rank of at least 2");
SmallVector<int64_t> outputShape(
cast<RankedTensorType>(matmulOutputOperand->get().getType()).getShape());
outputShape.push_back(
newQuantizedInputShape[newQuantizedInputShape.size() - 2]);
Value zero = arith::ConstantOp::create(rewriter, loc,
rewriter.getIntegerAttr(accType, 0.0));
Value emptyOut = tensor::EmptyOp::create(rewriter, loc, outputShape, accType);
Value output = linalg::FillOp::create(rewriter, loc, zero, emptyOut).result();
auto integerMatmulOp = linalg::GenericOp::create(
rewriter, loc, output.getType(),
ValueRange{newQuantizedInput, quantizedInput}, output, maps, iterators,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value mul;
if (quantType == mulType) {
Value ext1 = arith::ExtUIOp::create(b, nestedLoc, mulType, args[1]);
mul = arith::MulIOp::create(b, nestedLoc, args[0], ext1);
} else {
Value ext0 = arith::ExtSIOp::create(b, nestedLoc, mulType, args[0]);
Value ext1 = arith::ExtUIOp::create(b, nestedLoc, mulType, args[1]);
mul = arith::MulIOp::create(b, nestedLoc, ext0, ext1);
}
Value sum;
if (mulType == accType) {
sum = arith::AddIOp::create(b, nestedLoc, mul, args[2]);
} else {
Value extMul = arith::ExtSIOp::create(b, nestedLoc, accType, mul);
sum = arith::AddIOp::create(b, nestedLoc, extMul, args[2]);
}
linalg::YieldOp::create(b, nestedLoc, sum);
});
LLVM_DEBUG(DBGS() << "integerMatmulOp: " << integerMatmulOp << "\n");
return integerMatmulOp;
}
// Creates a generic that does the reassociated dequantization math, and does
// the final reduction along the number of groups dimension.
linalg::GenericOp
QuantizedMatmulRewriter::generateReassociatedDequantizationGeneric(
SmallVector<Value> quantizeResults, Value quantizedIntegerMatmul) {
assert(quantizeResults.size() == 4 &&
"expected 4 ops from quantization step");
Value scales = ins[2]->get();
Value zps = ins[3]->get();
Value newScales = quantizeResults[1];
Value groupSums = quantizeResults[2];
SmallVector<utils::IteratorType> iterators =
getParallelAndReductionIterators(matmul.getNumLoops() - 1, 1);
SmallVector<AffineMap> maps;
AffineMap quantizedIntegerMatmulMap =
rewriter.getMultiDimIdentityMap(iterators.size());
maps.push_back(quantizedIntegerMatmulMap);
{
OpOperand *matmulUnquantizedOperand = ins[1];
auto exprs =
matmul.getMatchingIndexingMap(matmulUnquantizedOperand).getResults();
exprs = exprs.slice(0, exprs.size() - 1);
auto newScalesAndSumsMap =
AffineMap::get(iterators.size(), 0, exprs, exprs.front().getContext());
maps.push_back(newScalesAndSumsMap);
maps.push_back(newScalesAndSumsMap);
}
{
OpOperand *matmulDequantizedOperand = ins[4];
auto exprsRef =
matmul.getMatchingIndexingMap(matmulDequantizedOperand).getResults();
SmallVector<AffineExpr> exprs(exprsRef.slice(0, exprsRef.size() - 1));
RankedTensorType scalesType = cast<RankedTensorType>(scales.getType());
RankedTensorType zpsType = cast<RankedTensorType>(zps.getType());
if (exprs.size() < scalesType.getShape().size() &&
scalesType.getShape().back() == 1 && zpsType.getShape().back() == 1) {
exprs.push_back(rewriter.getAffineConstantExpr(0));
}
assert(exprs.size() == scalesType.getShape().size() &&
"unexpected rank for scales");
assert(exprs.size() == zpsType.getShape().size() &&
"unexpected rank for zero points");
auto scalesAndZpsMap =
AffineMap::get(iterators.size(), 0, exprs, exprs.front().getContext());
maps.push_back(scalesAndZpsMap);
maps.push_back(scalesAndZpsMap);
}
OpOperand *matmulOutputOperand = matmul.getDpsInitOperand(0);
SmallVector<AffineExpr> outputExprs(
matmul.getMatchingIndexingMap(matmulOutputOperand).getResults());
auto outputMap = AffineMap::get(iterators.size(), 0, outputExprs,
outputExprs.front().getContext());
maps.push_back(outputMap);
Type floatType = getElementTypeOrSelf(scales);
Value output = matmulOutputOperand->get();
auto reassociatedDequantizationOp = linalg::GenericOp::create(
rewriter, loc, output.getType(),
ValueRange{quantizedIntegerMatmul, newScales, groupSums, scales, zps},
output, maps, iterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dq;
dq = arith::SIToFPOp::create(b, loc, floatType, args[0]);
Value scaledRes0 = arith::MulFOp::create(b, loc, dq, args[1]);
Value scaledRes1 = arith::MulFOp::create(b, loc, scaledRes0, args[3]);
Value scaledZp0 = arith::MulFOp::create(b, loc, args[4], args[3]);
Value scaledZp1 = arith::MulFOp::create(b, loc, scaledZp0, args[2]);
Value groupRes = arith::SubFOp::create(b, loc, scaledRes1, scaledZp1);
Value sum = arith::AddFOp::create(b, loc, groupRes, args[5]);
linalg::YieldOp::create(b, loc, sum);
});
LLVM_DEBUG(DBGS() << "reassociatedDequantizationOp: "
<< reassociatedDequantizationOp << "\n");
return reassociatedDequantizationOp;
}
// This function does the bulk of the rewrite for the dequantization + matmul.
//
// Starting with 2 `linalg.generic` ops (dequantization->matmul)
// %arg0 = quantized input
// %arg1 = scales
// %arg2 = zero points
// %arg3 = unquantized input
// ```mlir
// %0 = linalg.generic ins(%arg0, %arg1, %arg2 : tensor<8x4x2xi4>,
// tensor<8x4x1xf32>, tensor<8x4x1xf32>) outs(%1 : tensor<8x4x2xf32>) {
// ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
// %9 = arith.extui %in : i4 to i32
// %10 = arith.uitofp %9 : i32 to f32
// %11 = arith.subf %10, %in_1 : f32
// %12 = arith.mulf %11, %in_0 : f32
// linalg.yield %12 : f32
// } -> tensor<8x4x2xf32>
// %2 = linalg.generic ins(%arg3, %0 : tensor<4x2xf32>, tensor<8x4x2xf32>)
// outs(%3 : tensor<8xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32):
// %9 = arith.mulf %in, %in_0 : f32
// %10 = arith.addf %9, %out : f32
// linalg.yield %10 : f32
// } -> tensor<8xf32>
// ```
//
// This function rewrites the above ops as the following new sequence of 6 ops
// that does the following:
//
// a) Dynamically quantize the unquantized input:
// 1. Compute the absolute max of the unquantized input (%arg3) within each
// group.
// 2. Compute scales for %arg3 by dividing the absolute max by (1 <<
// newBitWidth) - 1),
// where newBitWidth is the bitwidth of the new quantized type,
// currently set to `i16`.
// 3. Compute the sum along groups of the unquantized input. This is not
// necessary for
// the quantization step, but it is needed to efficiently perform a
// reassociated quantized matmul in steps 5-6.
// 4. Quantize the unquantized input (%arg3) by dividing elements in each
// group by
// the corresponding scale. This quantization is symmetric with no zero
// point.
// b) Perform the reassociated quantized matmul, keeping the bulk of the
// computation in
// integer arithmetic:
// 5. The first op performs a matmul-like operation that reduces the
// innermost group
// dimension. Note the indexing maps in the following example:
// ```mlir
// %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) ->
// (d1, d2)>,
// affine_map<(d0, d1, d2) ->
// (d0, d1, d2)>, affine_map<(d0,
// d1, d2) -> (d0, d1)>],
// iterator_types = ["parallel", "parallel",
// "reduction"]} ins(%17, %0 : tensor<4x2xi16>,
// tensor<8x4x2xi4>) outs(%19 : tensor<8x4xi32>) {
// ^bb0(%in: i16, %in_4: i4, %out: i32):
// %24 = arith.extsi %in : i16 to i32
// %25 = arith.extui %in_4 : i4 to i32
// %26 = arith.muli %24, %25 : i32
// %27 = arith.addi %26, %out : i32
// linalg.yield %27 : i32
// } -> tensor<8x4xi32>
// ```
// This op also extends the inputs to the accumulation type, i32 in this
// case, to target specific x86 instructions. We perform the matrix
// multiplication before the dequantization arithmetic, which has been
// reassociated into op 6.
// 6. The final op performs the remaining reduction across groups and does
// the
// dequantization arithmetic:
// ```mlir
// %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0,
// d1)>,
// affine_map<(d0, d1) -> (d1)>,
// affine_map<(d0, d1) -> (d1)>,
// affine_map<(d0, d1) -> (d0,
// d1)>, affine_map<(d0, d1) ->
// (d0, d1)>, affine_map<(d0, d1)
// -> (d0)>],
// iterator_types = ["parallel", "reduction"]}
// ins(tensor<8x4xi32>, tensor<4xf32>,
// tensor<4xf32>, tensor<8x4xf32>,
// tensor<8x4xf32>) outs(tensor<8xf32>) {
// ^bb0(%in: i32, %in_4: f32, %in_5: f32, %in_6: f32, %in_7: f32,
// %out: f32):
// %24 = arith.sitofp %in : i32 to f32
// %25 = arith.mulf %24, %in_4 : f32
// %26 = arith.mulf %25, %in_6 : f32
// %27 = arith.mulf %in_7, %in_6 : f32
// %28 = arith.mulf %27, %in_5 : f32
// %29 = arith.subf %26, %28 : f32
// %30 = arith.addf %29, %out : f32
// linalg.yield %30 : f32
// } -> tensor<8xf32>
// ```
//
// ** Note that this rewrite introduces precision loss in the matmul, and is a
// tradeoff between precision and performance. This rewrite should most
// likely be opt-in only. **
static LogicalResult reassociateDequantMatmul(RewriterBase &rewriter,
linalg::GenericOp dequant,
linalg::GenericOp matmul,
int quantizeBitWidth) {
QuantizedMatmulRewriter qmr(rewriter, dequant, matmul, quantizeBitWidth);
if (failed(qmr.precondition())) {
return success();
}
SmallVector<Value> quantizeResults = qmr.generateQuantizationGenerics();
linalg::GenericOp quantizedIntegerMatmul =
qmr.generateQuantizedMatmulGeneric(quantizeResults);
linalg::GenericOp reassociatedDequantization =
qmr.generateReassociatedDequantizationGeneric(
quantizeResults, quantizedIntegerMatmul.getResult(0));
rewriter.replaceOp(matmul, reassociatedDequantization.getResult(0));
return success();
}
struct FuseDequantizationMatmulPass
: public impl::FuseDequantizationMatmulPassBase<
FuseDequantizationMatmulPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, IREE::Flow::FlowDialect,
math::MathDialect>();
}
void runOnOperation() override;
};
} // namespace
void FuseDequantizationMatmulPass::runOnOperation() {
MLIRContext *context = &getContext();
mlir::FunctionOpInterface funcOp = getOperation();
int quantizeBitWidth = 16;
SmallVector<std::pair<linalg::GenericOp, linalg::GenericOp>> candidates;
for (auto genericOp : funcOp.getFunctionBody().getOps<linalg::GenericOp>()) {
if (failed(isContractionWithTwoReductions(genericOp))) {
continue;
}
OpOperand *lhs = genericOp.getDpsInputOperand(0);
OpOperand *rhs = genericOp.getDpsInputOperand(1);
auto lhsOp = lhs->get().getDefiningOp<linalg::GenericOp>();
auto rhsOp = rhs->get().getDefiningOp<linalg::GenericOp>();
if (!cast<ShapedType>(genericOp.getInputs()[0].getType())
.hasStaticShape() ||
!cast<ShapedType>(genericOp.getInputs()[1].getType())
.hasStaticShape() ||
!cast<ShapedType>(genericOp.getResults()[0].getType())
.hasStaticShape()) {
// Codegen can't handle the dynamic case yet.
continue;
}
if (lhsOp) {
if (!failed(isGroupedDequantizationOp(lhsOp))) {
candidates.push_back(std::make_pair(lhsOp, genericOp));
continue;
}
}
if (rhsOp) {
if (!failed(isGroupedDequantizationOp(rhsOp))) {
candidates.push_back(std::make_pair(rhsOp, genericOp));
}
}
}
IRRewriter rewriter(context);
for (auto candidate : candidates) {
rewriter.setInsertionPointAfter(candidate.second);
if (failed(reassociateDequantMatmul(rewriter, candidate.first,
candidate.second, quantizeBitWidth))) {
return signalPassFailure();
}
}
}
} // namespace mlir::iree_compiler::GlobalOptimization