Updated `linalg_ext.online_attention` for `fp8` support (#17808)

We add additional support for `fp8` support include scaling post Q @ K,
exp(QK) normalization, and iterative normalization required for `fp8`s
limited range. This has been numerically evaluated on `llvm-cpu`. The
`rocm` path is currently numerically incorrect.

---------

Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 360ee7e..ce40d00 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1275,12 +1275,16 @@
   }
   bool isTiled = numOutputs == 3;
 
-  SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
-
   // Check if indexing maps can represent attention.
+  SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
   FailureOr<AttentionOpDetail> maybeOpInfo =
       AttentionOpDetail::get(indexingMaps);
 
+  FloatType scaleElementType = dyn_cast<FloatType>(getScale().getType());
+  if (!scaleElementType) {
+    return attnOp->emitOpError("expected scale to be of floating point type");
+  }
+
   // Check shape compatibility based on indexing maps.
   SmallVector<int64_t> shape(getIterationDomainRank());
   SmallVector<bool> foundDims(getIterationDomainRank(), false);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index 1db43d8..cfe994a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -44,6 +44,72 @@
   return genericOp.getResult(0);
 }
 
+static Value reciprocalValue(OpBuilder &b, Location loc, Value input,
+                             Value output) {
+  int64_t rank = cast<ShapedType>(input.getType()).getRank();
+  SmallVector<AffineMap> maps = {b.getMultiDimIdentityMap(rank),
+                                 b.getMultiDimIdentityMap(rank)};
+
+  SmallVector<utils::IteratorType> iteratorTypes(rank,
+                                                 utils::IteratorType::parallel);
+  auto genericOp = b.create<linalg::GenericOp>(
+      loc, output.getType(), ValueRange{input}, output, maps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
+                                        /*isUnsignedCast=*/false);
+        // Convert scale to the same datatype as input.
+        Value one =
+            b.create<arith::ConstantOp>(loc, b.getFloatAttr(in.getType(), 1.0));
+        Value result = b.create<arith::DivFOp>(loc, one, in);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
+                           AffineMap outputMap, Value value, Value output) {
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
+  inputMap = compressedMaps[0];
+  outputMap = compressedMaps[1];
+
+  SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
+                                                 utils::IteratorType::parallel);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, output.getType(), value, output,
+      SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        auto srcTy = cast<FloatType>(args[0].getType());
+        auto dstTy = cast<FloatType>(args[1].getType());
+
+        // We clamp to the min / max of the floating point representation
+        double mnDbl =
+            APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true)
+                .convertToDouble();
+        double mxDbl =
+            APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
+                .convertToDouble();
+
+        // Truncate to the `fp8` range so avoid nan values.
+        Value mn = builder.create<arith::ConstantOp>(
+            loc, builder.getFloatAttr(srcTy, mnDbl));
+        Value mx = builder.create<arith::ConstantOp>(
+            loc, builder.getFloatAttr(srcTy, mxDbl));
+        Value gt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
+                                           args[0], mx);
+        Value lt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+                                           args[0], mn);
+        Value sel0 = b.create<arith::SelectOp>(loc, gt, mx, args[0]);
+        Value sel1 = b.create<arith::SelectOp>(loc, lt, mn, sel0);
+
+        // Convert scale to the same datatype as input.
+        Value trunc = convertScalarToDtype(b, loc, sel1, dstTy,
+                                           /*isUnsignedCast=*/false);
+        b.create<linalg::YieldOp>(loc, trunc);
+      });
+  return genericOp.getResult(0);
+}
+
 template <typename T>
 static Value reduce(OpBuilder &builder, Location loc, AffineMap inputMap,
                     AffineMap outputMap, Value input, Value output) {
@@ -177,7 +243,8 @@
   Value oldAcc = getOutput();
   Value oldMax = getMax();
   Value oldSum = getSum();
-  Type elementType = getQuery().getType().getElementType();
+  Type elementType = getElementTypeOrSelf(getOutput().getType());
+  Type reduceType = getElementTypeOrSelf(oldMax.getType());
 
   FailureOr<AttentionOpDetail> maybeOpInfo =
       AttentionOpDetail::get(getIndexingMapsArray());
@@ -192,20 +259,27 @@
   // have better support for exp2 (we verified that we gain some speedup on
   // some GPUs).
   Value scale = getScale();
-  Value log2e =
-      b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, M_LOG2E));
+  Value log2e = b.create<arith::ConstantOp>(
+      loc, b.getFloatAttr(scale.getType(), M_LOG2E));
   scale = b.create<arith::MulFOp>(loc, scale, log2e);
 
+  auto qETy = getElementTypeOrSelf(query.getType());
+  auto vETy = getElementTypeOrSelf(value.getType());
+
   // In the original algorithm, the scaling is done after the softmax:
   //        softmax(Q @ K.T * scale) @ V
   //
   // But, it is mathematically equivalent to do it on Q first and then multiply
   // it by K.T. This just allows us to do the scaling once, instead of each
