[Attention] Only clamp attention for low precision types (#18848)
Post-softmax, the range of output is between 0, 1. For low-precision
types (like fp8), we scale the output range to be between 0, fpMax, so
we can use more of the fp range. For higher precision types like fp16,
we don't do this, so we don't need to clamp the range.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
index b6cd4cc..02d1e71 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
@@ -79,7 +79,8 @@
}
static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
- AffineMap outputMap, Value value, Value output) {
+ AffineMap outputMap, Value value, Value output,
+ bool clampToFPRange) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
inputMap = compressedMaps[0];
@@ -94,19 +95,23 @@
auto srcTy = cast<FloatType>(args[0].getType());
auto dstTy = cast<FloatType>(args[1].getType());
- double mxDbl =
- APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
- .convertToDouble();
+ Value input = args[0];
- // Clamp input to dstTy(usually `fp8`) MAX value to prevent NaNs.
- // We do not clamp for `-MAX` because this function meant to only be
- // used by attention's exp2 who's value is always > 0.
- Value mx = builder.create<arith::ConstantOp>(
- loc, builder.getFloatAttr(srcTy, mxDbl));
- Value clamp = b.create<arith::MinimumFOp>(loc, mx, args[0]);
+ if (clampToFPRange) {
+ double mxDbl =
+ APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
+ .convertToDouble();
+
+ // Clamp input to dstTy(usually `fp8`) MAX value to prevent NaNs.
+ // We do not clamp for `-MAX` because this function meant to only be
+ // used by attention's exp2 who's value is always > 0.
+ Value mx = builder.create<arith::ConstantOp>(
+ loc, builder.getFloatAttr(srcTy, mxDbl));
+ input = b.create<arith::MinimumFOp>(loc, mx, input);
+ }
// Convert scale to the same datatype as input.
- Value trunc = convertScalarToDtype(b, loc, clamp, dstTy,
+ Value trunc = convertScalarToDtype(b, loc, input, dstTy,
/*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, trunc);
});
@@ -370,7 +375,8 @@
s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);
- if (qETy.getIntOrFloatBitWidth() <= 8) {
+ bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;
+ if (lowPrecision) {
// For low bit-depth types we perform post Q @ K scaling. This is to avoid
// losing numerical precision due to the low dynamic range of fp8 types when
// pre applying the sclaing.
@@ -432,7 +438,7 @@
auto pETy = getElementTypeOrSelf(p.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
- p = truncateFloat(b, loc, pMap, pMap, p, convertP);
+ p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
}
Value newAcc = elementwiseValueInPlace<arith::MulFOp>(b, loc, accMap, normMap,