[LinalgExt] Implement AggregateOpInterface for AttentionOp (#18890)

- Adds AggregateOpInterface for AttentionOp
- Move all aggregate interface tests to IR/test/decompose_aggregate_op
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
index 204ae35..7fc985b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp
@@ -299,6 +299,212 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Attention Helpers
+//===----------------------------------------------------------------------===//
+
+Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
+                              Value key, Value scale, std::optional<Value> mask,
+                              AffineMap qMap, AffineMap kMap, AffineMap sMap,
+                              std::optional<AffineMap> maskMap,
+                              SmallVector<OpFoldResult> iterationDomain,
+                              Type sElementType, Region &elementwiseRegion,
+                              DictionaryAttr qkAttrs, bool lowPrecision) {
+  MLIRContext *ctx = b.getContext();
+  // Since we use exp2 for attention instead of the original exp, we have to
+  // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
+  // have better support for exp2 (we verified that we gain some speedup on
+  // some GPUs).
+  Value log2e = b.create<arith::ConstantOp>(
+      loc, b.getFloatAttr(scale.getType(), M_LOG2E));
+  scale = b.create<arith::MulFOp>(loc, scale, log2e);
+
+  auto qETy = getElementTypeOrSelf(query.getType());
+
+  AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
+                                      /*symbolCount=*/0, ctx);
+
+  // In the original algorithm, the scaling is done after the softmax:
+  //        softmax(Q @ K.T * scale) @ V
+  //
+  // But, it is mathematically equivalent to do it on Q first and then multiply
+  // it by K.T. This just allows us to do the scaling once, instead of each
+  // iteration of the loop. This is only valid for f16 or f32 types as f8
+  // is extremely limited on its dynamic range therefore this would
+  // significantly affect numerics.
+  if (!lowPrecision) {
+    query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
+                                                   query, scale);
+  }
+
+  // ---- QK Matmul ----
+
+  // Get sizes for S.
+  SmallVector<OpFoldResult> sSizes;
+  for (AffineExpr dimExpr : sMap.getResults()) {
+    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    sSizes.push_back(iterationDomain[dim]);
+  }
+
+  // S = Q @ K
+  // SMap = QMap @ KMap
+  Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, sElementType);
+  Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(sElementType));
+  Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
+
+  s = computeMatmul(b, loc, qMap, kMap, sMap, query, key, s);
+  if (qkAttrs) {
+    s.getDefiningOp()->setAttrs(qkAttrs);
+  }
+
+  s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s);
+
+  if (lowPrecision) {
+    // For low bit-depth types we perform post Q @ K scaling. This is to avoid
+    // losing numerical precision due to the low dynamic range of fp8 types when
+    // pre applying the sclaing.
+    AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
+    AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
+                                        /*symbolCount=*/0, ctx);
+    s = elementwiseValueInPlace<arith::MulFOp>(b, loc, sMap, scaleMap, s,
+                                               scale);
+
+    // If we need to truncate to fp8 post softmax we apply a scaling to use the
+    // full fp8 range. We can do this with a offset as post `exp2` this equates
+    // to multiplying by a static value. We are able to do this as `max` and
+    // `sum` are scaled by the same value so the end result is the same.
+    auto fpTy = cast<FloatType>(qETy);
+    double mx =
+        APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
+            .convertToDouble();
+    Value offset = b.create<arith::ConstantOp>(
+        loc, b.getFloatAttr(sElementType, clAttentionSoftmaxMax / mx));
+    s = elementwiseValueInPlace<arith::AddFOp>(b, loc, sMap, scaleMap, s,
+                                               offset);
+  }
+
+  // S += mask
+  if (mask != nullptr) {
+    s = applyMask(b, loc, sMap, *maskMap, s, mask.value());
+  }
+
+  return s;
+}
+
+//===----------------------------------------------------------------------===//
+// AttentionOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
+  Location loc = getLoc();
+  Value query = getQuery();
+  Value key = getKey();
+  Value value = getValue();
+  std::optional<Value> mask = getMask();
+  DictionaryAttr config = getDecompositionConfigAttr();
+
+  DictionaryAttr qkAttrs, pvAttrs;
+  if (config) {
+    qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
+    pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
+  }
+  Value output = getOutput();
+
+  FailureOr<AttentionOpDetail> maybeOpInfo =
+      AttentionOpDetail::get(getIndexingMapsArray());
+  assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
+  AttentionOpDetail opInfo = maybeOpInfo.value();
+
+  SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
+      getIterationDomain(b), [](Range x) { return x.size; });
+
+  AffineMap qMap = getQueryMap();
+  AffineMap kMap = getKeyMap();
+  AffineMap sMap = opInfo.getSMap();
+
+  auto qETy = getElementTypeOrSelf(query.getType());
+  bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;
+
+  // We compute output of first matmul in f32.
+  Type f32Type = b.getF32Type();
+
+  // ---- QK Matmul + elementwise math ----
+  Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap,
+                                    kMap, sMap, getMaskMap(), sizes, f32Type,
+                                    getRegion(), qkAttrs, lowPrecision);
+
+  // ---- Softmax ----
+
+  AffineMap accMap = getOutputMap();
+
+  llvm::SmallBitVector projectedK2Dims(opInfo.getDomainRank(), false);
+  for (auto dim : opInfo.getK2Dims()) {
+    projectedK2Dims.set(dim);
+  }
+
+  AffineMap maxMap = projectDims(sMap, projectedK2Dims).dropZeroResults();
+  AffineMap sumMap = maxMap;
+
+  SmallVector<OpFoldResult> rowRedSize =
+      applyPermutationMap<OpFoldResult>(maxMap, sizes);
+
+  Value rowRedEmpty = b.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
+
+  Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf,
+                                          getElementTypeOrSelf(output), b, loc,
+                                          /*useOnlyFiniteValue=*/true);
+  Value maxInit =
+      arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, b, loc,
+                              /*useOnlyFiniteValue=*/true);
+  Value sumInit =
+      arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, b, loc);
+
+  Value accFill =
+      b.create<linalg::FillOp>(loc, ValueRange{accInit}, output).getResult(0);
+  Value maxFill =
+      b.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
+          .getResult(0);
+  Value sumFill =
+      b.create<linalg::FillOp>(loc, ValueRange{sumInit}, rowRedEmpty)
+          .getResult(0);
+
+  // max = rowMax(S)
+  Value max = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, maxFill);
+
+  // P = exp2(S - max)
+  AffineMap pMap = sMap;
+  Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s);
+
+  // sum = rowSum(P)
+  Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);
+
+  // P = P / sum
+  p = elementwiseValueInPlace<arith::DivFOp>(b, loc, pMap, sumMap, p, sum);
+
+  // ---- Scale and truncate LHS to match RHS ----
+  SmallVector<OpFoldResult> sSizes;
+  for (AffineExpr dimExpr : sMap.getResults()) {
+    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    sSizes.push_back(sizes[dim]);
+  }
+
+  auto pETy = getElementTypeOrSelf(p.getType());
+  auto vETy = getElementTypeOrSelf(value.getType());
+  if (pETy != vETy && isa<FloatType>(vETy)) {
+    Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
+    p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
+  }
+
+  // result = P @ V + acc
+  Value result =
+      computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, accFill);
+  if (pvAttrs) {
+    result.getDefiningOp()->setAttrs(pvAttrs);
+  }
+
+  return SmallVector<Value>{result};
+}
+
+//===----------------------------------------------------------------------===//
 // OnlineAttentionOp
 //===----------------------------------------------------------------------===//
 
@@ -329,87 +535,17 @@
   SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
       getIterationDomain(b), [](Range x) { return x.size; });
 
-  // Since we use exp2 for attention instead of the original exp, we have to
-  // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
-  // have better support for exp2 (we verified that we gain some speedup on
-  // some GPUs).
-  Value scale = getScale();
-  Value log2e = b.create<arith::ConstantOp>(
-      loc, b.getFloatAttr(scale.getType(), M_LOG2E));
-  scale = b.create<arith::MulFOp>(loc, scale, log2e);
+  AffineMap qMap = getQueryMap();
+  AffineMap kMap = getKeyMap();
+  AffineMap sMap = opInfo.getSMap();
 
   auto qETy = getElementTypeOrSelf(query.getType());
-  auto vETy = getElementTypeOrSelf(value.getType());
-
-  AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(),
-                                      /*symbolCount=*/0, getContext());
-
-  // In the original algorithm, the scaling is done after the softmax:
-  //        softmax(Q @ K.T * scale) @ V
-  //
-  // But, it is mathematically equivalent to do it on Q first and then multiply
-  // it by K.T. This just allows us to do the scaling once, instead of each
-  // iteration of the loop. This is only valid for f16 or f32 types as f8
-  // is extremely limited on its dynamic range therefore this would
-  // significantly affect numerics.
-  if (qETy.getIntOrFloatBitWidth() > 8) {
-    AffineMap qMap = getQueryMap();
-    query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
-                                                   query, scale);
-  }
-
-  // ---- Matmul 1 ----
-
-  // Get sizes for S.
-  AffineMap sMap = opInfo.getSMap();
-  SmallVector<OpFoldResult> sSizes;
-  for (AffineExpr dimExpr : sMap.getResults()) {
-    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    sSizes.push_back(sizes[dim]);
-  }
-
-  // S = Q @ K
-  // SMap = QMap @ KMap
-  Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
-  Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
-  Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
-
-  s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
-  if (qkAttrs) {
-    s.getDefiningOp()->setDiscardableAttrs(qkAttrs);
-  }
-
-  s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);
-
   bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;
-  if (lowPrecision) {
-    // For low bit-depth types we perform post Q @ K scaling. This is to avoid
-    // losing numerical precision due to the low dynamic range of fp8 types when
-    // pre applying the sclaing.
-    AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
-    AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
-                                        /*symbolCount=*/0, getContext());
-    s = elementwiseValueInPlace<arith::MulFOp>(b, loc, sMap, scaleMap, s,
-                                               scale);
 
-    // If we need to truncate to fp8 post softmax we apply a scaling to use the
-    // full fp8 range. We can do this with a offset as post `exp2` this equates
-    // to multiplying by a static value. We are able to do this as `max` and
-    // `sum` are scaled by the same value so the end result is the same.
-    auto fpTy = cast<FloatType>(qETy);
-    double mx =
-        APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
-            .convertToDouble();
-    Value offset = b.create<arith::ConstantOp>(
-        loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx));
-    s = elementwiseValueInPlace<arith::AddFOp>(b, loc, sMap, scaleMap, s,
-                                               offset);
-  }
-
-  // S += mask
-  if (mask != nullptr) {
-    s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value());
-  }
+  // ---- QK Matmul + elementwise math ----
+  Value s = computeQKAndElementwise(
+      loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
+      sizes, elementType, getRegion(), qkAttrs, lowPrecision);
 
   // TODO: This decomposition should be in a seperate op called
   // "online softmax".
@@ -441,7 +577,14 @@
   AffineMap accMap = getOutputMap();
 
   // ---- Scale and truncate LHS to match RHS ----
+  SmallVector<OpFoldResult> sSizes;
+  for (AffineExpr dimExpr : sMap.getResults()) {
+    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    sSizes.push_back(sizes[dim]);
+  }
+
   auto pETy = getElementTypeOrSelf(p.getType());
+  auto vETy = getElementTypeOrSelf(value.getType());
   if (pETy != vETy && isa<FloatType>(vETy)) {
     Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
     p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 329c79c..3b46114 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -475,6 +475,7 @@
       ["getIndexingMapsForResults", "getIndexingMapsForOperands",
        "getStaticLoopRanges"]>,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel
index ade6135..aff5921 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel
@@ -17,6 +17,7 @@
     srcs = enforce_glob(
         [
             "canonicalize.mlir",
+            "decompose_aggregate_op.mlir",
             "invalid.mlir",
             "roundtrip.mlir",
         ],
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt
index f6d6730..36bdf43 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "canonicalize.mlir"
+    "decompose_aggregate_op.mlir"
     "invalid.mlir"
     "roundtrip.mlir"
   TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir
new file mode 100644
index 0000000..fae9e5b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir
@@ -0,0 +1,418 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s
+
+// Spec to decompose custom op.
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>,
+    %rhs1 : tensor<?x?xf32>, %rhs2 : tensor<?x?xf32>, %scalar : f32,
+    %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>)
+    -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
+  %0:2 = iree_linalg_ext.custom_op {
+        indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
+                         affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+                         affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
+                         affine_map<(d0, d1)[s0, s1] -> ()>,
+                         affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
+                         affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
+        iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
+                          #iree_linalg_ext.iterator_type<parallel>]}
+        ins(%lhs1, %rhs1, %rhs2, %scalar
+            : tensor<1000000x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, f32)
+        outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
+      ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?x?xf32>,
+           %s : f32, %t3 : tensor<?x?xf32>, %t4 : tensor<?x?xf32>) :
+        %0 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
+            outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
+        %1 = linalg.matmul ins(%0, %t2 : tensor<?x?xf32>, tensor<?x?xf32>)
+            outs(%t4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+        %2 = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                             affine_map<(d0, d1) -> ()>,
+                             affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%1, %s : tensor<?x?xf32>, f32) outs(%1 : tensor<?x?xf32>) {
+          ^bb0(%b0 : f32, %b1 : f32, %b2 :f32):
+            %3 = arith.addf %b0, %b2 : f32
+            linalg.yield %3 : f32
+        } -> tensor<?x?xf32>
+        iree_linalg_ext.yield %0, %2 : tensor<?x?xf32>, tensor<?x?xf32>
+    } -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
+  return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
+}
+
+// CHECK-LABEL: func @custom_op_decomposition(
+//  CHECK-SAME:     %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
+//  CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[RHS2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[SCALAR:[a-zA-Z0-9]+]]: f32
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
+//       CHECK:   %[[MATMUL1:.+]] = linalg.matmul
+//  CHECK-SAME:       ins(%[[LHS1]], %[[RHS1]] :
+//  CHECK-SAME:       outs(%[[INIT1]] :
+//       CHECK:   %[[MATMUL2:.+]] = linalg.matmul
+//  CHECK-SAME:       ins(%[[MATMUL1]], %[[RHS2]] :
+//  CHECK-SAME:       outs(%[[INIT2]] :
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[MATMUL2]], %[[SCALAR]] :
+//  CHECK-SAME:       outs(%[[MATMUL2]] :
+//       CHECK:   return %[[MATMUL1]], %[[GENERIC]]
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f16(%query: tensor<192x1024x64xf16>,
+                         %key: tensor<192x1024x64xf16>,
+                         %value: tensor<192x1024x64xf16>,
+                         %output: tensor<192x1024x64xf32>)
+                         -> (tensor<192x1024x64xf32>) {
+  %scale = arith.constant 1.0 : f16
+
+  %out = iree_linalg_ext.attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO] }
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
+        outs(%output : tensor<192x1024x64xf32>) {
+                      ^bb0(%score: f32):
+                        iree_linalg_ext.yield %score: f32
+                     }
+        -> tensor<192x1024x64xf32>
+
+  return %out : tensor<192x1024x64xf32>
+}
+
+// We just want to check if we are using the correct algorithm
+// CHECK-LABEL: @attention_f16
+// Q = Q * scale
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.extf
+// CHECK:   arith.extf
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// max = rowMax(S)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// P = exp2(S - max)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// sum = rowSum(P)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// P = P /= sum
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.divf
+// CHECK:   linalg.yield
+// truncf P : f32 to f16
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.truncf
+// CHECK:   linalg.yield
+// newAcc = P @ V
+// CHECK: linalg.generic
+// CHECK:   arith.extf
+// CHECK:   arith.extf
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_f16(%query: tensor<192x1024x64xf16>,
+                         %key: tensor<192x1024x64xf16>,
+                         %value: tensor<192x1024x64xf16>,
+                         %output: tensor<192x1024x64xf32>,
+                         %max: tensor<192x1024xf32>,
+                         %sum: tensor<192x1024xf32>)
+                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+  %scale = arith.constant 1.0 : f16
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
+        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
+                      ^bb0(%score: f32):
+                        iree_linalg_ext.yield %score: f32
+                     }
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// We just want to check if we are using the correct algorithm and the
+// correct number of extf/truncfs are emitted.
+// CHECK-LABEL: @online_attention_f16
+// Q = Q * scale
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.extf
+// CHECK:   arith.extf
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// newSum = normSum + rowSum(P)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK:   arith.extf
+// CHECK:   arith.extf
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+                         %key: tensor<192x1024x64xf8E4M3FNUZ>,
+                         %value: tensor<192x1024x64xf8E4M3FNUZ>,
+                         %output: tensor<192x1024x64xf32>,
+                         %max: tensor<192x1024xf32>,
+                         %sum: tensor<192x1024xf32>)
+                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+  %scale = arith.constant 1.0 : f32
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
+        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
+                      ^bb0(%score: f32):
+                        iree_linalg_ext.yield %score: f32
+                     }
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// CHECK-LABEL: @online_attention_f8
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// S = S * scale
+// CHECK:   linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.mulf
+// CHECK-NEXT:   linalg.yield
+// S = S + F8_linear_offset
+// CHECK:   linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.addf
+// CHECK-NEXT:   linalg.yield
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// newSum = normSum + rowSum(P)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// clamp = clamp(norm)
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.minimumf
+// CHECK:   arith.truncf
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK-NOT: arith.extf
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK:   arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK:   arith.extf [[A:.+]] f8E4M3FNUZ to f32
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+
+// -----
+
+// Spec to decompose online attention op.
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+                              %key: tensor<192x1024x64xf8E4M3FNUZ>,
+                              %value: tensor<192x1024x64xf8E4M3FNUZ>,
+                              %mask: tensor<192x1024x1024xf8E4M3FNUZ>,
+                              %output: tensor<192x1024x64xf32>,
+                              %max: tensor<192x1024xf32>,
+                              %sum: tensor<192x1024xf32>)
+                              -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+  %scale = arith.constant 1.0 : f16
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>)
+        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
+                      ^bb0(%score: f32):
+                        iree_linalg_ext.yield %score: f32
+                     }
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+// CHECK-LABEL: @online_attention_f8_masked
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// S = S * scale
+// CHECK:   linalg.generic
+// CHECK:   arith.mulf
+// S = S + mask
+// CHECK:   arith.addf
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index efe463a6..6ba9d5c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -20,9 +20,7 @@
             "conv2d_to_winograd.mlir",
             "convert_to_loops.mlir",
             "convert_to_online_attention.mlir",
-            "decompose_aggregate_op.mlir",
             "decompose_im2col.mlir",
-            "decompose_online_attention.mlir",
             "decompose_winograd.mlir",
             "distribution.mlir",
             "pad_contraction_to_block_size.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 3288c14..a912973 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -18,9 +18,7 @@
     "conv2d_to_winograd.mlir"
     "convert_to_loops.mlir"
     "convert_to_online_attention.mlir"
-    "decompose_aggregate_op.mlir"
     "decompose_im2col.mlir"
-    "decompose_online_attention.mlir"
     "decompose_winograd.mlir"
     "distribution.mlir"
     "pad_contraction_to_block_size.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir
deleted file mode 100644
index 80b0b7a..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir
+++ /dev/null
@@ -1,62 +0,0 @@
-// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s
-
-func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>,
-    %rhs1 : tensor<?x?xf32>, %rhs2 : tensor<?x?xf32>, %scalar : f32,
-    %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>)
-    -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
-  %0:2 = iree_linalg_ext.custom_op {
-        indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
-                         affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
-                         affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
-                         affine_map<(d0, d1)[s0, s1] -> ()>,
-                         affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
-                         affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
-        iterator_types = [#iree_linalg_ext.iterator_type<parallel>,
-                          #iree_linalg_ext.iterator_type<parallel>]}
-        ins(%lhs1, %rhs1, %rhs2, %scalar
-            : tensor<1000000x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, f32)
-        outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
-      ^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?x?xf32>,
-           %s : f32, %t3 : tensor<?x?xf32>, %t4 : tensor<?x?xf32>) :
-        %0 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
-            outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
-        %1 = linalg.matmul ins(%0, %t2 : tensor<?x?xf32>, tensor<?x?xf32>)
-            outs(%t4 : tensor<?x?xf32>) -> tensor<?x?xf32>
-        %2 = linalg.generic {
-            indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                             affine_map<(d0, d1) -> ()>,
-                             affine_map<(d0, d1) -> (d0, d1)>],
-            iterator_types = ["parallel", "parallel"]}
-            ins(%1, %s : tensor<?x?xf32>, f32) outs(%1 : tensor<?x?xf32>) {
-          ^bb0(%b0 : f32, %b1 : f32, %b2 :f32):
-            %3 = arith.addf %b0, %b2 : f32
-            linalg.yield %3 : f32
-        } -> tensor<?x?xf32>
-        iree_linalg_ext.yield %0, %2 : tensor<?x?xf32>, tensor<?x?xf32>
-    } -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
-  return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
-}
-module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
-    transform.yield
-  }
-}
-// CHECK-LABEL: func @custom_op_decomposition(
-//  CHECK-SAME:     %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
-//  CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-//  CHECK-SAME:     %[[RHS2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-//  CHECK-SAME:     %[[SCALAR:[a-zA-Z0-9]+]]: f32
-//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
-//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32>
-//       CHECK:   %[[MATMUL1:.+]] = linalg.matmul
-//  CHECK-SAME:       ins(%[[LHS1]], %[[RHS1]] :
-//  CHECK-SAME:       outs(%[[INIT1]] :
-//       CHECK:   %[[MATMUL2:.+]] = linalg.matmul
-//  CHECK-SAME:       ins(%[[MATMUL1]], %[[RHS2]] :
-//  CHECK-SAME:       outs(%[[INIT2]] :
-//       CHECK:   %[[GENERIC:.+]] = linalg.generic
-//  CHECK-SAME:       ins(%[[MATMUL2]], %[[SCALAR]] :
-//  CHECK-SAME:       outs(%[[MATMUL2]] :
-//       CHECK:   return %[[MATMUL1]], %[[GENERIC]]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
deleted file mode 100644
index 19bd2ca..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ /dev/null
@@ -1,242 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f16(%query: tensor<192x1024x64xf16>,
-                         %key: tensor<192x1024x64xf16>,
-                         %value: tensor<192x1024x64xf16>,
-                         %output: tensor<192x1024x64xf32>,
-                         %max: tensor<192x1024xf32>,
-                         %sum: tensor<192x1024xf32>)
-                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
-  %scale = arith.constant 1.0 : f16
-
-  %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
-        ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
-        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
-                      ^bb0(%score: f32):
-                        iree_linalg_ext.yield %score: f32
-                     }
-        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
-  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-
-// We just want to check if we are using the correct algorithm and the
-// correct number of extf/truncfs are emitted.
-// CHECK-LABEL: @attention_f16
-// Q = Q * scale
-// CHECK: linalg.generic
-// CHECK:   arith.mulf
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK:   arith.extf
-// CHECK:   arith.extf
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.maximumf
-// CHECK:   linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// newSum = normSum + rowSum(P)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// newAcc = norm * oldAcc
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// newAcc = P @ V + newAcc
-// CHECK: linalg.generic
-// CHECK:   arith.extf
-// CHECK:   arith.extf
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-
-// -----
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
-                         %key: tensor<192x1024x64xf8E4M3FNUZ>,
-                         %value: tensor<192x1024x64xf8E4M3FNUZ>,
-                         %output: tensor<192x1024x64xf32>,
-                         %max: tensor<192x1024xf32>,
-                         %sum: tensor<192x1024xf32>)
-                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
-  %scale = arith.constant 1.0 : f32
-
-  %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
-        ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
-        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
-                      ^bb0(%score: f32):
-                        iree_linalg_ext.yield %score: f32
-                     }
-        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
-  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-
-// CHECK-LABEL: @attention_f8
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// S = S * scale
-// CHECK:   linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.mulf
-// CHECK-NEXT:   linalg.yield
-// S = S + F8_linear_offset
-// CHECK:   linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.addf
-// CHECK-NEXT:   linalg.yield
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.maximumf
-// CHECK:   linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// newSum = normSum + rowSum(P)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// clamp = clamp(norm)
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.minimumf
-// CHECK:   arith.truncf
-// newAcc = norm * oldAcc
-// CHECK: linalg.generic
-// CHECK-NOT: arith.extf
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// newAcc = P @ V + newAcc
-// CHECK: linalg.generic
-// CHECK:   arith.extf [[A:.+]] f8E4M3FNUZ to f32
-// CHECK:   arith.extf [[A:.+]] f8E4M3FNUZ to f32
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-
-// -----
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
-#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
-                              %key: tensor<192x1024x64xf8E4M3FNUZ>,
-                              %value: tensor<192x1024x64xf8E4M3FNUZ>,
-                              %mask: tensor<192x1024x1024xf8E4M3FNUZ>,
-                              %output: tensor<192x1024x64xf32>,
-                              %max: tensor<192x1024xf32>,
-                              %sum: tensor<192x1024xf32>)
-                              -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
-  %scale = arith.constant 1.0 : f16
-
-  %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
-        ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>)
-        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
-                      ^bb0(%score: f32):
-                        iree_linalg_ext.yield %score: f32
-                     }
-        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
-  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-// CHECK-LABEL: @attention_f8_masked
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK:   arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// S = S * scale
-// CHECK:   linalg.generic
-// CHECK:   arith.mulf
-// S = S + mask
-// CHECK:   arith.addf
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK:   arith.maximumf
-// CHECK:   linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// newSum = normSum + rowMax(P)
-// CHECK: linalg.generic
-// CHECK:   arith.addf
-// CHECK:   linalg.yield