blob: e3d70258e85a14ae5354611f9bd1e4786d4fb6b7 [file] [log] [blame] [edit]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/GlobalOptimization/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::GlobalOptimization {
#define GEN_PASS_DEF_LINALGQUANTIZEDMATMULTOMATMULPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
namespace {
bool isConstantZero(Value val) {
auto constIntOp = val.getDefiningOp<arith::ConstantIntOp>();
return constIntOp && constIntOp.value() == 0;
}
// Pattern lowering quantized_matmul to matmul and quantized_batch_matmul to
// batch_matmul op.
// This is implementing the math explained in Section 2.3 of
// https://arxiv.org/abs/1712.05877.
struct QuantizedMatmulToMatmul
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const override {
// Fails when the operation is neither quantized_matmul or
// quantized_batch_matmul.
if (!isa<linalg::QuantizedMatmulOp, linalg::QuantizedBatchMatmulOp>(op)) {
return failure();
}
Location loc = op.getLoc();
SmallVector<Value> inputs = op.getDpsInputs();
bool batch = isa<linalg::QuantizedBatchMatmulOp>(op) ? true : false;
ImplicitLocOpBuilder builder(loc, rewriter);
assert(inputs.size() == 4);
Value lhs = inputs[0];
Value rhs = inputs[1];
Value lhsZp = inputs[2];
Value rhsZp = inputs[3];
auto lhsTy = dyn_cast<ShapedType>(lhs.getType());
unsigned lhsRank = lhsTy.getRank();
Value acc = op.getDpsInits()[0];
// Compute the matmul part.
Value matmul = batch ? linalg::BatchMatmulOp::create(
builder, ValueRange{lhs, rhs}, ValueRange{acc})
.getResult(0)
: linalg::MatmulOp::create(
builder, ValueRange{lhs, rhs}, ValueRange{acc})
.getResult(0);
bool lhsZpIsConstantZero = isConstantZero(lhsZp);
bool rhsZpIsConstantZero = isConstantZero(rhsZp);
if (lhsZpIsConstantZero && rhsZpIsConstantZero) {
// Easy case: both zero points are constant zeros, so the quantized_matmul
// was just a matmul all along.
rewriter.replaceOp(op, matmul);
return success();
// return matmul;
}
// Create the result. No need to zero-fill it as we will overwrite it.
ShapedType accType = cast<ShapedType>(acc.getType());
Value initResult = tensor::EmptyOp::create(
builder, tensor::getMixedSizes(builder, loc, acc),
accType.getElementType());
// Create the indexing maps for the generic.
MLIRContext *context = rewriter.getContext();
AffineExpr b, m, n;
batch ? bindDims(context, b, m, n) : bindDims(context, m, n);
AffineMap mapToNone = AffineMap::get(lhsRank, 0, context);
AffineMap mapToRowDim = batch ? AffineMap::get(lhsRank, 0, {b, m}, context)
: AffineMap::get(lhsRank, 0, m, context);
AffineMap mapToColumnDim = batch
? AffineMap::get(lhsRank, 0, {b, n}, context)
: AffineMap::get(lhsRank, 0, n, context);
AffineMap mapIdentity = batch
? AffineMap::get(lhsRank, 0, {b, m, n}, context)
: AffineMap::get(lhsRank, 0, {m, n}, context);
SmallVector<AffineMap> indexingMaps;
SmallVector<Value> ins;
auto addInput = [&](Value val, AffineMap map) -> int {
ins.push_back(val);
indexingMaps.push_back(map);
return ins.size() - 1;
};
int indexOfMatmulInput = addInput(matmul, mapIdentity);
int indexOfLhsSumsInput = 0;
int indexOfLhsZpInput = 0;
int indexOfRhsSumsInput = 0;
int indexOfRhsZpInput = 0;
int indexOfLhsZpTimesRhsZpTimesKSizeInput = 0;
Type accElTy = accType.getElementType();
if (!rhsZpIsConstantZero) {
SmallVector<bool> colRedIterator(lhsRank, false);
colRedIterator.back() = true;
Value lhsSums =
sumReduceDimensionSubset(builder, lhs, accElTy, colRedIterator);
indexOfLhsSumsInput = addInput(lhsSums, mapToRowDim);
indexOfRhsZpInput = addInput(rhsZp, mapToNone);
}
if (!lhsZpIsConstantZero) {
SmallVector<bool> rowRedIterator(lhsRank, false);
rowRedIterator[static_cast<int>(batch)] = true;
Value rhsSums =
sumReduceDimensionSubset(builder, rhs, accElTy, rowRedIterator);
indexOfRhsSumsInput = addInput(rhsSums, mapToColumnDim);
indexOfLhsZpInput = addInput(lhsZp, mapToNone);
}
if (!lhsZpIsConstantZero && !rhsZpIsConstantZero) {
Value lhsZpTimesRhsZp = arith::MulIOp::create(builder, lhsZp, rhsZp);
Value kSize = arith::IndexCastOp::create(
rewriter, loc, accElTy,
tensor::DimOp::create(builder, lhs, batch ? 2 : 1));
Value lhsZpTimesRhsZpTimesKSize =
arith::MulIOp::create(builder, lhsZpTimesRhsZp, kSize);
indexOfLhsZpTimesRhsZpTimesKSizeInput =
addInput(lhsZpTimesRhsZpTimesKSize, mapToNone);
}
// Add the indexing map for the initResult 'output' even though it's unused
indexingMaps.push_back(mapIdentity);
// Create the generic putting all the terms together.
SmallVector<utils::IteratorType> iterators(lhsRank,
utils::IteratorType::parallel);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, acc.getType(), ins, ValueRange{initResult}, indexingMaps, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value matmulEl = args[indexOfMatmulInput];
Value lhsSumsEl = args[indexOfLhsSumsInput];
Value rhsSumsEl = args[indexOfRhsSumsInput];
Value lhsZp = args[indexOfLhsZpInput];
Value rhsZp = args[indexOfRhsZpInput];
Value lhsZpTimesRhsZpTimesKSize =
args[indexOfLhsZpTimesRhsZpTimesKSizeInput];
Value result = matmulEl;
// If the rhs zero-point is not a constant zero, we need to add it
// times the sums along rows of lhs.
if (!rhsZpIsConstantZero) {
Value lhsSumsElTimesRhsZp =
arith::MulIOp::create(b, loc, lhsSumsEl, rhsZp);
result = arith::SubIOp::create(b, loc, result, lhsSumsElTimesRhsZp);
}
// If the lhs zero-point is not a constant zero, we need to add it
// times the sums along columns of rhs.
if (!lhsZpIsConstantZero) {
Value rhsSumsElTimesLhsZp =
arith::MulIOp::create(b, loc, rhsSumsEl, lhsZp);
result = arith::SubIOp::create(b, loc, result, rhsSumsElTimesLhsZp);
}
// Add the final correction term, if neither zero-point is cst zero.
if (!lhsZpIsConstantZero && !rhsZpIsConstantZero) {
result = arith::AddIOp::create(b, loc, result,
lhsZpTimesRhsZpTimesKSize);
}
linalg::YieldOp::create(b, loc, result);
});
return success();
}
};
/// Pass that lowers quantized_matmul to matmul.
class LinalgQuantizedMatmulToMatmulPass final
: public impl::LinalgQuantizedMatmulToMatmulPassBase<
LinalgQuantizedMatmulToMatmulPass> {
public:
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
patterns.add<QuantizedMatmulToMatmul>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::GlobalOptimization