Make fp8 attention using static quantized softmax (#17949)

Added a flag that uses a statically selected softmax range for fp8
attention. This allows the value
to be tuned without plumbing through the attention operation. Eventually
this should be a parameter
passed to `linalg_ext.attention`.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index cfe994a..23ba79c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -8,6 +8,7 @@
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -20,6 +21,13 @@
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
+/// Command line options used purely for development purposes. Not to be relied
+/// on in any way.
+static llvm::cl::opt<float> clAttentionSoftmaxMax(
+    "iree-linalgext-attention-softmax-max",
+    llvm::cl::desc("maximum expected value from attention softmax"),
+    llvm::cl::init(1.0));
+
 static Value scaleValueInPlace(OpBuilder &builder, Location loc,
                                AffineMap inputMap, AffineMap scaleMap,
                                Value value, Value scale) {
@@ -244,7 +252,6 @@
   Value oldMax = getMax();
   Value oldSum = getSum();
   Type elementType = getElementTypeOrSelf(getOutput().getType());
-  Type reduceType = getElementTypeOrSelf(oldMax.getType());
 
   FailureOr<AttentionOpDetail> maybeOpInfo =
       AttentionOpDetail::get(getIndexingMapsArray());
@@ -346,36 +353,25 @@
             return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
           }));
 
-      // We rescale `p` to use the full range and pass back the `pScale`.
-      Value maxEmpty = b.create<tensor::EmptyOp>(loc, mSizes, reduceType);
-      Value absMax =
-          reduce<arith::MaximumFOp>(b, loc, pMap, maxMap, p, maxEmpty);
-
       auto fpTy = cast<FloatType>(vETy);
       double largestDbl =
           APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
               .convertToDouble();
 
-      AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
-                                          /*symbolCount=*/0, getContext());
-
       // We normalize p from [0, max] to [0, fp8.max] to guarantee we
       // use the full `fp8` range, then renormlize post Softmax@V matmul
       // to correct.
-      Value largestInv = b.create<arith::ConstantOp>(
-          loc, b.getFloatAttr(elementType, 1.0 / largestDbl));
-      pScale = scaleValueInPlace(b, loc, maxMap, scaleMap, absMax, largestInv);
+      pScale = b.create<arith::ConstantOp>(
+          loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl));
 
       // Compute the pre matmul scale to handle fp8 quantization:
-      Value recInit = b.create<tensor::EmptyOp>(loc, mSizes, elementType);
-      Value largest = b.create<arith::ConstantOp>(
-          loc, b.getFloatAttr(elementType, largestDbl));
-      Value pScaleInv = reciprocalValue(b, loc, absMax, recInit);
-      pScaleInv =
-          scaleValueInPlace(b, loc, maxMap, scaleMap, pScaleInv, largest);
+      Value pScaleInv = b.create<arith::ConstantOp>(
+          loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax));
 
-      p = scaleValueInPlace(b, loc, pMap, maxMap, p, pScaleInv);
-      norm = scaleValueInPlace(b, loc, normMap, maxMap, norm, pScaleInv);
+      AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
+                                          /*symbolCount=*/0, getContext());
+      p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv);
+      norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv);
     }
 
     Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
@@ -391,7 +387,9 @@
 
   // Update for for the FP8 dynamic scale:
   if (pScale) {
-    newAcc = scaleValueInPlace(b, loc, accMap, maxMap, newAcc, pScale);
+    AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
+                                        /*symbolCount=*/0, getContext());
+    newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale);
   }
 
   return SmallVector<Value>{newAcc, newMax, newSum};