Revert "[LinalgExt] Add online_attention op" (#17658)

Reverts iree-org/iree#17536

This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling
for cpu:
https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index cc07354..9ae470d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -890,7 +890,7 @@
 /// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
 /// reduction loops.
 static void splitParallelAndReductionTiles(
-    Operation *op, SmallVectorImpl<int64_t> &parallelSizes,
+    linalg::LinalgOp op, SmallVectorImpl<int64_t> &parallelSizes,
     SmallVectorImpl<int64_t> &reductionSizes,
     SmallVectorImpl<bool> *parallelScalableFlags = nullptr,
     SmallVectorImpl<bool> *reductionScalableFlags = nullptr) {
@@ -900,9 +900,8 @@
     reductionScalableFlags->assign(parallelScalableFlags->begin(),
                                    parallelScalableFlags->end());
   }
-  TilingInterface tilingOp = cast<TilingInterface>(op);
   for (auto [index, iteratorType] :
-       llvm::enumerate(tilingOp.getLoopIteratorTypes())) {
+       llvm::enumerate(op.getIteratorTypesArray())) {
     if (iteratorType == utils::IteratorType::parallel) {
       reductionSizes[index] = 0;
       if (reductionScalableFlags)
@@ -1122,9 +1121,9 @@
   SmallVector<int64_t> parallelTileSizes = vecTileSizes;
   SmallVector<int64_t> reductionTileSizes;
   SmallVector<bool> reductionScalableFlags;
-  splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes,
-                                 &parallelScalableFlags,
-                                 &reductionScalableFlags);
+  splitParallelAndReductionTiles(
+      cast<linalg::LinalgOp>(op.getOperation()), parallelTileSizes,
+      reductionTileSizes, &parallelScalableFlags, &reductionScalableFlags);
 
   if (vecPreProcStrategy == VectorPreProcStrategy::None) {
     setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
@@ -1752,13 +1751,14 @@
 
   // Batch, M and N (parallel dimensions) are distributed on workgroups.
   DistributionHeuristicConfig config;
-  SmallVector<int64_t> distTileSizes =
-      getDefaultDistributedLevelTileSizes(attnOp, config);
+  SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
+      attnOp, DistributionHeuristicConfig{});
 
   // Batch, M and N (parallel dimensions) are distributed on workgroups.
   SmallVector<int64_t> vecTileSizes(attnOp.getIterationDomainRank(), 1);
-  // Mark k1 reduction dimensions not to distribute.
-  for (int i : opInfo.getK1Dims()) {
+  // Mark reduction dimensions not to distribute.
+  for (int64_t i :
+       llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
     vecTileSizes[i] = 0;
   }
   int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
@@ -1773,17 +1773,18 @@
         /*numElem=*/tileSize, vectorSize, vectorSize);
   }
 
-  SmallVector<int64_t> parallelTileSizes = vecTileSizes;
-  SmallVector<int64_t> reductionTileSizes;
-  splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes);
+  // TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension
+  // cannot be tiled. Remove this once fixed.
+  for (int64_t i : opInfo.getNDims()) {
+    distTileSizes[i] = 0;
+    vecTileSizes[i] = 0;
+  }
 
-  LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): "
-                       << parallelTileSizes << "\n");
-  LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): "
-                       << reductionTileSizes << "\n");
+  TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
 
-  TileSizesListType tileSizes = {distTileSizes, parallelTileSizes,
-                                 reductionTileSizes};
+  // TODO: (Groverkss): Tile K2 here using reduction tiling interface once we
+  // have it. TileAndDecomposeAttention pass only tiles K2. I think it should
+  // be possible to tile K1 also, but need to explore it more.
 
   return setOpConfigAndEntryPointFnTranslation(
       entryPointFn, attnOp, tileSizes,
@@ -1842,9 +1843,6 @@
   tileSizes.push_back(distTileSizes);
   SmallVector<int64_t> vecTileSizes(iterationRank, 1);
   tileSizes.push_back(vecTileSizes);
-  // Dummy tiling config for reduction level.
-  SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
-  tileSizes.push_back(reductionTileSizes);
   return setOpConfigAndEntryPointFnTranslation(
       entryPointFn, winogradOp, tileSizes,
       DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 6a4363f..31fcead 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -617,13 +617,10 @@
       createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
   // TODO: Remove the pass once we have PartialReductionOpInterface implemented
   // for AttentionOp.
-  funcPassManager.addPass(
-      IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
-  funcPassManager.addPass(
-      createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
+  funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
+  funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
   funcPassManager.addPass(
       IREE::LinalgExt::createDecomposeWinogradTransformPass());
-  funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
 
   {
     GenericVectorizationPassOptions options;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index 6777e97..4df5605 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1531,7 +1531,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @winograd_output_transform()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -1556,7 +1556,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @winograd_input_transform()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -1581,7 +1581,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1], [0, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @winograd_filter_transform()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -1613,7 +1613,7 @@
     return
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 64], [20, 32, 0, 0, 32], [0, 0, 0, 32, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 0], [20, 32, 0, 0, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
 //      CHECK: func.func @attention()
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
index a771ee0..8675c43 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
@@ -29,7 +29,6 @@
         "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
         "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
         "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
-        "@llvm-project//mlir:LinalgOpsTdFiles",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:PDLDialectTdFiles",
         "@llvm-project//mlir:SideEffectInterfacesTdFiles",
@@ -160,9 +159,7 @@
     ],
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "LinalgExtOps.td",
-    deps = [
-        ":td_files",
-    ],
+    deps = [":td_files"],
 )
 
 iree_gentbl_cc_library(
@@ -215,7 +212,5 @@
     ],
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "LinalgExtOps.td",
-    deps = [
-        ":td_files",
-    ],
+    deps = [":td_files"],
 )
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 34ce4f4..c5c42ec 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -8,7 +8,6 @@
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
@@ -1316,9 +1315,6 @@
     for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
       AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
       int64_t pos = dim.getPosition();
-      if (ShapedType::isDynamic(valShape[i])) {
-        continue;
-      }
       if (!foundDims[pos]) {
         foundDims[pos] = true;
         shape[pos] = valShape[i];
@@ -1431,79 +1427,6 @@
   return results;
 }
 
-//===----------------------------------------------------------------------===//
-// OnlineAttentionOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult OnlineAttentionOp::verify() {
-  OnlineAttentionOp attnOp = *this;
-
-  SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
-
-  // Check if indexing maps can represent attention.
-  FailureOr<AttentionOpDetail> maybeOpInfo =
-      AttentionOpDetail::get(indexingMaps);
-
-  // Check shape compatibility based on indexing maps.
-  SmallVector<int64_t> shape(getIterationDomainRank());
-  SmallVector<bool> foundDims(getIterationDomainRank(), false);
-  auto checkShape = [&shape, &foundDims,
-                     &attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
-                              AffineMap indexingMap) -> LogicalResult {
-    if (indexingMap.getNumResults() != valShape.size()) {
-      return attnOp->emitError("Rank Mismatch for ")
-             << operandName << ". Expected: " << indexingMap.getNumResults()
-             << " Got: " << valShape.size();
-    }
-    for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
-      AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
-      int64_t pos = dim.getPosition();
-      if (ShapedType::isDynamic(valShape[i])) {
-        continue;
-      }
-      if (!foundDims[pos]) {
-        foundDims[pos] = true;
-        shape[pos] = valShape[i];
-      }
-      if (shape[pos] != valShape[i]) {
-        return attnOp->emitError("Shape Mismatch for ")
-               << operandName << ". Expected: " << shape[pos]
-               << " Got: " << valShape[i];
-      }
-    }
-    return success();
-  };
-
-  if (failed(checkShape("Query", getQuery().getType().getShape(),
-                        getQueryMap())) ||
-      failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
-      failed(checkShape("Value", getValue().getType().getShape(),
-                        getValueMap())) ||
-      failed(checkShape("Output", getOutput().getType().getShape(),
-                        getOutputMap())) ||
-      failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
-      failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
-    return failure();
-  }
-
-  return success();
-}
-
-MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
-  return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
-}
-
-LogicalResult OnlineAttentionOp::reifyResultShapes(
-    OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  return cast<LinalgExtOp>(getOperation())
-      .reifyResultShapes(b, reifiedReturnShapes);
-}
-
-SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
-  return SmallVector<AffineMap>(
-      getIndexingMaps().getAsValueRange<AffineMapAttr>());
-}
-
 #define DEFINE_OP_GET_EFFECTS(OP_NAME)                                         \
   void OP_NAME::getEffects(                                                    \
       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
@@ -1523,7 +1446,6 @@
 DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
 DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
 DEFINE_OP_GET_EFFECTS(AttentionOp)
-DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)
 
 } // namespace mlir::iree_compiler::IREE::LinalgExt
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
index 3d52ae6..97caaab 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -7,7 +7,6 @@
 #ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
 #define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
 
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/IR/Attributes.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 0eebd2e..bf9694d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -9,7 +9,6 @@
 
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
-include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -679,96 +678,6 @@
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// OnlineAttention
-//===----------------------------------------------------------------------===//
-
-def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     DestinationStyleOpInterface, LinalgExtInterface,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
-     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
-     DeclareOpInterfaceMethods<TilingInterface,
-      ["getIterationDomain",
-       "getLoopIteratorTypes",
-       "getResultTilePosition",
-       "getTiledImplementation"]>]> {
-  let summary = "Online Attention operator";
-  let description = [{
-    Traditional scaled dot product attention computes:
-
-    attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
-
-    Online Attention on the other hand, uses an online normalizer instead of
-    softmax:
-
-    online_attention(Q, K, V, scale, running_max, running_sum)
-      = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
-
-    The advantage of this online_normalizer is that it can be tiled along
-    it's reduction dimension, making the online_attention operator:
-      - Tilable along softmax reduction dimension
-      - Associative along softmax reduction dimension
-      - Commutative along softmax associative dimension
-
-    Note: The results of online_attention need to be combined after computing
-    it over the entire softmax reduction dimension by:
-      x, _, sum : results
-      x = (1 / sum) * x
-  }];
-
-  let arguments = (ins AnyShaped:$query,
-                       AnyShaped:$key,
-                       AnyShaped:$value,
-                       AnyFloat:$scale,
-                       AnyShaped:$output,
-                       AnyShaped:$max,
-                       AnyShaped:$sum,
-                       AffineMapArrayAttr:$indexing_maps
-  );
-
-  let results = (outs Variadic<AnyRankedTensor>:$results);
-  let hasVerifier = 1;
-  let hasCustomAssemblyFormat = 1;
-  let assemblyFormat = [{
-    attr-dict
-    `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
-    `outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
-    (`->` type($results)^)?
-  }];
-
-  let extraClassDeclaration = [{
-    // Method to implement for specifying output range for
-    // DestinationStyleOpInterface
-    MutableOperandRange getDpsInitsMutable();
-
-    SmallVector<AffineMap> getIndexingMapsArray();
-
-    AffineMap getQueryMap() {
-      return getIndexingMapsArray()[0];
-    }
-    AffineMap getKeyMap() {
-      return getIndexingMapsArray()[1];
-    }
-    AffineMap getValueMap() {
-      return getIndexingMapsArray()[2];
-    }
-    AffineMap getOutputMap() {
-      return getIndexingMapsArray()[3];
-    }
-    AffineMap getMaxMap() {
-      return getIndexingMapsArray()[4];
-    }
-    AffineMap getSumMap() {
-      return getIndexingMapsArray()[5];
-    }
-
-    int64_t getIterationDomainRank() {
-      return getQueryMap().getNumDims();
-    }
-  }];
-}
-
 } // OpGroupNonStructuredOps
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
index 923b30a..00bb383 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
@@ -51,17 +51,5 @@
   return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilenceableFailure LinalgExt::ConvertToOnlineAttention::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  SmallVector<Operation *> ops;
-  LinalgExt::convertToOnlineAttention(attentionOp, ops, rewriter);
-  for (Operation *op : ops) {
-    results.push_back(op);
-  }
-  return DiagnosedSilenceableFailure::success();
-}
-
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
index 84e588a..c3a6310 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
@@ -86,36 +86,4 @@
   }];
 }
 
-def ConvertToOnlineAttention : Op<Transform_Dialect, "iree.convert_to_online_attention",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Target iree_linalg_ext.attention ops and decompose them.
-    This transform consumes the target handle and produces a result handle.
-  }];
-
-  let arguments = (
-      ins TransformHandleTypeInterface:$target
-  );
-  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
-
-  let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
-  let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
-  let assemblyFormat = [{
-    $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
 #endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
deleted file mode 100644
index b4370e7..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-
-namespace mlir::iree_compiler::IREE::LinalgExt {
-
-static Value scaleValueInPlace(OpBuilder &builder, Location loc,
-                               AffineMap inputMap, AffineMap scaleMap,
-                               Value value, Value scale) {
-  SmallVector<AffineMap> compressedMaps =
-      compressUnusedDims(SmallVector<AffineMap>{inputMap, scaleMap});
-  inputMap = compressedMaps[0];
-  scaleMap = compressedMaps[1];
-
-  SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
-                                                 utils::IteratorType::parallel);
-
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, value.getType(), scale, value,
-      SmallVector<AffineMap>{scaleMap, inputMap}, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        // Convert scale to the same datatype as input.
-        Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(),
-                                           /*isUnsignedCast=*/false);
-        Value result = b.create<arith::MulFOp>(loc, scale, args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
-}
-
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, AffineMap inputMap,
-                    AffineMap outputMap, Value input, Value output) {
-  SmallVector<AffineMap> compressedMaps =
-      compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
-  inputMap = compressedMaps[0];
-  outputMap = compressedMaps[1];
-
-  // Dims not present in outputMap are reductionDims.
-  SmallVector<utils::IteratorType> iteratorTypes(
-      inputMap.getNumDims(), utils::IteratorType::reduction);
-  for (AffineExpr dim : outputMap.getResults()) {
-    int pos = cast<AffineDimExpr>(dim).getPosition();
-    iteratorTypes[pos] = utils::IteratorType::parallel;
-  }
-
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), input, output,
-      SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        // Convert input to the same datatype as acc.
-        Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
-                                        /*isUnsignedCast=*/false);
-        Value result = b.create<T>(loc, in, args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-
-  return genericOp.getResult(0);
-}
-
-static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap,
-                           AffineMap rhsMap, AffineMap accMap, Value lhs,
-                           Value rhs, Value acc) {
-
-  SmallVector<AffineMap> compressedMaps =
-      compressUnusedDims(SmallVector<AffineMap>{lhsMap, rhsMap, accMap});
-  lhsMap = compressedMaps[0];
-  rhsMap = compressedMaps[1];
-  accMap = compressedMaps[2];
-
-  // Dims not present in accMap are reduction dims.
-  SmallVector<utils::IteratorType> iteratorTypes(
-      accMap.getNumDims(), utils::IteratorType::reduction);
-  for (AffineExpr dim : accMap.getResults()) {
-    int pos = cast<AffineDimExpr>(dim).getPosition();
-    iteratorTypes[pos] = utils::IteratorType::parallel;
-  }
-
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, acc.getType(), SmallVector<Value>{lhs, rhs}, acc,
-      SmallVector<AffineMap>{lhsMap, rhsMap, accMap}, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        // Cast inputs to match output datatype.
-        Value lhs = convertScalarToDtype(b, loc, args[0], args[2].getType(),
-                                         /*isUnsignedCast=*/false);
-        Value rhs = convertScalarToDtype(b, loc, args[1], args[2].getType(),
-                                         /*isUnsignedCast=*/false);
-        Value mul = b.create<arith::MulFOp>(loc, lhs, rhs);
-        Value add = b.create<arith::AddFOp>(loc, mul, args[2]);
-        b.create<linalg::YieldOp>(loc, add);
-      });
-
-  return genericOp.getResult(0);
-}
-
-// Compute output = exp2(output - input)
-static Value computeSubAndExp2(OpBuilder &builder, Location loc,
-                               AffineMap inputMap, AffineMap outputMap,
-                               Value input, Value output) {
-  SmallVector<AffineMap> compressedMaps =
-      compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
-  inputMap = compressedMaps[0];
-  outputMap = compressedMaps[1];
-
-  SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), input, output,
-      SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        // Convert input to the same datatype as output.
-        Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
-                                        /*isUnsignedCast=*/false);
-        Value diff = b.create<arith::SubFOp>(loc, args[1], in);
-        Value weight = b.create<math::Exp2Op>(loc, diff);
-        b.create<linalg::YieldOp>(loc, weight);
-      });
-  return genericOp.getResult(0);
-}
-
-FailureOr<SmallVector<Value>>
-OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
-  Location loc = getLoc();
-  Value query = getQuery();
-  Value key = getKey();
-  Value value = getValue();
-  Value oldAcc = getOutput();
-  Value oldMax = getMax();
-  Value oldSum = getSum();
-  Type elementType = getQuery().getType().getElementType();
-
-  FailureOr<AttentionOpDetail> maybeOpInfo =
-      AttentionOpDetail::get(getIndexingMapsArray());
-  assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
-  AttentionOpDetail opInfo = maybeOpInfo.value();
-
-  SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
-      getIterationDomain(b), [](Range x) { return x.size; });
-
-  // Since we use exp2 for attention instead of the original exp, we have to
-  // multiply the scale by log2(e). We use exp2 instead of exp as most platforms
-  // have better support for exp2 (we verified that we gain some speedup on
-  // some GPUs).
-  Value scale = getScale();
-  Value log2e =
-      b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, M_LOG2E));
-  scale = b.create<arith::MulFOp>(loc, scale, log2e);
-
-  // In the original algorithm, the scaling is done after the softmax:
-  //        softmax(Q @ K.T * scale) @ V
-  //
-  // But, it is mathematically equivalent to do it on Q first and then multiply
-  // it by K.T. This just allows us to do the scaling once, instead of each
-  // iteration of the loop.
-  AffineMap qMap = getQueryMap();
-  AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
-                                      /*symbolCount=*/0, getContext());
-  query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
-
-  // ---- Matmul 1 ----
-
-  // Get sizes for S.
-  AffineMap sMap = opInfo.getSMap();
-  SmallVector<OpFoldResult> sSizes;
-  for (AffineExpr dimExpr : sMap.getResults()) {
-    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    sSizes.push_back(sizes[dim]);
-  }
-
-  // S = Q @ K
-  // SMap = QMap @ KMap
-  Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
-  Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
-  Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
-  s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
-
-  // TODO: This decomposition should be in a seperate op called
-  // "online softmax".
-  // ---- Online Softmax ----
-
-  // newMax = max(oldMax, rowMax(S))
-  AffineMap maxMap = getMaxMap();
-  Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);
-
-  // P = exp2(S - newMax)
-  // PMap = SMap
-  AffineMap pMap = sMap;
-  Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);
-
-  // norm = exp2(oldMax - newMax)
-  // normMap = maxMap
-  AffineMap normMap = getMaxMap();
-  Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);
-
-  // normSum = norm * oldSum
-  AffineMap sumMap = getSumMap();
-  Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);
-
-  // newSum = normSum + rowMax(P)
-  Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
-
-  // newAcc = norm * oldAcc
-  AffineMap accMap = getOutputMap();
-  Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm);
-
-  // ---- Matmul 2 ----
-
-  // newAcc = P @ V + newAcc
-  newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
-
-  return SmallVector<Value>{newAcc, newMax, newSum};
-}
-
-} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 1a20a8f..1d08da9 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -30,7 +30,6 @@
 iree_compiler_cc_library(
     name = "Transforms",
     srcs = [
-        "AggregatedOpInterfaceImpl.cpp",
         "ConvertConv2DToWinograd.cpp",
         "ConvertToLoops.cpp",
         "DecomposeAttention.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 19c7522..668d28a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@
     "Passes.h"
     "Passes.h.inc"
   SRCS
-    "AggregatedOpInterfaceImpl.cpp"
     "ConvertConv2DToWinograd.cpp"
     "ConvertToLoops.cpp"
     "DecomposeAttention.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
index 2cd851d..c70000f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
@@ -366,16 +366,6 @@
     SmallVector<Operation *> ops;
     decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize);
   });
-  getOperation().walk([&](OnlineAttentionOp onlineAtt) {
-    rewriter.setInsertionPoint(onlineAtt);
-    FailureOr<SmallVector<Value>> results =
-        onlineAtt.decomposeOperation(rewriter);
-    if (failed(results)) {
-      onlineAtt->emitOpError("Could not decompose online attention");
-      return signalPassFailure();
-    }
-    rewriter.replaceOp(onlineAtt, results.value());
-  });
 }
 
 std::unique_ptr<Pass> createDecomposeAttentionPass() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 165430a..43d44a3 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -56,18 +56,12 @@
                              RewriterBase &rewriter,
                              std::optional<uint64_t> tileSize = std::nullopt);
 
-void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
-                              SmallVectorImpl<Operation *> &ops,
-                              RewriterBase &rewriter);
-
 // Creates a pass to tile the attention op along the reduction dim.
 std::unique_ptr<Pass> createTileAttentionPass();
 
 // Creates a pass to convert the attention op into a sequence of linalg ops.
 std::unique_ptr<Pass> createDecomposeAttentionPass();
 
-std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass();
-
 //===---------------------------------------------------------------------===//
 // Codegen Strategy passes that are moved into IREE.
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index ee801b6..77bb105 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -98,12 +98,4 @@
   ];
 }
 
-def ConvertAttentionToOnlineAttention :
-    InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention",
-    "mlir::FunctionOpInterface"> {
-  let summary = "";
-  let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
-                    "createConvertAttentionToOnlineAttentionPass()";
-}
-
 #endif  // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index df0d87e..4c862b5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -8,7 +8,6 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -159,17 +158,6 @@
   void runOnOperation() override;
 };
 
-struct ConvertAttentionToOnlineAttentionPass final
-    : ConvertAttentionToOnlineAttentionBase<
-          ConvertAttentionToOnlineAttentionPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
-                linalg::LinalgDialect, tensor::TensorDialect>();
-  }
-  void runOnOperation() override;
-};
-
 } // namespace
 
 /// Tile iree_linalg_ext.attention.
@@ -304,103 +292,6 @@
   return tiledAttentionOp;
 }
 
-void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
-                              SmallVectorImpl<Operation *> &ops,
-                              RewriterBase &rewriter) {
-  rewriter.setInsertionPoint(attnOp);
-
-  Location loc = attnOp.getLoc();
-  MLIRContext *ctx = attnOp.getContext();
-
-  FailureOr<AttentionOpDetail> maybeOpInfo =
-      AttentionOpDetail::get(attnOp.getIndexingMapsArray());
-  assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
-  AttentionOpDetail opInfo = maybeOpInfo.value();
-
-  // Create standard maps for max and sum: (batch, m)
-  int64_t rank = opInfo.getDomainRank();
-  AffineMap maxMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, ctx);
-  for (auto dim :
-       llvm::concat<const int64_t>(opInfo.getBatchDims(), opInfo.getMDims())) {
-    maxMap = maxMap.insertResult(rewriter.getAffineDimExpr(dim),
-                                 maxMap.getNumResults());
-  }
-  AffineMap sumMap = maxMap;
-
-  SmallVector<Range> sizes = attnOp.getIterationDomain(rewriter);
-
-  // Create fill for acc, max and sum.
-  // TODO: Acc should not need a fill. The attention op should get a filled
-  // input instead of an empty input.
-  Value zeroAcc = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getZeroAttr(attnOp.getOutputType().getElementType()));
-  Value accFill =
-      rewriter
-          .create<linalg::FillOp>(loc, ValueRange{zeroAcc}, attnOp.getOutput())
-          .result();
-
-  SmallVector<OpFoldResult> rowRedSize =
-      llvm::map_to_vector(sizes, [](Range x) { return x.size; });
-  rowRedSize = applyPermutationMap<OpFoldResult>(maxMap, rowRedSize);
-
-  Type f32Type = rewriter.getF32Type();
-  Value rowRedEmpty =
-      rewriter.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);
-
-  Value maxInit =
-      arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, rewriter,
-                              loc, /*useOnlyFiniteValue=*/true);
-  Value sumInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type,
-                                          rewriter, loc);
-
-  Value maxFill =
-      rewriter.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
-          .getResult(0);
-  Value sumFill =
-      rewriter.create<linalg::FillOp>(loc, ValueRange{sumInit}, rowRedEmpty)
-          .getResult(0);
-
-  // Create online attention op.
-  SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
-  indexingMaps.push_back(maxMap);
-  indexingMaps.push_back(sumMap);
-  OnlineAttentionOp onlineAttn = rewriter.create<OnlineAttentionOp>(
-      loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
-      attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
-      accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps));
-  onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary());
-  ops.push_back(onlineAttn);
-
-  Value x = onlineAttn.getResult(0);
-  Value sum = onlineAttn.getResult(2);
-
-  // Merge the outputs of online attention:
-  //  x = (1 / sum) * x
-
-  // Compress the indexing maps.
-  SmallVector<AffineMap> compressedMaps =
-      compressUnusedDims(SmallVector<AffineMap>{sumMap, attnOp.getOutputMap()});
-
-  SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
-                                                 utils::IteratorType::parallel);
-
-  auto genericOp = rewriter.create<linalg::GenericOp>(
-      loc, x.getType(), sum, x, compressedMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value one = b.create<arith::ConstantOp>(
-            loc, b.getFloatAttr(args[0].getType(), 1.0));
-        Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
-        // Convert sum to the same datatype as x.
-        reciprocal = convertScalarToDtype(b, loc, reciprocal, args[1].getType(),
-                                          /*isUnsignedCast=*/false);
-        Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-
-  rewriter.replaceOp(attnOp, genericOp);
-}
-
 void TileAttentionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   IRRewriter rewriter(context);
@@ -414,21 +305,8 @@
   });
 }
 
-void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  IRRewriter rewriter(context);
-  getOperation().walk([&](AttentionOp attnOp) {
-    SmallVector<Operation *> ops;
-    convertToOnlineAttention(attnOp, ops, rewriter);
-  });
-}
-
 std::unique_ptr<Pass> createTileAttentionPass() {
   return std::make_unique<TileAttentionPass>();
 }
 
-std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass() {
-  return std::make_unique<ConvertAttentionToOnlineAttentionPass>();
-}
-
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 2ecfd80..6b291c7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -4,9 +4,11 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -14,7 +16,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
@@ -1603,16 +1605,16 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Attention Helpers
+// AttentionOp
 //===----------------------------------------------------------------------===//
 
-static SmallVector<Range>
-getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank,
-                            ArrayRef<Value> values,
-                            ArrayRef<AffineMap> indexingMaps) {
+SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
+  int64_t domainRank = getIterationDomainRank();
+
   SmallVector<Range> loopBounds(domainRank);
-  OpFoldResult zero = b.getIndexAttr(0);
-  OpFoldResult one = b.getIndexAttr(1);
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
 
   for (auto dim : llvm::seq<int64_t>(0, domainRank)) {
     loopBounds[dim].offset = zero;
@@ -1629,27 +1631,26 @@
         continue;
       }
       dimsFound[pos] = true;
-      loopBounds[pos].size = getDimValue(b, loc, val, idx);
+      loopBounds[pos].size = getDimValue(builder, loc, val, idx);
     }
   };
 
-  for (auto [val, indexingMap] : llvm::zip_equal(values, indexingMaps)) {
-    fillSizes(val, indexingMap);
-  }
+  // Sizes can be found from Q, K, V alone.
+  fillSizes(getQuery(), getQueryMap());
+  fillSizes(getKey(), getKeyMap());
+  fillSizes(getValue(), getValueMap());
 
   return loopBounds;
 }
 
-static SmallVector<utils::IteratorType>
-getAttentionIteratorTypes(int64_t domainRank,
-                          ArrayRef<AffineMap> indexingMaps) {
+SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
   FailureOr<AttentionOpDetail> maybeOpInfo =
-      AttentionOpDetail::get(indexingMaps);
+      AttentionOpDetail::get(getIndexingMapsArray());
   assert(succeeded(maybeOpInfo) && "Failed to infer attention op details");
   AttentionOpDetail opInfo = maybeOpInfo.value();
 
   // All dimensions other than k1 and k2 are parallel.
-  SmallVector<utils::IteratorType> iteratorTypes(domainRank,
+  SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
                                                  utils::IteratorType::parallel);
 
   for (auto dim :
@@ -1660,42 +1661,6 @@
   return iteratorTypes;
 }
 
-static SmallVector<Range> getPermutedSlice(AffineMap permutation,
-                                           ArrayRef<OpFoldResult> offsets,
-                                           ArrayRef<OpFoldResult> sizes) {
-  auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1);
-  assert(permutation.isProjectedPermutation() &&
-         "Indexing map should be a projected permutation");
-  SmallVector<Range> output;
-  for (AffineExpr dimExpr : permutation.getResults()) {
-    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    Range dimRange;
-    dimRange.offset = offsets[dim];
-    dimRange.size = sizes[dim];
-    dimRange.stride = one;
-    output.push_back(dimRange);
-  }
-  return output;
-}
-
-//===----------------------------------------------------------------------===//
-// AttentionOp
-//===----------------------------------------------------------------------===//
-
-SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &b) {
-  // Attention shape can be determined from Q, K, V alone.
-  SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
-  SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
-                                         getValueMap()};
-  return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
-                                     shapedValues, indexingMaps);
-}
-
-SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
-  return getAttentionIteratorTypes(getIterationDomainRank(),
-                                   getIndexingMapsArray());
-}
-
 FailureOr<TilingResult>
 AttentionOp::getTiledImplementation(OpBuilder &builder,
                                     ArrayRef<OpFoldResult> offsets,
@@ -1704,36 +1669,59 @@
   assert(sizes.size() == getIterationDomainRank());
 
   Location loc = getLoc();
+  auto one = builder.getIndexAttr(1);
 
-  SmallVector<Range> querySlice =
-      getPermutedSlice(getQueryMap(), offsets, sizes);
-  SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
-  SmallVector<Range> valueSlice =
-      getPermutedSlice(getValueMap(), offsets, sizes);
-  SmallVector<Range> outputSlice =
-      getPermutedSlice(getOutputMap(), offsets, sizes);
+  auto tileValue = [&](Value val, AffineMap indexingMap)
+      -> std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+                    SmallVector<OpFoldResult>> {
+    assert(indexingMap.isProjectedPermutation() &&
+           "Indexing map should be a projected permutation");
+    SmallVector<OpFoldResult> outputOffsets;
+    SmallVector<OpFoldResult> outputSizes;
+    SmallVector<OpFoldResult> outputStrides(indexingMap.getNumResults(), one);
+    for (AffineExpr dimExpr : indexingMap.getResults()) {
+      int dim = cast<AffineDimExpr>(dimExpr).getPosition();
+      outputOffsets.push_back(offsets[dim]);
+      outputSizes.push_back(sizes[dim]);
+    }
+    return {outputOffsets, outputSizes, outputStrides};
+  };
+
+  auto [queryOffsets, querySizes, queryStrides] =
+      tileValue(getQuery(), getQueryMap());
+  auto [keyOffsets, keySizes, keyStrides] = tileValue(getKey(), getKeyMap());
+  auto [valueOffsets, valueSizes, valueStrides] =
+      tileValue(getValue(), getValueMap());
+  auto [outputOffsets, outputSizes, outputStrides] =
+      tileValue(getOutput(), getOutputMap());
 
   Value scale = getScale();
 
   SmallVector<Value> tiledOperands;
-  tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice));
-  tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
-  tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), queryOffsets,
+                                      querySizes, queryStrides));
+  tiledOperands.emplace_back(
+      getSlice(builder, loc, getKey(), keyOffsets, keySizes, keyStrides));
+  tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueOffsets,
+                                      valueSizes, valueStrides));
   tiledOperands.emplace_back(scale);
-  tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
+  tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputOffsets,
+                                      outputSizes, outputStrides));
 
   std::optional<Value> max = getMax();
   if (max) {
-    SmallVector<Range> maxSlice =
-        getPermutedSlice(*getMaxMap(), offsets, sizes);
-    tiledOperands.emplace_back(getSlice(builder, loc, max.value(), maxSlice));
+    auto [maxOffsets, maxSizes, maxStrides] =
+        tileValue(max.value(), *getMaxMap());
+    tiledOperands.emplace_back(
+        getSlice(builder, loc, max.value(), maxOffsets, maxSizes, maxStrides));
   }
 
   std::optional<Value> sum = getMax();
   if (sum) {
-    SmallVector<Range> sumSlice =
-        getPermutedSlice(*getSumMap(), offsets, sizes);
-    tiledOperands.emplace_back(getSlice(builder, loc, sum.value(), sumSlice));
+    auto [sumOffsets, sumSizes, sumStrides] =
+        tileValue(sum.value(), *getSumMap());
+    tiledOperands.emplace_back(
+        getSlice(builder, loc, sum.value(), sumOffsets, sumSizes, sumStrides));
   }
 
   SmallVector<Type> resultTypes;
@@ -1783,93 +1771,4 @@
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// OnlineAttentionOp
-//===----------------------------------------------------------------------===//
-
-SmallVector<Range> OnlineAttentionOp::getIterationDomain(OpBuilder &b) {
-  // Attention shape can be determined from Q, K, V alone.
-  SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
-  SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
-                                         getValueMap()};
-  return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
-                                     shapedValues, indexingMaps);
-}
-
-SmallVector<utils::IteratorType> OnlineAttentionOp::getLoopIteratorTypes() {
-  return getAttentionIteratorTypes(getIterationDomainRank(),
-                                   getIndexingMapsArray());
-}
-
-FailureOr<TilingResult>
-OnlineAttentionOp::getTiledImplementation(OpBuilder &builder,
-                                          ArrayRef<OpFoldResult> offsets,
-                                          ArrayRef<OpFoldResult> sizes) {
-  assert(offsets.size() == getIterationDomainRank());
-  assert(sizes.size() == getIterationDomainRank());
-
-  Location loc = getLoc();
-
-  SmallVector<Range> querySlice =
-      getPermutedSlice(getQueryMap(), offsets, sizes);
-  SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
-  SmallVector<Range> valueSlice =
-      getPermutedSlice(getValueMap(), offsets, sizes);
-  SmallVector<Range> outputSlice =
-      getPermutedSlice(getOutputMap(), offsets, sizes);
-  SmallVector<Range> maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes);
-  SmallVector<Range> sumSlice = getPermutedSlice(getSumMap(), offsets, sizes);
-
-  Value scale = getScale();
-
-  SmallVector<Value> tiledOperands;
-  tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice));
-  tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice));
-  tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice));
-  tiledOperands.emplace_back(scale);
-  tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice));
-  tiledOperands.emplace_back(getSlice(builder, loc, getMax(), maxSlice));
-  tiledOperands.emplace_back(getSlice(builder, loc, getSum(), sumSlice));
-
-  SmallVector<Type> resultTypes;
-  resultTypes.push_back(tiledOperands[4].getType());
-  resultTypes.push_back(tiledOperands[5].getType());
-  resultTypes.push_back(tiledOperands[6].getType());
-
-  Operation *tiledOp =
-      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
-}
-
-LogicalResult OnlineAttentionOp::getResultTilePosition(
-    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
-  resultOffsets.clear();
-  resultSizes.clear();
-
-  AffineMap resultIndexingMap;
-  switch (resultNumber) {
-  case 0:
-    resultIndexingMap = getOutputMap();
-    break;
-  case 1:
-    resultIndexingMap = getMaxMap();
-    break;
-  case 2:
-    resultIndexingMap = getSumMap();
-    break;
-  default:
-    return failure();
-  }
-
-  for (AffineExpr dimExpr : resultIndexingMap.getResults()) {
-    int dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    resultOffsets.push_back(offsets[dim]);
-    resultSizes.push_back(sizes[dim]);
-  }
-  return success();
-}
-
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index fe21d60..1917631 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,7 +19,6 @@
             "conv2d_to_winograd.mlir",
             "convert_to_loops.mlir",
             "decompose_attention.mlir",
-            "decompose_online_attention.mlir",
             "decompose_winograd.mlir",
             "distribution.mlir",
             "pad_contraction_to_block_size.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 7ef5fd7..92abdf7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,7 +17,6 @@
     "conv2d_to_winograd.mlir"
     "convert_to_loops.mlir"
     "decompose_attention.mlir"
-    "decompose_online_attention.mlir"
     "decompose_winograd.mlir"
     "distribution.mlir"
     "pad_contraction_to_block_size.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
deleted file mode 100644
index 945cff8..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ /dev/null
@@ -1,64 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @attention_f16(%query: tensor<192x1024x64xf16>,
-                         %key: tensor<192x1024x64xf16>,
-                         %value: tensor<192x1024x64xf16>,
-                         %output: tensor<192x1024x64xf32>,
-                         %max: tensor<192x1024xf32>,
-                         %sum: tensor<192x1024xf32>)
-                         -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
-  %scale = arith.constant 1.0 : f16
-
-  %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
-        ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
-        outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
-  return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
-}
-
-// We just want to check if we are using the correct algorithm.
-// CHECK-LABEL: @attention_f16
-// S = Q @ K
-// CHECK: linalg.generic
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// newMax = max(oldMax, rowMax(S))
-// CHECK: linalg.generic
-// CHECK:   arith.maximumf
-// CHECK:   linalg.yield
-// P = exp2(S - newMax)
-// CHECK: linalg.generic
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// norm = exp2(oldMax - newMax)
-// CHECK: linalg.generic
-// CHECK:   arith.subf
-// CHECK:   math.exp2
-// CHECK:   linalg.yield
-// normSum = norm * oldSum
-// CHECK: linalg.generic
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// newSum = normSum + rowMax(P)
-// CHECK: linalg.generic
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
-// newAcc = norm * oldAcc
-// CHECK: linalg.generic
-// CHECK:   arith.mulf
-// CHECK:   linalg.yield
-// newAcc = P @ V + newAcc
-// CHECK: linalg.generic
-// CHECK:   arith.mulf
-// CHECK:   arith.addf
-// CHECK:   linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index a974f3a..f92bdb8 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -1536,66 +1536,3 @@
 // CHECK:        }
 // CHECK:        return
 // CHECK:      }
-
-// -----
-
-#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
-#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
-#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
-#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
-#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
-
-func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
-  %scale = arith.constant 1.0 : f32
-
-  %output_empty = tensor.empty() : tensor<192x1024x64xf32>
-  %row_red_empty = tensor.empty() : tensor<192x1024xf32>
-
-  %sum_ident = arith.constant 0.000000e+00 : f32
-  %max_ident = arith.constant -3.40282347E+38 : f32
-
-  %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-  %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
-  %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
-
-  %out:3 = iree_linalg_ext.online_attention
-        { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
-        ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32)
-        outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-        -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
-
-  return %out#0 : tensor<192x1024x64xf32>
-}
-
-// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
-// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
-// CHECK-LABEL: @online_attention
-// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
-// CHECK-DAG:   %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
-// CHECK-DAG:   %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
-// CHECK-DAG:   %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
-// CHECK-DAG:  %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
-// CHECK-DAG:  %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
-// CHECK-DAG:  %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
-// CHECK-DAG:  %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
-// CHECK-DAG:  %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
-// CHECK-DAG:  %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
-// CHECK-DAG: iree_linalg_ext.online_attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]}
-// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32)
-// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
-// CHECK: scf.forall.in_parallel
-
-module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
-    %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
index ac5f5c0..7feee69 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
@@ -27,7 +27,7 @@
 
 void AttentionOpDetail::inferFromIndexingMaps(
     ArrayRef<AffineMap> indexingMaps) {
-  assert(indexingMaps.size() >= 4);
+  assert(indexingMaps.size() == 4);
   AffineMap qMap = indexingMaps[0];
   AffineMap kMap = indexingMaps[1];
   AffineMap vMap = indexingMaps[2];
@@ -82,23 +82,7 @@
 
   AttentionOpDetail opInfo;
   opInfo.inferFromIndexingMaps(indexingMaps);
-  opInfo.maps = SmallVector<AffineMap>(indexingMaps);
   return opInfo;
 }
 
-AffineMap AttentionOpDetail::getSMap() const {
-  // We need to create an indexing map for the intermediate result of first
-  // matmul. There could be other options, but we choose to create a standard
-  // indexing map:
-  //   SMap = (batch, m, k1, k2, n) -> (batch, m, k2)
-  AffineMap sMap = AffineMap::get(/*dimCount=*/getDomainRank(),
-                                  /*symbolCount=*/0, getContext());
-  for (auto dim :
-       llvm::concat<const int64_t>(getBatchDims(), getMDims(), getK2Dims())) {
-    AffineExpr dimExpr = getAffineDimExpr(dim, getContext());
-    sMap = sMap.insertResult(dimExpr, sMap.getNumResults());
-  }
-  return sMap;
-}
-
 }; // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
index e66bc86..cfba4ff 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h
@@ -42,29 +42,20 @@
 public:
   static FailureOr<AttentionOpDetail> get(ArrayRef<AffineMap> indexingMaps);
 
-  int64_t getDomainRank() const { return maps[0].getNumDims(); }
   ArrayRef<int64_t> getBatchDims() const { return batch; }
   ArrayRef<int64_t> getMDims() const { return m; }
   ArrayRef<int64_t> getK1Dims() const { return k1; }
   ArrayRef<int64_t> getK2Dims() const { return k2; }
   ArrayRef<int64_t> getNDims() const { return n; }
 
-  ArrayRef<AffineMap> getIndexingMaps() const { return maps; }
-
-  AffineMap getSMap() const;
-
 private:
   void inferFromIndexingMaps(ArrayRef<AffineMap> indexingMaps);
 
-  MLIRContext *getContext() const { return maps[0].getContext(); }
-
   SmallVector<int64_t> batch;
   SmallVector<int64_t> m;
   SmallVector<int64_t> k1;
   SmallVector<int64_t> k2;
   SmallVector<int64_t> n;
-
-  SmallVector<AffineMap> maps;
 };
 
 }; // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index e6c3548..d231322 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -44,13 +44,6 @@
       [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); });
 }
 
-Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef<Range> slice) {
-  return getSlice(b, loc, src,
-                  llvm::map_to_vector(slice, [](Range x) { return x.offset; }),
-                  llvm::map_to_vector(slice, [](Range x) { return x.size; }),
-                  llvm::map_to_vector(slice, [](Range x) { return x.stride; }));
-}
-
 Value getSlice(OpBuilder &b, Location loc, Value src,
                ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
                ArrayRef<OpFoldResult> strides) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index eec973f..9b40a54 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -13,10 +13,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 
-namespace mlir {
-struct Range;
-}; // namespace mlir
-
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
 /// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
@@ -30,7 +26,6 @@
 
 /// Returns a `memref.subview` or a `tensor.extract_slice` based on the type of
 /// `src`.
-Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef<Range> slice);
 Value getSlice(OpBuilder &b, Location loc, Value src,
                ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
                ArrayRef<OpFoldResult> strides);