| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" |
| #include "iree/compiler/GlobalOptimization/Passes.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "iree-global-opt-fuse-dequantization-matmul" |
| #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| |
| namespace mlir::iree_compiler::GlobalOptimization { |
| |
| #define GEN_PASS_DEF_FUSEDEQUANTIZATIONMATMULPASS |
| #include "iree/compiler/GlobalOptimization/Passes.h.inc" |
| |
| namespace { |
| |
| //----------------------------------------------------------------------------// |
| // Utility |
| //----------------------------------------------------------------------------// |
| |
| // Checks if the passed op is a contraction with two reduction dimensions |
| // This function checks that the genericOp: |
| // 1. isaContractionOpInterface |
| // 2. Has 2 reduction dimensions |
| static LogicalResult |
| isContractionWithTwoReductions(linalg::GenericOp genericOp) { |
| unsigned numLoops = genericOp.getNumLoops(); |
| linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(genericOp.getOperation()); |
| if (numLoops == 0) { |
| return failure(); |
| } |
| if (!linalg::isaContractionOpInterface(linalgOp)) { |
| return failure(); |
| } |
| if (genericOp.getNumReductionLoops() != 2) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| // Checks if the passed op is a dequantization on grouped input |
| // This function checks that the genericOp: |
| // 1. Has a body like: |
| // arith.extui |
| // arith.uitofp |
| // arith.subf |
| // arith.mulf |
| // 2. Increases the bit width of the input |
| // 3. Has 3 parallel dims |
| // 4. Has 2 (weights, scales) or 3 (weights, scales, zero points) |
| // inputs and 1 output |
| static LogicalResult isGroupedDequantizationOp(linalg::GenericOp genericOp) { |
| // Check for 1 result, and 2 (input, scales) or 3 (input, scales, zero points) |
| // inputs |
| if (genericOp.getNumDpsInits() != 1) { |
| return failure(); |
| } |
| if (genericOp.getNumDpsInputs() != 2 && genericOp.getNumDpsInputs() != 3) { |
| return failure(); |
| } |
| // Check that the rank is at least 3 and all loops are parallel |
| unsigned numLoops = genericOp.getNumLoops(); |
| unsigned numParallelLoops = genericOp.getNumParallelLoops(); |
| if (numLoops < 3) { |
| return failure(); |
| } |
| if (numLoops != numParallelLoops) { |
| return failure(); |
| } |
| |
| // Work back from linalg.yield and check body of genericOp. |
| // The genericOp should yield the result of an arith.mulf, |
| // preceded by an arith.subf, arith.uitofp, and arith.extui |
| auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator()); |
| Value producerOutput; |
| Operation *producer; |
| |
| // Producer of linalg.yield op is arith.mulf |
| { |
| producerOutput = yieldOp->getOperand(0); |
| producer = producerOutput.getDefiningOp(); |
| if (!producer || producer->getNumOperands() == 0) { |
| return failure(); |
| } |
| if (!matchPattern(producer, m_Op<arith::MulFOp>())) { |
| return failure(); |
| } |
| } |
| |
| // Producer of arith.mulf op is arith.subf |
| { |
| producerOutput = producer->getOperand(0); |
| producer = producerOutput.getDefiningOp(); |
| if (!producer || producer->getNumOperands() == 0) { |
| return failure(); |
| } |
| if (!matchPattern(producer, m_Op<arith::SubFOp>())) { |
| return failure(); |
| } |
| } |
| |
| // Producer of arith.subf op is arith.uitofp |
| { |
| producerOutput = producer->getOperand(0); |
| producer = producerOutput.getDefiningOp(); |
| if (!producer || producer->getNumOperands() == 0) { |
| return failure(); |
| } |
| if (!matchPattern(producer, m_Op<arith::UIToFPOp>())) { |
| return failure(); |
| } |
| } |
| |
| // Producer of arith.uitofp op is arith.extui |
| { |
| producerOutput = producer->getOperand(0); |
| producer = producerOutput.getDefiningOp(); |
| if (!producer) { |
| return failure(); |
| } |
| if (!matchPattern(producer, m_Op<arith::ExtUIOp>())) { |
| return failure(); |
| } |
| } |
| |
| // Ensure that the dequantization increases the |
| // bitwidth from the input to the output |
| auto elementTypeOut = |
| cast<ShapedType>(genericOp.getOutputs()[0].getType()).getElementType(); |
| if (!elementTypeOut.isIntOrFloat()) { |
| return failure(); |
| } |
| unsigned bitWidthOut = elementTypeOut.getIntOrFloatBitWidth(); |
| auto elementTypeIn = |
| cast<ShapedType>(genericOp.getInputs()[0].getType()).getElementType(); |
| if (!elementTypeIn.isIntOrFloat()) { |
| return failure(); |
| } |
| unsigned bitWidthIn = elementTypeIn.getIntOrFloatBitWidth(); |
| if (bitWidthIn >= bitWidthOut) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| static SmallVector<utils::IteratorType> |
| getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) { |
| SmallVector<utils::IteratorType> res(nLoops - nReduction, |
| utils::IteratorType::parallel); |
| res.append(nReduction, utils::IteratorType::reduction); |
| return res; |
| } |
| |
| struct QuantizedMatmulRewriter { |
| QuantizedMatmulRewriter(RewriterBase &rewriter, linalg::GenericOp dequant, |
| linalg::GenericOp matmul, int quantizedBitWidth); |
| std::optional<SmallVector<OpOperand *>> getDequantMatmulInputs(); |
| std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>> |
| getGroupReductionMapsAndIterators(OpOperand *inputOperand); |
| Value getGroupReductionInit(Value input); |
| Value generateGroupMaxGeneric(); |
| Value generateScalesGeneric(Value groupMax); |
| Value generateGroupSumsGeneric(); |
| SmallVector<Value> generateQuantizationGenerics(); |
| linalg::GenericOp |
| generateQuantizedMatmulGeneric(SmallVector<Value> quantizeResults); |
| linalg::GenericOp |
| generateReassociatedDequantizationGeneric(SmallVector<Value> quantizeResults, |
| Value quantizedIntegerMatmul); |
| LogicalResult precondition(); |
| |
| private: |
| // rewriter |
| RewriterBase &rewriter; |
| // linalg::GenericOp that performs the dequantization |
| linalg::GenericOp dequant; |
| // linalg::GenericOp that performs the matmul |
| linalg::GenericOp matmul; |
| // Destination bit width for dynamic quantization of the unquantized input |
| // (`ins[1]`) |
| int quantizedBitWidth; |
| // Type for accumulation of integer matmul |
| IntegerType accType; |
| // Type for multiplication in integer matmul (should probably be same as |
| // `accType`) |
| IntegerType mulType; |
| // Type for dynamic quantization of unquantized input |
| IntegerType quantType; |
| // inputs to the `dequant` and `matmul` ops |
| // ins[0] = quantized matrix input to `dequant` |
| // ins[1] = unquantized input to `matmul` |
| // ins[2] = scales input to `dequant |
| // ins[3] = zero points input to `dequant` |
| // ins[4] = dequantized input to `matmul` |
| SmallVector<OpOperand *> ins; |
| // Location for rewrite |
| Location loc; |
| }; |
| |
| // Takes as input the dequantization `linalg.generic` op and the matmul |
| // `linalg.generic` op, and returns the scales, zero points, quantized |
| // input matrix, unquantizaed input matrix, and dequantized result |
| // matrix. |
| // TODO(#) Have stricter matching on inputs. There may be cases where |
| // the current matching fails |
| std::optional<SmallVector<OpOperand *>> |
| QuantizedMatmulRewriter::getDequantMatmulInputs() { |
| assert(!failed(isContractionWithTwoReductions(matmul)) && |
| "expected `matmul` to be a contraction with two reduction dimensions"); |
| assert(!failed(isGroupedDequantizationOp(dequant)) && |
| "expected `dequant` to be a grouped dequantization"); |
| OpOperand *scales, *zps, *quantMat, *unquantMat, *dequantMat; |
| auto maps = dequant.getIndexingMapsArray(); |
| for (auto [idx, map] : enumerate(ArrayRef<AffineMap>(maps).drop_back())) { |
| if (map.isIdentity()) { |
| quantMat = dequant.getDpsInputOperand(idx); |
| } else if (map.isProjectedPermutation(true)) { |
| for (Operation &bodyOp : dequant.getBlock()->getOperations()) { |
| if (isa<arith::MulFOp>(bodyOp)) { |
| if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) { |
| scales = dequant.getDpsInputOperand(idx); |
| break; |
| } |
| } else if (isa<arith::SubFOp>(bodyOp)) { |
| if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) { |
| zps = dequant.getDpsInputOperand(idx); |
| break; |
| } |
| } |
| } |
| } |
| } |
| Value dequantOut = dequant.getResult(0); |
| if (matmul.getDpsInputOperand(0)->get() == dequantOut) { |
| unquantMat = matmul.getDpsInputOperand(1); |
| dequantMat = matmul.getDpsInputOperand(0); |
| } else { |
| unquantMat = matmul.getDpsInputOperand(0); |
| dequantMat = matmul.getDpsInputOperand(1); |
| } |
| if (scales && zps && quantMat && unquantMat) { |
| return SmallVector<OpOperand *>( |
| {quantMat, unquantMat, scales, zps, dequantMat}); |
| } |
| return std::nullopt; |
| } |
| |
| QuantizedMatmulRewriter::QuantizedMatmulRewriter(RewriterBase &rewriter, |
| linalg::GenericOp dequant, |
| linalg::GenericOp matmul, |
| int quantizedBitWidth) |
| : rewriter(rewriter), dequant(dequant), matmul(matmul), |
| quantizedBitWidth(quantizedBitWidth), loc(dequant.getLoc()) { |
| accType = rewriter.getI32Type(); |
| mulType = rewriter.getI32Type(); |
| quantType = rewriter.getIntegerType(quantizedBitWidth); |
| std::optional<SmallVector<OpOperand *>> inputs = getDequantMatmulInputs(); |
| if (inputs) { |
| ins = *inputs; |
| } |
| } |
| |
| LogicalResult QuantizedMatmulRewriter::precondition() { |
| if (ins.size() != 5) { |
| return rewriter.notifyMatchFailure( |
| matmul, |
| "expected `ins` to have 5 inputs for quantized matmul reassociation"); |
| } |
| OpOperand *unquantizedInputOperand = ins[1]; |
| Value unquantizedInput = ins[1]->get(); |
| RankedTensorType unquantizedInputType = |
| cast<RankedTensorType>(unquantizedInput.getType()); |
| SmallVector<int64_t> unquantizedInputShape(unquantizedInputType.getShape()); |
| AffineMap indexingMap = |
| matmul.getMatchingIndexingMap(unquantizedInputOperand); |
| SmallVector<utils::IteratorType> matmulIteratorTypes = |
| matmul.getIteratorTypesArray(); |
| if (unquantizedInputShape.size() < 2) { |
| return rewriter.notifyMatchFailure( |
| matmul, "input expected to have a rank of at least 2"); |
| } |
| if (matmulIteratorTypes[indexingMap.getNumDims() - 1] == |
| utils::IteratorType::parallel || |
| matmulIteratorTypes[indexingMap.getNumDims() - 2] == |
| utils::IteratorType::parallel) { |
| return rewriter.notifyMatchFailure( |
| matmul, "inner 2 dimensions of matmul expected to be reduction"); |
| } |
| auto affineExprs = indexingMap.getResults(); |
| auto innerDim0 = dyn_cast<AffineDimExpr>(affineExprs.back()); |
| auto innerDim1 = dyn_cast<AffineDimExpr>(affineExprs[affineExprs.size() - 2]); |
| if (!innerDim0 || !innerDim1 || |
| innerDim0.getPosition() != indexingMap.getNumDims() - 1 || |
| innerDim1.getPosition() != indexingMap.getNumDims() - 2) { |
| return rewriter.notifyMatchFailure( |
| matmul, "inner shape of input expected to be reduced in matmul"); |
| } |
| Value scales = ins[2]->get(); |
| Value zps = ins[3]->get(); |
| if (!isa<FloatType>(unquantizedInputType.getElementType()) || |
| !isa<FloatType>(getElementTypeOrSelf(scales)) || |
| !isa<FloatType>(getElementTypeOrSelf(zps))) { |
| return rewriter.notifyMatchFailure(matmul, "expected float type"); |
| } |
| OpOperand *matmulDequantizedOperand = ins[4]; |
| auto matmulDequantizedInputExprs = |
| matmul.getMatchingIndexingMap(matmulDequantizedOperand).getResults(); |
| auto scalesType = dyn_cast<RankedTensorType>(scales.getType()); |
| auto zpsType = dyn_cast<RankedTensorType>(zps.getType()); |
| if (!scalesType || !zpsType) { |
| return rewriter.notifyMatchFailure( |
| dequant, "expected scales and zero points to have RankedTensorType"); |
| } |
| if (scalesType.getShape().size() != matmulDequantizedInputExprs.size() - 1) { |
| if (scalesType.getShape().size() != matmulDequantizedInputExprs.size() || |
| scalesType.getShape().back() != 1) { |
| return rewriter.notifyMatchFailure(dequant, "unexpected rank for scales"); |
| } |
| } |
| if (zpsType.getShape().size() != matmulDequantizedInputExprs.size() - 1) { |
| if (zpsType.getShape().size() != matmulDequantizedInputExprs.size() || |
| zpsType.getShape().back() != 1) { |
| return rewriter.notifyMatchFailure(dequant, |
| "unexpected rank for zero points"); |
| } |
| } |
| return success(); |
| } |
| |
| std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>> |
| QuantizedMatmulRewriter::getGroupReductionMapsAndIterators( |
| OpOperand *inputOperand) { |
| Value input = inputOperand->get(); |
| RankedTensorType inputType = cast<RankedTensorType>(input.getType()); |
| SmallVector<int64_t> inputShape(inputType.getShape()); |
| AffineMap indexingMap = matmul.getMatchingIndexingMap(inputOperand); |
| SmallVector<utils::IteratorType> matmulIteratorTypes = |
| matmul.getIteratorTypesArray(); |
| assert(inputShape.size() >= 2 && |
| "input expected to have a rank of at least 2"); |
| assert((matmulIteratorTypes[indexingMap.getNumDims() - 1] == |
| utils::IteratorType::reduction && |
| matmulIteratorTypes[indexingMap.getNumDims() - 2] == |
| utils::IteratorType::reduction) && |
| "inner 2 dimensions of matmul expected to be reduction"); |
| auto affineExprs = indexingMap.getResults(); |
| auto innerDim0 = dyn_cast<AffineDimExpr>(affineExprs.back()); |
| auto innerDim1 = dyn_cast<AffineDimExpr>(affineExprs[affineExprs.size() - 2]); |
| assert(innerDim0 && innerDim1 && |
| innerDim0.getPosition() == indexingMap.getNumDims() - 1 && |
| innerDim1.getPosition() == indexingMap.getNumDims() - 2 && |
| "inner shape of input expected to be reduced in matmul"); |
| (void)innerDim0; |
| (void)innerDim1; |
| |
| SmallVector<utils::IteratorType> iterators(inputShape.size(), |
| utils::IteratorType::parallel); |
| iterators.back() = utils::IteratorType::reduction; |
| AffineMap inputMap = rewriter.getMultiDimIdentityMap(inputShape.size()); |
| AffineMap outputMap = rewriter.getMultiDimIdentityMap(inputShape.size()) |
| .getMajorSubMap(inputShape.size() - 1); |
| SmallVector<AffineMap> maps{inputMap, outputMap}; |
| return std::make_pair(maps, iterators); |
| } |
| |
| // Helper to create an init Value for reductions along the group dimension. |
| Value QuantizedMatmulRewriter::getGroupReductionInit(Value input) { |
| RankedTensorType inputType = cast<RankedTensorType>(input.getType()); |
| assert(isa<FloatType>(inputType.getElementType()) && "expected float type"); |
| Value zero = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getFloatAttr(inputType.getElementType(), 0.0)); |
| SmallVector<int64_t> inputShape(inputType.getShape()); |
| SmallVector<int64_t> outputShape(llvm::drop_end(inputShape)); |
| RankedTensorType outputType = |
| RankedTensorType::get(outputShape, inputType.getElementType()); |
| Value emptyOut = tensor::EmptyOp::create(rewriter, loc, outputType.getShape(), |
| outputType.getElementType()); |
| return linalg::FillOp::create(rewriter, loc, zero, emptyOut).result(); |
| } |
| |
| // Creates a generic that computes the absolute max along the group |
| // dimension, and returns the result. |
| Value QuantizedMatmulRewriter::generateGroupMaxGeneric() { |
| OpOperand *inputOperand = ins[1]; |
| std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>> |
| mapsAndIterators = getGroupReductionMapsAndIterators(inputOperand); |
| auto maps = mapsAndIterators.first; |
| auto iterators = mapsAndIterators.second; |
| Value input = inputOperand->get(); |
| Value output = getGroupReductionInit(input); |
| auto groupMaxOp = linalg::GenericOp::create( |
| rewriter, loc, output.getType(), input, output, maps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| Value abs = math::AbsFOp::create(b, nestedLoc, args[0]); |
| Value max = arith::MaximumFOp::create(b, nestedLoc, abs, args[1]); |
| linalg::YieldOp::create(b, nestedLoc, max); |
| }); |
| LLVM_DEBUG(DBGS() << "groupMaxOp: " << groupMaxOp << "\n"); |
| return groupMaxOp.getResult(0); |
| } |
| |
| // Creates a generic that computes the scales for each group, and |
| // returns the result. |
| Value QuantizedMatmulRewriter::generateScalesGeneric(Value groupMax) { |
| auto groupMaxType = cast<RankedTensorType>(groupMax.getType()); |
| assert(isa<FloatType>(groupMaxType.getElementType()) && |
| "expected float type"); |
| Value cst = arith::ConstantOp::create( |
| rewriter, loc, |
| rewriter.getFloatAttr(groupMaxType.getElementType(), |
| (1 << (quantizedBitWidth - 1)) - 1)); |
| Value output = tensor::EmptyOp::create(rewriter, loc, groupMaxType.getShape(), |
| groupMaxType.getElementType()); |
| SmallVector<AffineMap> maps( |
| 2, rewriter.getMultiDimIdentityMap(groupMaxType.getShape().size())); |
| auto scalesOp = linalg::GenericOp::create( |
| rewriter, loc, output.getType(), groupMax, output, maps, |
| getParallelAndReductionIterators(groupMaxType.getRank(), 0), |
| [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| Value scale = arith::DivFOp::create(b, nestedLoc, args[0], cst); |
| linalg::YieldOp::create(b, nestedLoc, scale); |
| }); |
| LLVM_DEBUG(DBGS() << "scalesOp: " << scalesOp << "\n"); |
| return scalesOp.getResult(0); |
| } |
| |
| // Creates a generic that computes the sums for each group, and |
| // returns the result. |
| Value QuantizedMatmulRewriter::generateGroupSumsGeneric() { |
| OpOperand *inputOperand = ins[1]; |
| std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>> |
| mapsAndIterators = getGroupReductionMapsAndIterators(inputOperand); |
| auto maps = mapsAndIterators.first; |
| auto iterators = mapsAndIterators.second; |
| Value input = inputOperand->get(); |
| Value output = getGroupReductionInit(input); |
| auto groupSumsOp = linalg::GenericOp::create( |
| rewriter, loc, output.getType(), input, output, maps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| Value sum = arith::AddFOp::create(b, nestedLoc, args[0], args[1]); |
| linalg::YieldOp::create(b, nestedLoc, sum); |
| }); |
| LLVM_DEBUG(DBGS() << "groupSumsOp: " << groupSumsOp << "\n"); |
| return groupSumsOp.getResult(0); |
| } |
| |
| // Creates 4 linalg::GenericOps that collectively perform a symmetric |
| // quantization of the unquantized input, and returns the results. |
| SmallVector<Value> QuantizedMatmulRewriter::generateQuantizationGenerics() { |
| assert(ins.size() == 5 && "expected `ins` to have 5 inputs"); |
| OpOperand *unquantizedInputOperand = ins[1]; |
| Value unquantizedInput = unquantizedInputOperand->get(); |
| |
| auto unquantizedType = cast<RankedTensorType>(unquantizedInput.getType()); |
| SmallVector<int64_t> unquantizedShape(unquantizedType.getShape()); |
| |
| IntegerType quantizedElementType = rewriter.getIntegerType(quantizedBitWidth); |
| Value groupMax = generateGroupMaxGeneric(); |
| Value scales = generateScalesGeneric(groupMax); |
| Value groupSums = generateGroupSumsGeneric(); |
| |
| Value output = tensor::EmptyOp::create(rewriter, loc, unquantizedShape, |
| quantizedElementType); |
| AffineMap inputMap = rewriter.getMultiDimIdentityMap(unquantizedShape.size()); |
| AffineMap scalesMap = rewriter.getMultiDimIdentityMap(unquantizedShape.size()) |
| .getMajorSubMap(unquantizedShape.size() - 1); |
| AffineMap outputMap = |
| rewriter.getMultiDimIdentityMap(unquantizedShape.size()); |
| SmallVector<AffineMap> maps{inputMap, scalesMap, outputMap}; |
| auto quantizeOp = linalg::GenericOp::create( |
| rewriter, loc, output.getType(), ValueRange{unquantizedInput, scales}, |
| output, maps, |
| getParallelAndReductionIterators(unquantizedShape.size(), 0), |
| [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| Value scaled = arith::DivFOp::create(b, nestedLoc, args[0], args[1]); |
| Value quant = |
| arith::FPToSIOp::create(b, nestedLoc, quantizedElementType, scaled); |
| linalg::YieldOp::create(b, nestedLoc, quant); |
| }); |
| LLVM_DEBUG(DBGS() << "quantizeOp: " << quantizeOp << "\n"); |
| Value newQuantizedInput = quantizeOp.getResult(0); |
| SmallVector<Value> results{groupMax, scales, groupSums, newQuantizedInput}; |
| return results; |
| } |
| |
| // Creates a generic that computes the main matmul computation in integer |
| // arithmetic, while values are still quantized, and returns the result. |
| linalg::GenericOp QuantizedMatmulRewriter::generateQuantizedMatmulGeneric( |
| SmallVector<Value> quantizeResults) { |
| Value newQuantizedInput = quantizeResults.back(); |
| assert(ins.size() == 5 && "expected `ins` to have 5 inputs"); |
| Value quantizedInput = ins[0]->get(); |
| OpOperand *matmulUnquantizedOperand = ins[1]; |
| OpOperand *matmulDequantizedOperand = ins[4]; |
| OpOperand *matmulOutputOperand = matmul.getDpsInitOperand(0); |
| |
| SmallVector<utils::IteratorType> iterators = |
| getParallelAndReductionIterators(matmul.getNumLoops(), 1); |
| SmallVector<AffineMap> maps; |
| AffineMap matmulUnquantizedMap = |
| matmul.getMatchingIndexingMap(matmulUnquantizedOperand); |
| AffineMap matmulDequantizedMap = |
| matmul.getMatchingIndexingMap(matmulDequantizedOperand); |
| AffineMap matmulOutputMap = |
| matmul.getMatchingIndexingMap(matmulOutputOperand); |
| maps.push_back(matmulUnquantizedMap); |
| maps.push_back(matmulDequantizedMap); |
| SmallVector<AffineExpr> outputExprs(matmulOutputMap.getResults()); |
| outputExprs.push_back(rewriter.getAffineDimExpr(matmul.getNumLoops() - 2)); |
| maps.push_back(AffineMap::get(iterators.size(), 0, outputExprs, |
| outputExprs.front().getContext())); |
| |
| SmallVector<int64_t> newQuantizedInputShape( |
| cast<RankedTensorType>(newQuantizedInput.getType()).getShape()); |
| assert(newQuantizedInputShape.size() >= 2 && |
| "expected new quantized input to have a rank of at least 2"); |
| SmallVector<int64_t> outputShape( |
| cast<RankedTensorType>(matmulOutputOperand->get().getType()).getShape()); |
| outputShape.push_back( |
| newQuantizedInputShape[newQuantizedInputShape.size() - 2]); |
| Value zero = arith::ConstantOp::create(rewriter, loc, |
| rewriter.getIntegerAttr(accType, 0.0)); |
| Value emptyOut = tensor::EmptyOp::create(rewriter, loc, outputShape, accType); |
| Value output = linalg::FillOp::create(rewriter, loc, zero, emptyOut).result(); |
| auto integerMatmulOp = linalg::GenericOp::create( |
| rewriter, loc, output.getType(), |
| ValueRange{newQuantizedInput, quantizedInput}, output, maps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| Value mul; |
| if (quantType == mulType) { |
| Value ext1 = arith::ExtUIOp::create(b, nestedLoc, mulType, args[1]); |
| mul = arith::MulIOp::create(b, nestedLoc, args[0], ext1); |
| } else { |
| Value ext0 = arith::ExtSIOp::create(b, nestedLoc, mulType, args[0]); |
| Value ext1 = arith::ExtUIOp::create(b, nestedLoc, mulType, args[1]); |
| mul = arith::MulIOp::create(b, nestedLoc, ext0, ext1); |
| } |
| Value sum; |
| if (mulType == accType) { |
| sum = arith::AddIOp::create(b, nestedLoc, mul, args[2]); |
| } else { |
| Value extMul = arith::ExtSIOp::create(b, nestedLoc, accType, mul); |
| sum = arith::AddIOp::create(b, nestedLoc, extMul, args[2]); |
| } |
| linalg::YieldOp::create(b, nestedLoc, sum); |
| }); |
| LLVM_DEBUG(DBGS() << "integerMatmulOp: " << integerMatmulOp << "\n"); |
| return integerMatmulOp; |
| } |
| |
| // Creates a generic that does the reassociated dequantization math, and does |
| // the final reduction along the number of groups dimension. |
| linalg::GenericOp |
| QuantizedMatmulRewriter::generateReassociatedDequantizationGeneric( |
| SmallVector<Value> quantizeResults, Value quantizedIntegerMatmul) { |
| assert(quantizeResults.size() == 4 && |
| "expected 4 ops from quantization step"); |
| Value scales = ins[2]->get(); |
| Value zps = ins[3]->get(); |
| Value newScales = quantizeResults[1]; |
| Value groupSums = quantizeResults[2]; |
| |
| SmallVector<utils::IteratorType> iterators = |
| getParallelAndReductionIterators(matmul.getNumLoops() - 1, 1); |
| SmallVector<AffineMap> maps; |
| AffineMap quantizedIntegerMatmulMap = |
| rewriter.getMultiDimIdentityMap(iterators.size()); |
| maps.push_back(quantizedIntegerMatmulMap); |
| |
| { |
| OpOperand *matmulUnquantizedOperand = ins[1]; |
| auto exprs = |
| matmul.getMatchingIndexingMap(matmulUnquantizedOperand).getResults(); |
| exprs = exprs.slice(0, exprs.size() - 1); |
| auto newScalesAndSumsMap = |
| AffineMap::get(iterators.size(), 0, exprs, exprs.front().getContext()); |
| maps.push_back(newScalesAndSumsMap); |
| maps.push_back(newScalesAndSumsMap); |
| } |
| |
| { |
| OpOperand *matmulDequantizedOperand = ins[4]; |
| auto exprsRef = |
| matmul.getMatchingIndexingMap(matmulDequantizedOperand).getResults(); |
| SmallVector<AffineExpr> exprs(exprsRef.slice(0, exprsRef.size() - 1)); |
| RankedTensorType scalesType = cast<RankedTensorType>(scales.getType()); |
| RankedTensorType zpsType = cast<RankedTensorType>(zps.getType()); |
| if (exprs.size() < scalesType.getShape().size() && |
| scalesType.getShape().back() == 1 && zpsType.getShape().back() == 1) { |
| exprs.push_back(rewriter.getAffineConstantExpr(0)); |
| } |
| assert(exprs.size() == scalesType.getShape().size() && |
| "unexpected rank for scales"); |
| assert(exprs.size() == zpsType.getShape().size() && |
| "unexpected rank for zero points"); |
| auto scalesAndZpsMap = |
| AffineMap::get(iterators.size(), 0, exprs, exprs.front().getContext()); |
| maps.push_back(scalesAndZpsMap); |
| maps.push_back(scalesAndZpsMap); |
| } |
| |
| OpOperand *matmulOutputOperand = matmul.getDpsInitOperand(0); |
| SmallVector<AffineExpr> outputExprs( |
| matmul.getMatchingIndexingMap(matmulOutputOperand).getResults()); |
| auto outputMap = AffineMap::get(iterators.size(), 0, outputExprs, |
| outputExprs.front().getContext()); |
| maps.push_back(outputMap); |
| |
| Type floatType = getElementTypeOrSelf(scales); |
| Value output = matmulOutputOperand->get(); |
| auto reassociatedDequantizationOp = linalg::GenericOp::create( |
| rewriter, loc, output.getType(), |
| ValueRange{quantizedIntegerMatmul, newScales, groupSums, scales, zps}, |
| output, maps, iterators, |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value dq; |
| dq = arith::SIToFPOp::create(b, loc, floatType, args[0]); |
| Value scaledRes0 = arith::MulFOp::create(b, loc, dq, args[1]); |
| Value scaledRes1 = arith::MulFOp::create(b, loc, scaledRes0, args[3]); |
| Value scaledZp0 = arith::MulFOp::create(b, loc, args[4], args[3]); |
| Value scaledZp1 = arith::MulFOp::create(b, loc, scaledZp0, args[2]); |
| Value groupRes = arith::SubFOp::create(b, loc, scaledRes1, scaledZp1); |
| Value sum = arith::AddFOp::create(b, loc, groupRes, args[5]); |
| linalg::YieldOp::create(b, loc, sum); |
| }); |
| LLVM_DEBUG(DBGS() << "reassociatedDequantizationOp: " |
| << reassociatedDequantizationOp << "\n"); |
| return reassociatedDequantizationOp; |
| } |
| |
| // This function does the bulk of the rewrite for the dequantization + matmul. |
| // |
| // Starting with 2 `linalg.generic` ops (dequantization->matmul) |
| // %arg0 = quantized input |
| // %arg1 = scales |
| // %arg2 = zero points |
| // %arg3 = unquantized input |
| // ```mlir |
| // %0 = linalg.generic ins(%arg0, %arg1, %arg2 : tensor<8x4x2xi4>, |
| // tensor<8x4x1xf32>, tensor<8x4x1xf32>) outs(%1 : tensor<8x4x2xf32>) { |
| // ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32): |
| // %9 = arith.extui %in : i4 to i32 |
| // %10 = arith.uitofp %9 : i32 to f32 |
| // %11 = arith.subf %10, %in_1 : f32 |
| // %12 = arith.mulf %11, %in_0 : f32 |
| // linalg.yield %12 : f32 |
| // } -> tensor<8x4x2xf32> |
| // %2 = linalg.generic ins(%arg3, %0 : tensor<4x2xf32>, tensor<8x4x2xf32>) |
| // outs(%3 : tensor<8xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32): |
| // %9 = arith.mulf %in, %in_0 : f32 |
| // %10 = arith.addf %9, %out : f32 |
| // linalg.yield %10 : f32 |
| // } -> tensor<8xf32> |
| // ``` |
| // |
| // This function rewrites the above ops as the following new sequence of 6 ops |
| // that does the following: |
| // |
| // a) Dynamically quantize the unquantized input: |
| // 1. Compute the absolute max of the unquantized input (%arg3) within each |
| // group. |
| // 2. Compute scales for %arg3 by dividing the absolute max by (1 << |
| // newBitWidth) - 1), |
| // where newBitWidth is the bitwidth of the new quantized type, |
| // currently set to `i16`. |
| // 3. Compute the sum along groups of the unquantized input. This is not |
| // necessary for |
| // the quantization step, but it is needed to efficiently perform a |
| // reassociated quantized matmul in steps 5-6. |
| // 4. Quantize the unquantized input (%arg3) by dividing elements in each |
| // group by |
| // the corresponding scale. This quantization is symmetric with no zero |
| // point. |
| // b) Perform the reassociated quantized matmul, keeping the bulk of the |
| // computation in |
| // integer arithmetic: |
| // 5. The first op performs a matmul-like operation that reduces the |
| // innermost group |
| // dimension. Note the indexing maps in the following example: |
| // ```mlir |
| // %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> |
| // (d1, d2)>, |
| // affine_map<(d0, d1, d2) -> |
| // (d0, d1, d2)>, affine_map<(d0, |
| // d1, d2) -> (d0, d1)>], |
| // iterator_types = ["parallel", "parallel", |
| // "reduction"]} ins(%17, %0 : tensor<4x2xi16>, |
| // tensor<8x4x2xi4>) outs(%19 : tensor<8x4xi32>) { |
| // ^bb0(%in: i16, %in_4: i4, %out: i32): |
| // %24 = arith.extsi %in : i16 to i32 |
| // %25 = arith.extui %in_4 : i4 to i32 |
| // %26 = arith.muli %24, %25 : i32 |
| // %27 = arith.addi %26, %out : i32 |
| // linalg.yield %27 : i32 |
| // } -> tensor<8x4xi32> |
| // ``` |
| // This op also extends the inputs to the accumulation type, i32 in this |
| // case, to target specific x86 instructions. We perform the matrix |
| // multiplication before the dequantization arithmetic, which has been |
| // reassociated into op 6. |
| // 6. The final op performs the remaining reduction across groups and does |
| // the |
| // dequantization arithmetic: |
| // ```mlir |
| // %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, |
| // d1)>, |
| // affine_map<(d0, d1) -> (d1)>, |
| // affine_map<(d0, d1) -> (d1)>, |
| // affine_map<(d0, d1) -> (d0, |
| // d1)>, affine_map<(d0, d1) -> |
| // (d0, d1)>, affine_map<(d0, d1) |
| // -> (d0)>], |
| // iterator_types = ["parallel", "reduction"]} |
| // ins(tensor<8x4xi32>, tensor<4xf32>, |
| // tensor<4xf32>, tensor<8x4xf32>, |
| // tensor<8x4xf32>) outs(tensor<8xf32>) { |
| // ^bb0(%in: i32, %in_4: f32, %in_5: f32, %in_6: f32, %in_7: f32, |
| // %out: f32): |
| // %24 = arith.sitofp %in : i32 to f32 |
| // %25 = arith.mulf %24, %in_4 : f32 |
| // %26 = arith.mulf %25, %in_6 : f32 |
| // %27 = arith.mulf %in_7, %in_6 : f32 |
| // %28 = arith.mulf %27, %in_5 : f32 |
| // %29 = arith.subf %26, %28 : f32 |
| // %30 = arith.addf %29, %out : f32 |
| // linalg.yield %30 : f32 |
| // } -> tensor<8xf32> |
| // ``` |
| // |
| // ** Note that this rewrite introduces precision loss in the matmul, and is a |
| // tradeoff between precision and performance. This rewrite should most |
| // likely be opt-in only. ** |
| static LogicalResult reassociateDequantMatmul(RewriterBase &rewriter, |
| linalg::GenericOp dequant, |
| linalg::GenericOp matmul, |
| int quantizeBitWidth) { |
| QuantizedMatmulRewriter qmr(rewriter, dequant, matmul, quantizeBitWidth); |
| if (failed(qmr.precondition())) { |
| return success(); |
| } |
| SmallVector<Value> quantizeResults = qmr.generateQuantizationGenerics(); |
| linalg::GenericOp quantizedIntegerMatmul = |
| qmr.generateQuantizedMatmulGeneric(quantizeResults); |
| linalg::GenericOp reassociatedDequantization = |
| qmr.generateReassociatedDequantizationGeneric( |
| quantizeResults, quantizedIntegerMatmul.getResult(0)); |
| |
| rewriter.replaceOp(matmul, reassociatedDequantization.getResult(0)); |
| |
| return success(); |
| } |
| |
| struct FuseDequantizationMatmulPass |
| : public impl::FuseDequantizationMatmulPassBase< |
| FuseDequantizationMatmulPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<linalg::LinalgDialect, IREE::Flow::FlowDialect, |
| math::MathDialect>(); |
| } |
| void runOnOperation() override; |
| }; |
| |
| } // namespace |
| |
| void FuseDequantizationMatmulPass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| mlir::FunctionOpInterface funcOp = getOperation(); |
| |
| int quantizeBitWidth = 16; |
| SmallVector<std::pair<linalg::GenericOp, linalg::GenericOp>> candidates; |
| for (auto genericOp : funcOp.getFunctionBody().getOps<linalg::GenericOp>()) { |
| if (failed(isContractionWithTwoReductions(genericOp))) { |
| continue; |
| } |
| |
| OpOperand *lhs = genericOp.getDpsInputOperand(0); |
| OpOperand *rhs = genericOp.getDpsInputOperand(1); |
| auto lhsOp = lhs->get().getDefiningOp<linalg::GenericOp>(); |
| auto rhsOp = rhs->get().getDefiningOp<linalg::GenericOp>(); |
| if (!cast<ShapedType>(genericOp.getInputs()[0].getType()) |
| .hasStaticShape() || |
| !cast<ShapedType>(genericOp.getInputs()[1].getType()) |
| .hasStaticShape() || |
| !cast<ShapedType>(genericOp.getResults()[0].getType()) |
| .hasStaticShape()) { |
| // Codegen can't handle the dynamic case yet. |
| continue; |
| } |
| if (lhsOp) { |
| if (!failed(isGroupedDequantizationOp(lhsOp))) { |
| candidates.push_back(std::make_pair(lhsOp, genericOp)); |
| continue; |
| } |
| } |
| if (rhsOp) { |
| if (!failed(isGroupedDequantizationOp(rhsOp))) { |
| candidates.push_back(std::make_pair(rhsOp, genericOp)); |
| } |
| } |
| } |
| IRRewriter rewriter(context); |
| for (auto candidate : candidates) { |
| rewriter.setInsertionPointAfter(candidate.second); |
| if (failed(reassociateDequantMatmul(rewriter, candidate.first, |
| candidate.second, quantizeBitWidth))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| } // namespace mlir::iree_compiler::GlobalOptimization |