Updated `linalg_ext.online_attention` for `fp8` support (#17808)
We add additional support for `fp8` support include scaling post Q @ K,
exp(QK) normalization, and iterative normalization required for `fp8`s
limited range. This has been numerically evaluated on `llvm-cpu`. The
`rocm` path is currently numerically incorrect.
---------
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 360ee7e..ce40d00 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1275,12 +1275,16 @@
}
bool isTiled = numOutputs == 3;
- SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
-
// Check if indexing maps can represent attention.
+ SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);
+ FloatType scaleElementType = dyn_cast<FloatType>(getScale().getType());
+ if (!scaleElementType) {
+ return attnOp->emitOpError("expected scale to be of floating point type");
+ }
+
// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
SmallVector<bool> foundDims(getIterationDomainRank(), false);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index 1db43d8..cfe994a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -44,6 +44,72 @@
return genericOp.getResult(0);
}
+static Value reciprocalValue(OpBuilder &b, Location loc, Value input,
+ Value output) {
+ int64_t rank = cast<ShapedType>(input.getType()).getRank();
+ SmallVector<AffineMap> maps = {b.getMultiDimIdentityMap(rank),
+ b.getMultiDimIdentityMap(rank)};
+
+ SmallVector<utils::IteratorType> iteratorTypes(rank,
+ utils::IteratorType::parallel);
+ auto genericOp = b.create<linalg::GenericOp>(
+ loc, output.getType(), ValueRange{input}, output, maps, iteratorTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
+ /*isUnsignedCast=*/false);
+ // Convert scale to the same datatype as input.
+ Value one =
+ b.create<arith::ConstantOp>(loc, b.getFloatAttr(in.getType(), 1.0));
+ Value result = b.create<arith::DivFOp>(loc, one, in);
+ b.create<linalg::YieldOp>(loc, result);
+ });
+ return genericOp.getResult(0);
+}
+
+static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
+ AffineMap outputMap, Value value, Value output) {
+ SmallVector<AffineMap> compressedMaps =
+ compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
+ inputMap = compressedMaps[0];
+ outputMap = compressedMaps[1];
+
+ SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
+ utils::IteratorType::parallel);
+ auto genericOp = builder.create<linalg::GenericOp>(
+ loc, output.getType(), value, output,
+ SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ auto srcTy = cast<FloatType>(args[0].getType());
+ auto dstTy = cast<FloatType>(args[1].getType());
+
+ // We clamp to the min / max of the floating point representation
+ double mnDbl =
+ APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true)
+ .convertToDouble();
+ double mxDbl =
+ APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
+ .convertToDouble();
+
+ // Truncate to the `fp8` range so avoid nan values.
+ Value mn = builder.create<arith::ConstantOp>(
+ loc, builder.getFloatAttr(srcTy, mnDbl));
+ Value mx = builder.create<arith::ConstantOp>(
+ loc, builder.getFloatAttr(srcTy, mxDbl));
+ Value gt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
+ args[0], mx);
+ Value lt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+ args[0], mn);
+ Value sel0 = b.create<arith::SelectOp>(loc, gt, mx, args[0]);
+ Value sel1 = b.create<arith::SelectOp>(loc, lt, mn, sel0);
+
+ // Convert scale to the same datatype as input.
+ Value trunc = convertScalarToDtype(b, loc, sel1, dstTy,
+ /*isUnsignedCast=*/false);
+ b.create<linalg::YieldOp>(loc, trunc);
+ });
+ return genericOp.getResult(0);
+}
+
template <typename T>
static Value reduce(OpBuilder &builder, Location loc, AffineMap inputMap,
AffineMap outputMap, Value input, Value output) {
@@ -177,7 +243,8 @@
Value oldAcc = getOutput();
Value oldMax = getMax();
Value oldSum = getSum();
- Type elementType = getQuery().getType().getElementType();
+ Type elementType = getElementTypeOrSelf(getOutput().getType());
+ Type reduceType = getElementTypeOrSelf(oldMax.getType());
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
@@ -192,20 +259,27 @@
// have better support for exp2 (we verified that we gain some speedup on
// some GPUs).
Value scale = getScale();
- Value log2e =
- b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, M_LOG2E));
+ Value log2e = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(scale.getType(), M_LOG2E));
scale = b.create<arith::MulFOp>(loc, scale, log2e);
+ auto qETy = getElementTypeOrSelf(query.getType());
+ auto vETy = getElementTypeOrSelf(value.getType());
+
// In the original algorithm, the scaling is done after the softmax:
// softmax(Q @ K.T * scale) @ V
//
// But, it is mathematically equivalent to do it on Q first and then multiply
// it by K.T. This just allows us to do the scaling once, instead of each
- // iteration of the loop.
- AffineMap qMap = getQueryMap();
- AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
- /*symbolCount=*/0, getContext());
- query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
+ // iteration of the loop. This is only valid for f16 or f32 types as f8
+ // is extremely limited on its dynamic range therefore this would
+ // significantly affect numerics.
+ if (qETy.getIntOrFloatBitWidth() > 8) {
+ AffineMap qMap = getQueryMap();
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
+ /*symbolCount=*/0, getContext());
+ query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
+ }
// ---- Matmul 1 ----
@@ -224,6 +298,16 @@
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
+ // For low bit-depth types we perform post Q @ K scaling. This is to avoid
+ // losing numerical precision due to the low dynamic range of fp8 types when
+ // pre applying the sclaing.
+ if (qETy.getIntOrFloatBitWidth() <= 8) {
+ AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
+ /*symbolCount=*/0, getContext());
+ s = scaleValueInPlace(b, loc, sMap, scaleMap, s, scale);
+ }
+
// TODO: This decomposition should be in a seperate op called
// "online softmax".
// ---- Online Softmax ----
@@ -246,11 +330,58 @@
AffineMap sumMap = getSumMap();
Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);
- // newSum = normSum + rowMax(P)
+ // newSum = normSum + rowSum(P)
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
// newAcc = norm * oldAcc
AffineMap accMap = getOutputMap();
+
+ // ---- Scale and truncate LHS to match RHS ----
+ Value pScale;
+ auto pETy = getElementTypeOrSelf(p.getType());
+ if (pETy != vETy && isa<FloatType>(vETy)) {
+ if (vETy.getIntOrFloatBitWidth() <= 8) {
+ SmallVector<OpFoldResult> mSizes(
+ llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) {
+ return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
+ }));
+
+ // We rescale `p` to use the full range and pass back the `pScale`.
+ Value maxEmpty = b.create<tensor::EmptyOp>(loc, mSizes, reduceType);
+ Value absMax =
+ reduce<arith::MaximumFOp>(b, loc, pMap, maxMap, p, maxEmpty);
+
+ auto fpTy = cast<FloatType>(vETy);
+ double largestDbl =
+ APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
+ .convertToDouble();
+
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
+ /*symbolCount=*/0, getContext());
+
+ // We normalize p from [0, max] to [0, fp8.max] to guarantee we
+ // use the full `fp8` range, then renormlize post Softmax@V matmul
+ // to correct.
+ Value largestInv = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(elementType, 1.0 / largestDbl));
+ pScale = scaleValueInPlace(b, loc, maxMap, scaleMap, absMax, largestInv);
+
+ // Compute the pre matmul scale to handle fp8 quantization:
+ Value recInit = b.create<tensor::EmptyOp>(loc, mSizes, elementType);
+ Value largest = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(elementType, largestDbl));
+ Value pScaleInv = reciprocalValue(b, loc, absMax, recInit);
+ pScaleInv =
+ scaleValueInPlace(b, loc, maxMap, scaleMap, pScaleInv, largest);
+
+ p = scaleValueInPlace(b, loc, pMap, maxMap, p, pScaleInv);
+ norm = scaleValueInPlace(b, loc, normMap, maxMap, norm, pScaleInv);
+ }
+
+ Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
+ p = truncateFloat(b, loc, pMap, pMap, p, convertP);
+ }
+
Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm);
// ---- Matmul 2 ----
@@ -258,6 +389,11 @@
// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
+ // Update for for the FP8 dynamic scale:
+ if (pScale) {
+ newAcc = scaleValueInPlace(b, loc, accMap, maxMap, newAcc, pScale);
+ }
+
return SmallVector<Value>{newAcc, newMax, newSum};
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
index 945cff8..3e323ed 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -26,6 +26,9 @@
// We just want to check if we are using the correct algorithm.
// CHECK-LABEL: @attention_f16
+// Q = Q * scale
+// CHECK: linalg.generic
+// CHECK: arith.mulf
// S = Q @ K
// CHECK: linalg.generic
// CHECK: arith.mulf
@@ -62,3 +65,81 @@
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+ %key: tensor<192x1024x64xf8E4M3FNUZ>,
+ %value: tensor<192x1024x64xf8E4M3FNUZ>,
+ %output: tensor<192x1024x64xf32>,
+ %max: tensor<192x1024xf32>,
+ %sum: tensor<192x1024xf32>)
+ -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+ %scale = arith.constant 1.0 : f16
+
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16)
+ outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// CHECK-LABEL: @attention_f8
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// S = S * scale
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK: arith.maximumf
+// CHECK: linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// clamp = clamp(norm)
+// CHECK: linalg.generic
+// CHECK: arith.cmpf ogt
+// CHECK: arith.cmpf olt
+// CHECK: arith.select
+// CHECK: arith.select
+// CHECK: arith.truncf
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK: arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK: arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield