[LinalgExt] Reland: Add online_attention op (#17681)

This patch adds a new online_attention op. This op represents a
partially reduced attention op which can be tiled along it's k2
reduction dimension. This op also has indexing maps, supports tiling on
all dimensions other than k1 dimension, and can decompose based on any
given indexing maps.

This patch also makes the CPU backend use online attention to decompose
and tile reduction dimension, allowing it to be tiled along N and batch
dimensions, and tiling using LLVMCPUTile.

This is a reland of https://github.com/iree-org/iree/pull/17658 , with
more conservative tile size selection to not unroll too much.
diff --git a/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json b/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json
index 0e97930..083915b 100644
--- a/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json
+++ b/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json
@@ -16,10 +16,9 @@
   "expected_compile_failures": [
     // TODO(#17344): need to regenerate .mlirbc
     "opt-125M",
-    "resnet50",
-    // TODO(#17467): Remove the workaround once we have better support for attention op codegen.
-    "sdxl-vae-decode-tank"
+    "resnet50"
   ],
   "expected_run_failures": [
+    "sdxl-vae-decode-tank"
   ]
 }
diff --git a/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_cpu_llvm_task.json b/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_cpu_llvm_task.json
index 2deef58..d2dff6f 100644
--- a/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_cpu_llvm_task.json
+++ b/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_cpu_llvm_task.json
@@ -3,9 +3,7 @@
   "iree_compile_flags": [
     "--iree-hal-target-backends=llvm-cpu",
     "--iree-llvmcpu-target-cpu-features=host",
-    "--iree-input-demote-f64-to-f32",
-    // TODO(#17467): Remove the workaround once we have better support for attention op codegen.
-    "--iree-llvmcpu-fail-on-large-vector=false"
+    "--iree-input-demote-f64-to-f32"
   ],
   "iree_run_module_flags": [
     "--device=local-task",
@@ -21,7 +19,5 @@
   "skip_compile_tests": [],
   "skip_run_tests": [],
   "expected_compile_failures": [],
-  "expected_run_failures": [
-    "sdxl-scheduled-unet-3-tank"
-  ]
+  "expected_run_failures": []
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 9ae470d..175c1ae 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -890,7 +890,7 @@
 /// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
 /// reduction loops.
 static void splitParallelAndReductionTiles(
-    linalg::LinalgOp op, SmallVectorImpl<int64_t> &parallelSizes,
+    Operation *op, SmallVectorImpl<int64_t> &parallelSizes,
     SmallVectorImpl<int64_t> &reductionSizes,
     SmallVectorImpl<bool> *parallelScalableFlags = nullptr,
     SmallVectorImpl<bool> *reductionScalableFlags = nullptr) {
@@ -900,8 +900,9 @@
     reductionScalableFlags->assign(parallelScalableFlags->begin(),
                                    parallelScalableFlags->end());
   }
+  TilingInterface tilingOp = cast<TilingInterface>(op);
   for (auto [index, iteratorType] :
-       llvm::enumerate(op.getIteratorTypesArray())) {
+       llvm::enumerate(tilingOp.getLoopIteratorTypes())) {
     if (iteratorType == utils::IteratorType::parallel) {
       reductionSizes[index] = 0;
       if (reductionScalableFlags)
@@ -1121,9 +1122,9 @@
   SmallVector<int64_t> parallelTileSizes = vecTileSizes;
   SmallVector<int64_t> reductionTileSizes;
   SmallVector<bool> reductionScalableFlags;
-  splitParallelAndReductionTiles(
-      cast<linalg::LinalgOp>(op.getOperation()), parallelTileSizes,
-      reductionTileSizes, &parallelScalableFlags, &reductionScalableFlags);
+  splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes,
+                                 &parallelScalableFlags,
+                                 &reductionScalableFlags);
 
   if (vecPreProcStrategy == VectorPreProcStrategy::None) {
     setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
@@ -1741,27 +1742,26 @@
     llvm::dbgs() << "]\n";
   });
 
-  // TODO (Groverkss): Flash Attention 2 (current algorithm we use for
-  // attention) was originally designed for GPUs. N, K1 are the head dimension
-  // and are usually very small (64, 128, See AttentionOpDetail docs for more
-  // detail). For larger sizes, fusing attention doesn't have many gains (as
-  // pointed out by the original author). We should explore if we should tile N
-  // and K1 dimensions on CPU and if it has any gains. On GPUs, we don't tile
-  // these dimensions as subgroups can hold much larger register sizes.
-
   // Batch, M and N (parallel dimensions) are distributed on workgroups.
   DistributionHeuristicConfig config;
-  SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
-      attnOp, DistributionHeuristicConfig{});
+  int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
+  config.maxTileSizes.resize(opInfo.getDomainRank(), clDefaultDistTileSize);
+  config.vectorSizeHints.resize(opInfo.getDomainRank(), vectorSize);
+  // Distribute batch dimensions completely on workgroups (tile_size = 1).
+  for (int batch : opInfo.getBatchDims()) {
+    config.maxTileSizes[batch] = 1;
+    config.vectorSizeHints[batch] = 1;
+  }
+  SmallVector<int64_t> distTileSizes =
+      getDefaultDistributedLevelTileSizes(attnOp, config);
 
   // Batch, M and N (parallel dimensions) are distributed on workgroups.
   SmallVector<int64_t> vecTileSizes(attnOp.getIterationDomainRank(), 1);
-  // Mark reduction dimensions not to distribute.
-  for (int64_t i :
-       llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
+  // Due to the way attention works, K1 dimensions cannot be tiled. Mark k1
+  // reduction dimensions not to distribute.
+  for (int i : opInfo.getK1Dims()) {
     vecTileSizes[i] = 0;
   }
-  int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
   for (auto i : llvm::seq<unsigned>(0, vecTileSizes.size())) {
     // Do not tile reduction dimensions.
     if (vecTileSizes[i] == 0) {
@@ -1773,18 +1773,29 @@
         /*numElem=*/tileSize, vectorSize, vectorSize);
   }
 
-  // TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension
-  // cannot be tiled. Remove this once fixed.
-  for (int64_t i : opInfo.getNDims()) {
-    distTileSizes[i] = 0;
-    vecTileSizes[i] = 0;
+  // Tile the M dimension completely.
+  // TODO: This is a hack to prevent too large vector sizes. The largest vector
+  // generally produced is the Q vector, which is of shape: BATCH x M x K1.
+  // Since K1 cannot be tiled, the heuristics don't properly account for tiling
+  // M such that Q doesn't grow too large.
+  // Ideally, we should use something like limitVectorTileSizes, to fixup tile
+  // sizes. Currently, limitVectorTileSizes ignores static dimensions which are
+  // not tiled, which is why it's not currently used here.
+  for (int i : opInfo.getMDims()) {
+    vecTileSizes[i] = 1;
   }
 
-  TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
+  SmallVector<int64_t> parallelTileSizes = vecTileSizes;
+  SmallVector<int64_t> reductionTileSizes;
+  splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes);
 
-  // TODO: (Groverkss): Tile K2 here using reduction tiling interface once we
-  // have it. TileAndDecomposeAttention pass only tiles K2. I think it should
-  // be possible to tile K1 also, but need to explore it more.
+  LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): "
+                       << parallelTileSizes << "\n");
+  LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): "
+                       << reductionTileSizes << "\n");
+
+  TileSizesListType tileSizes = {distTileSizes, parallelTileSizes,
+                                 reductionTileSizes};
 
   return setOpConfigAndEntryPointFnTranslation(
       entryPointFn, attnOp, tileSizes,
@@ -1843,6 +1854,9 @@
   tileSizes.push_back(distTileSizes);
   SmallVector<int64_t> vecTileSizes(iterationRank, 1);
   tileSizes.push_back(vecTileSizes);
+  // Dummy tiling config for reduction level.
+  SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
+  tileSizes.push_back(reductionTileSizes);
   return setOpConfigAndEntryPointFnTranslation(
       entryPointFn, winogradOp, tileSizes,
       DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 31fcead..6a4363f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -617,10 +617,13 @@
       createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
   // TODO: Remove the pass once we have PartialReductionOpInterface implemented
   // for AttentionOp.
-  funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
-  funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
+  funcPassManager.addPass(
+      IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
+  funcPassManager.addPass(
+      createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
   funcPassManager.addPass(
       IREE::LinalgExt::createDecomposeWinogradTransformPass());
+  funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
 
   {
     GenericVectorizationPassOptions options;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index 4df5605..43b3ddb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1531,7 +1531,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @winograd_output_transform()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -1556,7 +1556,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @winograd_input_transform()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -1581,7 +1581,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1], [0, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @winograd_filter_transform()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -1613,7 +1613,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 0], [20, 32, 0, 0, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 0, 0, 64], [1, 1, 0, 0, 32], [0, 0, 0, 32, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @attention()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
index 8675c43..a771ee0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
@@ -29,6 +29,7 @@
         "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
         "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
         "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+        "@llvm-project//mlir:LinalgOpsTdFiles",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:PDLDialectTdFiles",
         "@llvm-project//mlir:SideEffectInterfacesTdFiles",
@@ -159,7 +160,9 @@
     ],
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "LinalgExtOps.td",
-    deps = [":td_files"],
+    deps = [
+        ":td_files",
+    ],
 )
 
 iree_gentbl_cc_library(
@@ -212,5 +215,7 @@
     ],
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "LinalgExtOps.td",
-    deps = [":td_files"],
+    deps = [
+        ":td_files",
+    ],
 )
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c5c42ec..34ce4f4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -8,6 +8,7 @@
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
@@ -1315,6 +1316,9 @@
     for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
       AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
       int64_t pos = dim.getPosition();
+      if (ShapedType::isDynamic(valShape[i])) {
+        continue;
+      }
       if (!foundDims[pos]) {
         foundDims[pos] = true;
         shape[pos] = valShape[i];
@@ -1427,6 +1431,79 @@
   return results;
 }
 
+//===----------------------------------------------------------------------===//
+// OnlineAttentionOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OnlineAttentionOp::verify() {
+  OnlineAttentionOp attnOp = *this;
+
+  SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
+
+  // Check if indexing maps can represent attention.
+  FailureOr<AttentionOpDetail> maybeOpInfo =
+      AttentionOpDetail::get(indexingMaps);
+
+  // Check shape compatibility based on indexing maps.
+  SmallVector<int64_t> shape(getIterationDomainRank());
+  SmallVector<bool> foundDims(getIterationDomainRank(), false);
+  auto checkShape = [&shape, &foundDims,
+                     &attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
+                              AffineMap indexingMap) -> LogicalResult {
+    if (indexingMap.getNumResults() != valShape.size()) {
+      return attnOp->emitError("Rank Mismatch for ")
+             << operandName << ". Expected: " << indexingMap.getNumResults()
+             << " Got: " << valShape.size();
+    }
+    for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
+      AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
+      int64_t pos = dim.getPosition();
+      if (ShapedType::isDynamic(valShape[i])) {
+        continue;
+      }
+      if (!foundDims[pos]) {
+        foundDims[pos] = true;
+        shape[pos] = valShape[i];
+      }
+      if (shape[pos] != valShape[i]) {
+        return attnOp->emitError("Shape Mismatch for ")
+               << operandName << ". Expected: " << shape[pos]
+               << " Got: " << valShape[i];
+      }
+    }
+    return success();
+  };
+
+  if (failed(checkShape("Query", getQuery().getType().getShape(),
+                        getQueryMap())) ||
+      failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
+      failed(checkShape("Value", getValue().getType().getShape(),
+                        getValueMap())) ||
+      failed(checkShape("Output", getOutput().getType().getShape(),
+                        getOutputMap())) ||
+      failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
+      failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
+    return failure();
+  }
+
+  return success();
+}
+
+MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
+  return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
+}
+
+LogicalResult OnlineAttentionOp::reifyResultShapes(
+    OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return cast<LinalgExtOp>(getOperation())
+      .reifyResultShapes(b, reifiedReturnShapes);
+}
+
+SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
+  return SmallVector<AffineMap>(
+      getIndexingMaps().getAsValueRange<AffineMapAttr>());
+}
+
 #define DEFINE_OP_GET_EFFECTS(OP_NAME)                                         \
   void OP_NAME::getEffects(                                                    \
       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
@@ -1446,6 +1523,7 @@
 DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
 DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
 DEFINE_OP_GET_EFFECTS(AttentionOp)
+DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)
 
 } // namespace mlir::iree_compiler::IREE::LinalgExt
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
index 97caaab..3d52ae6 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
 #define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
 
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/IR/Attributes.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index bf9694d..0eebd2e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -9,6 +9,7 @@
 
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -678,6 +679,96 @@
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// OnlineAttention
+//===----------------------------------------------------------------------===//
+
+def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DestinationStyleOpInterface, LinalgExtInterface,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
+     DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Online Attention operator";
+  let description = [{
+    Traditional scaled dot product attention computes:
+
+    attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
+
+    Online Attention on the other hand, uses an online normalizer instead of
+    softmax:
+
+    online_attention(Q, K, V, scale, running_max, running_sum)
+      = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
+
+    The advantage of this online_normalizer is that it can be tiled along
+    it's reduction dimension, making the online_attention operator:
+      - Tilable along softmax reduction dimension
+      - Associative along softmax reduction dimension
+      - Commutative along softmax associative dimension
+
+    Note: The results of online_attention need to be combined after computing
+    it over the entire softmax reduction dimension by:
+      x, _, sum : results
+      x = (1 / sum) * x
+  }];
+
+  let arguments = (ins AnyShaped:$query,
+                       AnyShaped:$key,
+                       AnyShaped:$value,
+                       AnyFloat:$scale,
+                       AnyShaped:$output,
+                       AnyShaped:$max,
+                       AnyShaped:$sum,
+                       AffineMapArrayAttr:$indexing_maps
+  );
+
+  let results = (outs Variadic<AnyRankedTensor>:$results);
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    attr-dict
+    `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
+    `outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
+    (`->` type($results)^)?
+  }];
+
+  let extraClassDeclaration = [{
+    // Method to implement for specifying output range for
+    // DestinationStyleOpInterface
+    MutableOperandRange getDpsInitsMutable();
+
+    SmallVector<AffineMap> getIndexingMapsArray();
+
+    AffineMap getQueryMap() {
+      return getIndexingMapsArray()[0];
+    }
+    AffineMap getKeyMap() {
+      return getIndexingMapsArray()[1];
+    }
+    AffineMap getValueMap() {
+      return getIndexingMapsArray()[2];
+    }
+    AffineMap getOutputMap() {
+      return getIndexingMapsArray()[3];
+    }
+    AffineMap getMaxMap() {
+      return getIndexingMapsArray()[4];
+    }
+    AffineMap getSumMap() {
+      return getIndexingMapsArray()[5];
+    }
+
+    int64_t getIterationDomainRank() {
+      return getQueryMap().getNumDims();
+    }
+  }];
+}
+
 } // OpGroupNonStructuredOps
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
index 00bb383..923b30a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
@@ -51,5 +51,17 @@
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure LinalgExt::ConvertToOnlineAttention::applyToOne(
+    transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  SmallVector<Operation *> ops;
+  LinalgExt::convertToOnlineAttention(attentionOp, ops, rewriter);
+  for (Operation *op : ops) {
+    results.push_back(op);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
index c3a6310..84e588a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
@@ -86,4 +86,36 @@
   }];
 }
 
+def ConvertToOnlineAttention : Op<Transform_Dialect, "iree.convert_to_online_attention",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Target iree_linalg_ext.attention ops and decompose them.
+    This transform consumes the target handle and produces a result handle.
+  }];
+
+  let arguments = (
+      ins TransformHandleTypeInterface:$target
+  );
+  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
+
+  let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
+  let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+
+  let assemblyFormat = [{
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
new file mode 100644
index 0000000..b4370e7
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -0,0 +1,228 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+
+namespace mlir::iree_compiler::IREE::LinalgExt {
+
+static Value scaleValueInPlace(OpBuilder &builder, Location loc,
+                               AffineMap inputMap, AffineMap scaleMap,
+                               Value value, Value scale) {
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{inputMap, scaleMap});
+  inputMap = compressedMaps[0];
+  scaleMap = compressedMaps[1];
+
+  SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
+                                                 utils::IteratorType::parallel);
+
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, value.getType(), scale, value,
+      SmallVector<AffineMap>{scaleMap, inputMap}, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        // Convert scale to the same datatype as input.
+        Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(),
+                                           /*isUnsignedCast=*/false);
+        Value result = b.create<arith::MulFOp>(loc, scale, args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+template <typename T>
+static Value reduce(OpBuilder &builder, Location loc, AffineMap inputMap,
+                    AffineMap outputMap, Value input, Value output) {
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
+  inputMap = compressedMaps[0];
+  outputMap = compressedMaps[1];
+
+  // Dims not present in outputMap are reductionDims.
+  SmallVector<utils::IteratorType> iteratorTypes(
+      inputMap.getNumDims(), utils::IteratorType::reduction);
+  for (AffineExpr dim : outputMap.getResults()) {
+    int pos = cast<AffineDimExpr>(dim).getPosition();
+    iteratorTypes[pos] = utils::IteratorType::parallel;
+  }
+
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, output.getType(), input, output,
+      SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        // Convert input to the same datatype as acc.
+        Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
+                                        /*isUnsignedCast=*/false);
+        Value result = b.create<T>(loc, in, args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+
+  return genericOp.getResult(0);
+}
+
+static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap,
+                           AffineMap rhsMap, AffineMap accMap, Value lhs,
+                           Value rhs, Value acc) {
+
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{lhsMap, rhsMap, accMap});
+  lhsMap = compressedMaps[0];
+  rhsMap = compressedMaps[1];
+  accMap = compressedMaps[2];
+
+  // Dims not present in accMap are reduction dims.
+  SmallVector<utils::IteratorType> iteratorTypes(
+      accMap.getNumDims(), utils::IteratorType::reduction);
+  for (AffineExpr dim : accMap.getResults()) {
+    int pos = cast<AffineDimExpr>(dim).getPosition();
+    iteratorTypes[pos] = utils::IteratorType::parallel;
+  }
+
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, acc.getType(), SmallVector<Value>{lhs, rhs}, acc,
+      SmallVector<AffineMap>{lhsMap, rhsMap, accMap}, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        // Cast inputs to match output datatype.
+        Value lhs = convertScalarToDtype(b, loc, args[0], args[2].getType(),
+                                         /*isUnsignedCast=*/false);
+        Value rhs = convertScalarToDtype(b, loc, args[1], args[2].getType(),
+                                         /*isUnsignedCast=*/false);
+        Value mul = b.create<arith::MulFOp>(loc, lhs, rhs);
+        Value add = b.create<arith::AddFOp>(loc, mul, args[2]);
+        b.create<linalg::YieldOp>(loc, add);
+      });
+
+  return genericOp.getResult(0);
+}
+
+// Compute output = exp2(output - input)
+static Value computeSubAndExp2(OpBuilder &builder, Location loc,
+                               AffineMap inputMap, AffineMap outputMap,
+                               Value input, Value output) {
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
+  inputMap = compressedMaps[0];
+  outputMap = compressedMaps[1];
+
+  SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
+                                                 utils::IteratorType::parallel);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, output.getType(), input, output,
+      SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        // Convert input to the same datatype as output.
+        Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
+                                        /*isUnsignedCast=*/false);
+        Value diff = b.create<arith::SubFOp>(loc, args[1], in);
+        Value weight = b.create<math::Exp2Op>(loc, diff);
+        b.create<linalg::YieldOp>(loc, weight);
+      });
+  return genericOp.getResult(0);
+}
+
+FailureOr<SmallVector<Value>>
+OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
+  Location loc = getLoc();
+  Value query = getQuery();
+  Value key = getKey();
+  Value value = getValue();
+  Value oldAcc = getOutput();
+  Value oldMax = getMax();
+  Value oldSum = getSum();
+  Type elementType = getQuery().getType().getElementType();
+
+  FailureOr<AttentionOpDetail> maybeOpInfo =
+      AttentionOpDetail::get(getIndexingMapsArray());
+  assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
+  AttentionOpDetail opInfo = maybeOpInfo.value();
+
+  SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
+      getIterationDomain(b), [](Range x) { return x.size; });
+
+  // Since we use exp2 for attention instead of the original exp, we have to
+  // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
+  // have better support for exp2 (we verified that we gain some speedup on
+  // some GPUs).
+  Value scale = getScale();
+  Value log2e =
+      b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, M_LOG2E));
+  scale = b.create<arith::MulFOp>(loc, scale, log2e);
+
+  // In the original algorithm, the scaling is done after the softmax:
+  //        softmax(Q @ K.T * scale) @ V
+  //
+  // But, it is mathematically equivalent to do it on Q first and then multiply
+  // it by K.T. This just allows us to do the scaling once, instead of each
+  // iteration of the loop.
+  AffineMap qMap = getQueryMap();
+  AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
+                                      /*symbolCount=*/0, getContext());
+  query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
+
+  // ---- Matmul 1 ----
+
+  // Get sizes for S.
+  AffineMap sMap = opInfo.getSMap();
+  SmallVector<OpFoldResult> sSizes;
+  for (AffineExpr dimExpr : sMap.getResults()) {
+    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    sSizes.push_back(sizes[dim]);
+  }
+
+  // S = Q @ K
+  // SMap = QMap @ KMap
+  Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
+  Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
+  Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
+  s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
+
+  // TODO: This decomposition should be in a seperate op called
+  // "online softmax".
+  // ---- Online Softmax ----
+
+  // newMax = max(oldMax, rowMax(S))
+  AffineMap maxMap = getMaxMap();
+  Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);
+
+  // P = exp2(S - newMax)
+  // PMap = SMap
+  AffineMap pMap = sMap;
+  Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);
+
+  // norm = exp2(oldMax - newMax)
+  // normMap = maxMap
+  AffineMap normMap = getMaxMap();
+  Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);
+
+  // normSum = norm * oldSum
+  AffineMap sumMap = getSumMap();
+  Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);
+
+  // newSum = normSum + rowMax(P)
+  Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
+
+  // newAcc = norm * oldAcc
+  AffineMap accMap = getOutputMap();
+  Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm);
+
+  // ---- Matmul 2 ----
+
+  // newAcc = P @ V + newAcc
+  newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
+
+  return SmallVector<Value>{newAcc, newMax, newSum};
+}
+
+} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 1d08da9..1a20a8f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -30,6 +30,7 @@
 iree_compiler_cc_library(
     name = "Transforms",
     srcs = [
+        "AggregatedOpInterfaceImpl.cpp",
         "ConvertConv2DToWinograd.cpp",
         "ConvertToLoops.cpp",
         "DecomposeAttention.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 668d28a..19c7522 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@
     "Passes.h"
     "Passes.h.inc"
   SRCS
+    "AggregatedOpInterfaceImpl.cpp"
     "ConvertConv2DToWinograd.cpp"
     "ConvertToLoops.cpp"
     "DecomposeAttention.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
index c70000f..2cd851d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
@@ -366,6 +366,16 @@
     SmallVector<Operation *> ops;
     decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize);
   });
+  getOperation().walk([&](OnlineAttentionOp onlineAtt) {
+    rewriter.setInsertionPoint(onlineAtt);
+    FailureOr<SmallVector<Value>> results =
+        onlineAtt.decomposeOperation(rewriter);
+    if (failed(results)) {
+      onlineAtt->emitOpError("Could not decompose online attention");
+      return signalPassFailure();
+    }
+    rewriter.replaceOp(onlineAtt, results.value());
+  });
 }
 
 std::unique_ptr<Pass> createDecomposeAttentionPass() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 43d44a3..165430a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -56,12 +56,18 @@
                              RewriterBase &rewriter,
                              std::optional<uint64_t> tileSize = std::nullopt);
 
+void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
+                              SmallVectorImpl<Operation *> &ops,
+                              RewriterBase &rewriter);
+
 // Creates a pass to tile the attention op along the reduction dim.
 std::unique_ptr<Pass> createTileAttentionPass();
 
 // Creates a pass to convert the attention op into a sequence of linalg ops.
 std::unique_ptr<Pass> createDecomposeAttentionPass();
 
+std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass();
+
 //===---------------------------------------------------------------------===//
 // Codegen Strategy passes that are moved into IREE.
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index 77bb105..ee801b6 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -98,4 +98,12 @@
   ];
 }
 
+def ConvertAttentionToOnlineAttention :
+    InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention",
+    "mlir::FunctionOpInterface"> {
+  let summary = "";
+  let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
+                    "createConvertAttentionToOnlineAttentionPass()";
+}
+
 #endif  // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index 4c862b5..df0d87e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -8,6 +8,7 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -158,6 +159,17 @@
   void runOnOperation() override;
 };
 
+struct ConvertAttentionToOnlineAttentionPass final
+    : ConvertAttentionToOnlineAttentionBase<
+          ConvertAttentionToOnlineAttentionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
+                linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+  void runOnOperation() override;
+};
+
 } // namespace
 
 /// Tile iree_linalg_ext.attention.
@@ -292,6 +304,103 @@
   return tiledAttentionOp;
 }
 
+void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
+                              SmallVectorImpl<Operation *> &ops,
+                              RewriterBase &rewriter) {
+  rewriter.setInsertionPoint(attnOp);
+
+  Location loc = attnOp.getLoc();
+  MLIRContext *ctx = attnOp.getContext();
+
+  FailureOr<AttentionOpDetail> maybeOpInfo =
+      AttentionOpDetail::get(attnOp.getIndexingMapsArray());
+  assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
+  AttentionOpDetail opInfo = maybeOpInfo.value();
+
+  // Create standard maps for max and sum: (batch, m)
+  int64_t rank = opInfo.getDomainRank();
+  AffineMap maxMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, ctx);
+  for (auto dim :
+       llvm::concat<const int64_t>(opInfo.getBatchDims(), opInfo.getMDims())) {
+    maxMap = maxMap.insertResult(rewriter.getAffineDimExpr(dim),
+                                 maxMap.getNumResults());
+  }
+  AffineMap sumMap = maxMap;
+
+  SmallVector<Range> sizes = attnOp.getIterationDomain(rewriter);
+
+  // Create fill for acc, max and sum.
+  // TODO: Acc should not need a fill. The attention op should get a filled
+  // input instead of an empty input.
+  Value zeroAcc = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(attnOp.getOutputType().getElementType()));
+  Value accFill =
+      rewriter
+          .create<linalg::FillOp>(loc, ValueRange{zeroAcc}, attnOp.getOutput())
+          .result();
+
+  SmallVector<OpFoldResult> rowRedSize =
+      llvm::map_to_vector(sizes, [](Range x) { return x.size; });
+  rowRedSize = applyPermutationMap<OpFoldResult>(maxMap, rowRedSize);
+
+  Type f32Type = rewriter.getF32Type();
+  Value rowRedEmpty =
+      rewriter.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
+
+  Value maxInit =
+      arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, rewriter,
+                              loc, /*useOnlyFiniteValue=*/true);
+  Value sumInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type,
+                                          rewriter, loc);
+
+  Value maxFill =
+      rewriter.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
+          .getResult(0);
+  Value sumFill =
+      rewriter.create<linalg::FillOp>(loc, ValueRange{sumInit}, rowRedEmpty)
+          .getResult(0);
+
+  // Create online attention op.
+  SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
+  indexingMaps.push_back(maxMap);
+  indexingMaps.push_back(sumMap);
+  OnlineAttentionOp onlineAttn = rewriter.create<OnlineAttentionOp>(
+      loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
+      attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
+      accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps));
+  onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary());
+  ops.push_back(onlineAttn);
+
+  Value x = onlineAttn.getResult(0);
+  Value sum = onlineAttn.getResult(2);
+
+  // Merge the outputs of online attention:
+  //  x = (1 / sum) * x
+
+  // Compress the indexing maps.
+  SmallVector<AffineMap> compressedMaps =
+      compressUnusedDims(SmallVector<AffineMap>{sumMap, attnOp.getOutputMap()});
+
+  SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
+                                                 utils::IteratorType::parallel);
+
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, x.getType(), sum, x, compressedMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value one = b.create<arith::ConstantOp>(
+            loc, b.getFloatAttr(args[0].getType(), 1.0));
+        Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
+        // Convert sum to the same datatype as x.
+        reciprocal = convertScalarToDtype(b, loc, reciprocal, args[1].getType(),
+                                          /*isUnsignedCast=*/false);
+        Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  ops.push_back(genericOp);
+
+  rewriter.replaceOp(attnOp, genericOp);
+}
+
 void TileAttentionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   IRRewriter rewriter(context);
@@ -305,8 +414,21 @@
   });
 }
 
+void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  IRRewriter rewriter(context);
+  getOperation().walk([&](AttentionOp attnOp) {
+    SmallVector<Operation *> ops;
+    convertToOnlineAttention(attnOp, ops, rewriter);
+  });
+}
+
 std::unique_ptr<Pass> createTileAttentionPass() {
   return std::make_unique<TileAttentionPass>();
 }
 
+std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass() {
+  return std::make_unique<ConvertAttentionToOnlineAttentionPass>();
+}
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 6b291c7..2ecfd80 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -4,11 +4,9 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
-#include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -16,7 +14,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
@@ -1605,16 +1603,16 @@
 }
 
 //===----------------------------------------------------------------------===//
-// AttentionOp
+// Attention Helpers
 //===----------------------------------------------------------------------===//
 
-SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
-  int64_t domainRank = getIterationDomainRank();
-
+static SmallVector<Range>
+getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank,
+                            ArrayRef<Value> values,
+                            ArrayRef<AffineMap> indexingMaps) {
   SmallVector<Range> loopBounds(domainRank);
-  Location loc = getLoc();
-  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  OpFoldResult zero = b.getIndexAttr(0);
+  OpFoldResult one = b.getIndexAttr(1);
 
   for (auto dim : llvm::seq<int64_t>(0, domainRank)) {
     loopBounds[dim].offset = zero;
@@ -1631,26 +1629,27 @@
         continue;
       }
       dimsFound[pos] = true;
-      loopBounds[pos].size = getDimValue(builder, loc, val, idx);
+      loopBounds[pos].size = getDimValue(b, loc, val, idx);
     }
   };
 
-  // Sizes can be found from Q, K, V alone.
-  fillSizes(getQuery(), getQueryMap());
-  fillSizes(getKey(), getKeyMap());
-  fillSizes(getValue(), getValueMap());
+  for (auto [val, indexingMap] : llvm::zip_equal(values, indexingMaps)) {
+    fillSizes(val, indexingMap);
+  }
 
   return loopBounds;
 }
 
-SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
+static SmallVector<utils::IteratorType>
+getAttentionIteratorTypes(int64_t domainRank,
+                          ArrayRef<AffineMap> indexingMaps) {
   FailureOr<AttentionOpDetail> maybeOpInfo =
-      AttentionOpDetail::get(getIndexingMapsArray());
+      AttentionOpDetail::get(indexingMaps);
   assert(succeeded(maybeOpInfo) && "Failed to infer attention op details");
   AttentionOpDetail opInfo = maybeOpInfo.value();
 
   // All dimensions other than k1 and k2 are parallel.
-  SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
+  SmallVector<utils::IteratorType> iteratorTypes(domainRank,
                                                  utils::IteratorType::parallel);
 
   for (auto dim :
@@ -1661,6 +1660,42 @@
   return iteratorTypes;
 }
 
+static SmallVector<Range> getPermutedSlice(AffineMap permutation,
+                                           ArrayRef<OpFoldResult> offsets,
+                                           ArrayRef<OpFoldResult> sizes) {
+  auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1);
+  assert(permutation.isProjectedPermutation() &&
+         "Indexing map should be a projected permutation");
+  SmallVector<Range> output;
+  for (AffineExpr dimExpr : permutation.getResults()) {
+    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    Range dimRange;
+    dimRange.offset = offsets[dim];
+    dimRange.size = sizes[dim];
+    dimRange.stride = one;
+    output.push_back(dimRange);
+  }
+  return output;
+}
+
+//===----------------------------------------------------------------------===//
+// AttentionOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &b) {
+  // Attention shape can be determined from Q, K, V alone.
+  SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
+  SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
+                                         getValueMap()};
+  return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
+                                     shapedValues, indexingMaps);
+}
+
+SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
+  return getAttentionIteratorTypes(getIterationDomainRank(),
+                                   getIndexingMapsArray());
+}
+
 FailureOr<TilingResult>
 AttentionOp::getTiledImplementation(OpBuilder &builder,
                                     ArrayRef<OpFoldResult> offsets,
@@ -1669,59 +1704,36 @@
   assert(sizes.size() == getIterationDomainRank());
 
   Location loc = getLoc();
-  auto one = builder.getIndexAttr(1);
 
-  auto tileValue = [&](Value val, AffineMap indexingMap)
-      -> std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
-                    SmallVector<OpFoldResult>> {
-    assert(indexingMap.isProjectedPermutation() &&
-           "Indexing map should be a projected permutation");
-    SmallVector<OpFoldResult> outputOffsets;
-    SmallVector<OpFoldResult> outputSizes;
-    SmallVector<OpFoldResult> outputStrides(indexingMap.getNumResults(), one);
-    for (AffineExpr dimExpr : indexingMap.getResults()) {
-      int dim = cast<AffineDimExpr>(dimExpr).getPosition();
-      outputOffsets.push_back(offsets[dim]);
-      outputSizes.push_back(sizes[dim]);
-    }
-    return {outputOffsets, outputSizes, outputStrides};
-  };
-
-  auto [queryOffsets, querySizes, queryStrides] =
-      tileValue(getQuery(), getQueryMap());
-  auto [keyOffsets, keySizes, keyStrides] = tileValue(getKey(), getKeyMap());
-  auto [valueOffsets, valueSizes, valueStrides] =
-      tileValue(getValue(), getValueMap());
-  auto [outputOffsets, outputSizes, outputStrides] =
-      tileValue(getOutput(), getOutputMap());
+  SmallVector<Range> querySlice =
+      getPermutedSlice(getQueryMap(), offsets, sizes);
+  SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
+  SmallVector<Range> valueSlice =
+      getPermutedSlice(getValueMap(), offsets, sizes);
+  SmallVector<Range> outputSlice =
+      getPermutedSlice(getOutputMap(), offsets, sizes);
 
   Value scale = getScale();
 
   SmallVector<Value> tiledOperands;
-  tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), queryOffsets,
-                                      querySizes, queryStrides));
-  tiledOperands.emplace_back(
-      getSlice(builder, loc, getKey(), keyOffsets, keySizes, keyStrides));
-  tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueOffsets,
-                                      valueSizes, valueStrides));
+  tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
   tiledOperands.emplace_back(scale);
-  tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputOffsets,
-                                      outputSizes, outputStrides));
+  tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
 
   std::optional<Value> max = getMax();
   if (max) {
-    auto [maxOffsets, maxSizes, maxStrides] =
-        tileValue(max.value(), *getMaxMap());
-    tiledOperands.emplace_back(
-        getSlice(builder, loc, max.value(), maxOffsets, maxSizes, maxStrides));
+    SmallVector<Range> maxSlice =
+        getPermutedSlice(*getMaxMap(), offsets, sizes);
+    tiledOperands.emplace_back(getSlice(builder, loc, max.value(), maxSlice));
   }
 
   std::optional<Value> sum = getMax();
   if (sum) {
-    auto [sumOffsets, sumSizes, sumStrides] =
-        tileValue(sum.value(), *getSumMap());
-    tiledOperands.emplace_back(
-        getSlice(builder, loc, sum.value(), sumOffsets, sumSizes, sumStrides));
+    SmallVector<Range> sumSlice =
+        getPermutedSlice(*getSumMap(), offsets, sizes);
+    tiledOperands.emplace_back(getSlice(builder, loc, sum.value(), sumSlice));
   }
 
   SmallVector<Type> resultTypes;
@@ -1771,4 +1783,93 @@
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// OnlineAttentionOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<Range> OnlineAttentionOp::getIterationDomain(OpBuilder &b) {
+  // Attention shape can be determined from Q, K, V alone.
+  SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
+  SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
+                                         getValueMap()};
+  return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
+                                     shapedValues, indexingMaps);
+}
+
+SmallVector<utils::IteratorType> OnlineAttentionOp::getLoopIteratorTypes() {
+  return getAttentionIteratorTypes(getIterationDomainRank(),
+                                   getIndexingMapsArray());
+}
+
+FailureOr<TilingResult>
+OnlineAttentionOp::getTiledImplementation(OpBuilder &builder,
+                                          ArrayRef<OpFoldResult> offsets,
+                                          ArrayRef<OpFoldResult> sizes) {
+  assert(offsets.size() == getIterationDomainRank());
+  assert(sizes.size() == getIterationDomainRank());
+
+  Location loc = getLoc();
+
+  SmallVector<Range> querySlice =
+      getPermutedSlice(getQueryMap(), offsets, sizes);
+  SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
+  SmallVector<Range> valueSlice =
+      getPermutedSlice(getValueMap(), offsets, sizes);
+  SmallVector<Range> outputSlice =
+      getPermutedSlice(getOutputMap(), offsets, sizes);
+  SmallVector<Range> maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes);
+  SmallVector<Range> sumSlice = getPermutedSlice(getSumMap(), offsets, sizes);
+
+  Value scale = getScale();
+
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
+  tiledOperands.emplace_back(scale);
+  tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getMax(), maxSlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getSum(), sumSlice));
+
+  SmallVector<Type> resultTypes;
+  resultTypes.push_back(tiledOperands[4].getType());
+  resultTypes.push_back(tiledOperands[5].getType());
+  resultTypes.push_back(tiledOperands[6].getType());
+
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+LogicalResult OnlineAttentionOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  resultOffsets.clear();
+  resultSizes.clear();
+
+  AffineMap resultIndexingMap;
+  switch (resultNumber) {
+  case 0:
+    resultIndexingMap = getOutputMap();
+    break;
+  case 1:
+    resultIndexingMap = getMaxMap();
+    break;
+  case 2:
+    resultIndexingMap = getSumMap();
+    break;
+  default:
+    return failure();
+  }
+
+  for (AffineExpr dimExpr : resultIndexingMap.getResults()) {
+    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+    resultOffsets.push_back(offsets[dim]);
+    resultSizes.push_back(sizes[dim]);
+  }
+  return success();
+}
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index 1917631..fe21d60 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
             "conv2d_to_winograd.mlir",
             "convert_to_loops.mlir",
             "decompose_attention.mlir",
+            "decompose_online_attention.mlir",
             "decompose_winograd.mlir",
             "distribution.mlir",
             "pad_contraction_to_block_size.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 92abdf7..7ef5fd7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
     "conv2d_to_winograd.mlir"
     "convert_to_loops.mlir"
     "decompose_attention.mlir"
+    "decompose_online_attention.mlir"
     "decompose_winograd.mlir"
     "distribution.mlir"
     "pad_contraction_to_block_size.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
new file mode 100644
index 0000000..945cff8
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -0,0 +1,64 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f16(%query: tensor<192x1024x64xf16>,
+                         %key: tensor<192x1024x64xf16>,
+                         %value: tensor<192x1024x64xf16>,
+                         %output: tensor<192x1024x64xf32>,
+                         %max: tensor<192x1024xf32>,
+                         %sum: tensor<192x1024xf32>)
+                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+  %scale = arith.constant 1.0 : f16
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
+        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+
+// We just want to check if we are using the correct algorithm.
+// CHECK-LABEL: @attention_f16
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK:   arith.maximumf
+// CHECK:   linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK:   arith.subf
+// CHECK:   math.exp2
+// CHECK:   linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
+// newAcc = norm * oldAcc
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   linalg.yield
+// newAcc = P @ V + newAcc
+// CHECK: linalg.generic
+// CHECK:   arith.mulf
+// CHECK:   arith.addf
+// CHECK:   linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index f92bdb8..a974f3a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -1536,3 +1536,66 @@
 // CHECK:        }
 // CHECK:        return
 // CHECK:      }
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
+  %scale = arith.constant 1.0 : f32
+
+  %output_empty = tensor.empty() : tensor<192x1024x64xf32>
+  %row_red_empty = tensor.empty() : tensor<192x1024xf32>
+
+  %sum_ident = arith.constant 0.000000e+00 : f32
+  %max_ident = arith.constant -3.40282347E+38 : f32
+
+  %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+  %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+
+  %out:3 = iree_linalg_ext.online_attention
+        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+        ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32)
+        outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+  return %out#0 : tensor<192x1024x64xf32>
+}
+
+// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-LABEL: @online_attention
+// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
+// CHECK-DAG:   %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
+// CHECK-DAG:   %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
+// CHECK-DAG:   %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
+// CHECK-DAG:  %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
+// CHECK-DAG:  %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
+// CHECK-DAG:  %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
+// CHECK-DAG:  %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
+// CHECK-DAG:  %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG:  %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: iree_linalg_ext.online_attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]}
+// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32)
+// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
+// CHECK: scf.forall.in_parallel
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
index 7feee69..ac5f5c0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
@@ -27,7 +27,7 @@
 
 void AttentionOpDetail::inferFromIndexingMaps(
     ArrayRef<AffineMap> indexingMaps) {
-  assert(indexingMaps.size() == 4);
+  assert(indexingMaps.size() >= 4);
   AffineMap qMap = indexingMaps[0];
   AffineMap kMap = indexingMaps[1];
   AffineMap vMap = indexingMaps[2];
@@ -82,7 +82,23 @@
 
   AttentionOpDetail opInfo;
   opInfo.inferFromIndexingMaps(indexingMaps);
+  opInfo.maps = SmallVector<AffineMap>(indexingMaps);
   return opInfo;
 }
 
+AffineMap AttentionOpDetail::getSMap() const {
+  // We need to create an indexing map for the intermediate result of first
+  // matmul. There could be other options, but we choose to create a standard
+  // indexing map:
+  //   SMap = (batch, m, k1, k2, n) -> (batch, m, k2)
+  AffineMap sMap = AffineMap::get(/*dimCount=*/getDomainRank(),
+                                  /*symbolCount=*/0, getContext());
+  for (auto dim :
+       llvm::concat<const int64_t>(getBatchDims(), getMDims(), getK2Dims())) {
+    AffineExpr dimExpr = getAffineDimExpr(dim, getContext());
+    sMap = sMap.insertResult(dimExpr, sMap.getNumResults());
+  }
+  return sMap;
+}
+
 }; // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
index cfba4ff..e66bc86 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
@@ -42,20 +42,29 @@
 public:
   static FailureOr<AttentionOpDetail> get(ArrayRef<AffineMap> indexingMaps);
 
+  int64_t getDomainRank() const { return maps[0].getNumDims(); }
   ArrayRef<int64_t> getBatchDims() const { return batch; }
   ArrayRef<int64_t> getMDims() const { return m; }
   ArrayRef<int64_t> getK1Dims() const { return k1; }
   ArrayRef<int64_t> getK2Dims() const { return k2; }
   ArrayRef<int64_t> getNDims() const { return n; }
 
+  ArrayRef<AffineMap> getIndexingMaps() const { return maps; }
+
+  AffineMap getSMap() const;
+
 private:
   void inferFromIndexingMaps(ArrayRef<AffineMap> indexingMaps);
 
+  MLIRContext *getContext() const { return maps[0].getContext(); }
+
   SmallVector<int64_t> batch;
   SmallVector<int64_t> m;
   SmallVector<int64_t> k1;
   SmallVector<int64_t> k2;
   SmallVector<int64_t> n;
+
+  SmallVector<AffineMap> maps;
 };
 
 }; // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index d231322..e6c3548 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -44,6 +44,13 @@
       [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); });
 }
 
+Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef<Range> slice) {
+  return getSlice(b, loc, src,
+                  llvm::map_to_vector(slice, [](Range x) { return x.offset; }),
+                  llvm::map_to_vector(slice, [](Range x) { return x.size; }),
+                  llvm::map_to_vector(slice, [](Range x) { return x.stride; }));
+}
+
 Value getSlice(OpBuilder &b, Location loc, Value src,
                ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
                ArrayRef<OpFoldResult> strides) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index 9b40a54..eec973f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -13,6 +13,10 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 
+namespace mlir {
+struct Range;
+}; // namespace mlir
+
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
 /// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
@@ -26,6 +30,7 @@
 
 /// Returns a `memref.subview` or a `tensor.extract_slice` based on the type of
 /// `src`.
+Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef<Range> slice);
 Value getSlice(OpBuilder &b, Location loc, Value src,
                ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
                ArrayRef<OpFoldResult> strides);