-  // iteration of the loop.
-  AffineMap qMap = getQueryMap();
-  AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
-                                      /*symbolCount=*/0, getContext());
-  query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
+  // iteration of the loop. This is only valid for f16 or f32 types as f8
+  // is extremely limited on its dynamic range therefore this would
+  // significantly affect numerics.
+  if (qETy.getIntOrFloatBitWidth() > 8) {
+    AffineMap qMap = getQueryMap();
+    AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
+                                        /*symbolCount=*/0, getContext());
+    query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
+  }
 
   // ---- Matmul 1 ----
 
@@ -224,6 +298,16 @@
   Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
   s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
 
+  // For low bit-depth types we perform post Q @ K scaling. This is to avoid
+  // losing numerical precision due to the low dynamic range of fp8 types when
+  // pre applying the sclaing.
+  if (qETy.getIntOrFloatBitWidth() <= 8) {
+    AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
+    AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
+                                        /*symbolCount=*/0, getContext());
+    s = scaleValueInPlace(b, loc, sMap, scaleMap, s, scale);
+  }
+
   // TODO: This decomposition should be in a seperate op called
   // "online softmax".
   // ---- Online Softmax ----
@@ -246,11 +330,58 @@
   AffineMap sumMap = getSumMap();
   Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);
 
-  // newSum = normSum + rowMax(P)
+  // newSum = normSum + rowSum(P)
   Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
 
   // newAcc = norm * oldAcc
   AffineMap accMap = getOutputMap();
+
+  // ---- Scale and truncate LHS to match RHS ----
+  Value pScale;
+  auto pETy = getElementTypeOrSelf(p.getType());
+  if (pETy != vETy && isa<FloatType>(vETy)) {
+    if (vETy.getIntOrFloatBitWidth() <= 8) {
+      SmallVector<OpFoldResult> mSizes(
+          llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) {
+            return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
+          }));
+
+      // We rescale `p` to use the full range and pass back the `pScale`.
+      Value maxEmpty = b.create<tensor::EmptyOp>(loc, mSizes, reduceType);
+      Value absMax =
+          reduce<arith::MaximumFOp>(b, loc, pMap, maxMap, p, maxEmpty);
+
+      auto fpTy = cast<FloatType>(vETy);
+      double largestDbl =
+          APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
+              .convertToDouble();
+
+      AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
+                                          /*symbolCount=*/0, getContext());
+
+      // We normalize p from [0, max] to [0, fp8.max] to guarantee we
+      // use the full `fp8` range, then renormlize post Softmax@V matmul
+      // to correct.
+      Value largestInv = b.create<arith::ConstantOp>(
+          loc, b.getFloatAttr(elementType, 1.0 / largestDbl));
+      pScale = scaleValueInPlace(b, loc, maxMap, scaleMap, absMax, largestInv);
+
+      // Compute the pre matmul scale to handle fp8 quantization:
+      Value recInit = b.create<tensor::EmptyOp>(loc, mSizes, elementType);
+      Value largest = b.create<arith::ConstantOp>(
+          loc, b.getFloatAttr(elementType, largestDbl));
+      Value pScaleInv = reciprocalValue(b, loc, absMax, recInit);
+      pScaleInv =
+          scaleValueInPlace(b, loc, maxMap, scaleMap, pScaleInv, largest);
+
+      p = scaleValueInPlace(b, loc, pMap, maxMap, p, pScaleInv);
+      norm = scaleValueInPlace(b, loc, normMap, maxMap, norm, pScaleInv);
+    }
+
+    Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
+    p = truncateFloat(b, loc, pMap, pMap, p, convertP);
+  }
+
   Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm);
 
   // ---- Matmul 2 ----
@@ -258,6 +389,11 @@
   // newAcc = P @ V + newAcc
   newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
 
+  // Update for for the FP8 dynamic scale:
+  if (pScale) {
+    newAcc = scaleValueInPlace(b, loc, accMap, maxMap, newAcc, pScale);
+  }
+
   return SmallVector<Value>{newAcc, newMax, newSum};
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
index 945cff8..3e323ed 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -26,6 +26,9 @@
 
 // We just want to check if we are using the correct algorithm.
 // CHECK-LABEL: @attention_f16
+// Q = Q * scale
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
 // S = Q @ K
 // CHECK: linalg.generic
 // CHECK:   arith.mulf
@@ -62,3 +65,81 @@
 // CHECK:   arith.mulf
 // CHECK:   arith.addf
 // CHECK:   linalg.yield
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+                         %key: tensor<192x1024x64xf8E4M3FNUZ>,
+                         %value: tensor<192x1024x64xf8E4M3FNUZ>,
+                         %output: tensor<192x1024x64xf32>,
+                         %max: tensor<192x1024xf32>,
+                         %sum: tensor<192x1024xf32>)
+                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+  %scale = arith.constant 1.0 : f16
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16)
+        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// CHECK-LABEL: @attention_f8
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// S = S * scale
+// CHECK:   linalg.generic
+// CHECK:   arith.mulf
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// clamp = clamp(norm)
+// CHECK: linalg.generic
+// CHECK:   arith.cmpf ogt
+// CHECK:   arith.cmpf olt
+// CHECK:   arith.select
+// CHECK:   arith.select
+// CHECK:   arith.truncf
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK:   arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK:   arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield