Make fp8 attention using static quantized softmax (#17949)
Added a flag that uses a statically selected softmax range for fp8
attention. This allows the value
to be tuned without plumbing through the attention operation. Eventually
this should be a parameter
passed to `linalg_ext.attention`.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index cfe994a..23ba79c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -20,6 +21,13 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+/// Command line options used purely for development purposes. Not to be relied
+/// on in any way.
+static llvm::cl::opt<float> clAttentionSoftmaxMax(
+ "iree-linalgext-attention-softmax-max",
+ llvm::cl::desc("maximum expected value from attention softmax"),
+ llvm::cl::init(1.0));
+
static Value scaleValueInPlace(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap scaleMap,
Value value, Value scale) {
@@ -244,7 +252,6 @@
Value oldMax = getMax();
Value oldSum = getSum();
Type elementType = getElementTypeOrSelf(getOutput().getType());
- Type reduceType = getElementTypeOrSelf(oldMax.getType());
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
@@ -346,36 +353,25 @@
return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
}));
- // We rescale `p` to use the full range and pass back the `pScale`.
- Value maxEmpty = b.create<tensor::EmptyOp>(loc, mSizes, reduceType);
- Value absMax =
- reduce<arith::MaximumFOp>(b, loc, pMap, maxMap, p, maxEmpty);
-
auto fpTy = cast<FloatType>(vETy);
double largestDbl =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();
- AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
- /*symbolCount=*/0, getContext());
-
// We normalize p from [0, max] to [0, fp8.max] to guarantee we
// use the full `fp8` range, then renormlize post Softmax@V matmul
// to correct.
- Value largestInv = b.create<arith::ConstantOp>(
- loc, b.getFloatAttr(elementType, 1.0 / largestDbl));
- pScale = scaleValueInPlace(b, loc, maxMap, scaleMap, absMax, largestInv);
+ pScale = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl));
// Compute the pre matmul scale to handle fp8 quantization:
- Value recInit = b.create<tensor::EmptyOp>(loc, mSizes, elementType);
- Value largest = b.create<arith::ConstantOp>(
- loc, b.getFloatAttr(elementType, largestDbl));
- Value pScaleInv = reciprocalValue(b, loc, absMax, recInit);
- pScaleInv =
- scaleValueInPlace(b, loc, maxMap, scaleMap, pScaleInv, largest);
+ Value pScaleInv = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax));
- p = scaleValueInPlace(b, loc, pMap, maxMap, p, pScaleInv);
- norm = scaleValueInPlace(b, loc, normMap, maxMap, norm, pScaleInv);
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
+ /*symbolCount=*/0, getContext());
+ p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv);
+ norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv);
}
Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
@@ -391,7 +387,9 @@
// Update for for the FP8 dynamic scale:
if (pScale) {
- newAcc = scaleValueInPlace(b, loc, accMap, maxMap, newAcc, pScale);
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
+ /*symbolCount=*/0, getContext());
+ newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale);
}
return SmallVector<Value>{newAcc, newMax, newSum};