[Codegen] Remove old attention transformations (#18740)

These transformations on iree_linalg_ext.attention have been replaced by
transformations on iree_linalg_ext.online_attention. This path of
attention transformations has been deprecated for a long time and it's
time to delete it.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 925106a..522404d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -21,8 +21,6 @@
             "amdgpu_chained_matmul.mlir",
             "amdgpu_contraction_distribution.mlir",
             "amdgpu_set_anchor_layouts.mlir",
-            "attention.mlir",
-            "attention_mfma.mlir",
             "conv_pipeline_test_cuda.mlir",
             "conv_pipeline_test_rocm.mlir",
             "convert_to_nvvm.mlir",
@@ -82,8 +80,6 @@
         # tensor_dialect_*_spec is a an MLIR file that specifies a
         # transformation, it needs to be included as data.
         exclude = [
-            "attention_mfma_transform_spec.mlir",
-            "attention_transform_spec.mlir",
             "transform_dialect_codegen_bufferize_spec.mlir",
             "transform_dialect_codegen_foreach_to_gpu_spec.mlir",
             "transform_dialect_codegen_vector_distribution_spec.mlir",
@@ -92,8 +88,6 @@
     ),
     cfg = "//compiler:lit.cfg.py",
     data = [
-        "attention_mfma_transform_spec.mlir",
-        "attention_transform_spec.mlir",
         "transform_dialect_codegen_bufferize_spec.mlir",
         "transform_dialect_codegen_foreach_to_gpu_spec.mlir",
         "transform_dialect_codegen_vector_distribution_spec.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 0c27964..b707dfb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -17,8 +17,6 @@
     "amdgpu_chained_matmul.mlir"
     "amdgpu_contraction_distribution.mlir"
     "amdgpu_set_anchor_layouts.mlir"
-    "attention.mlir"
-    "attention_mfma.mlir"
     "cast_address_space_function.mlir"
     "cast_type_to_fit_mma.mlir"
     "config_matvec.mlir"
@@ -77,8 +75,6 @@
     FileCheck
     iree-opt
   DATA
-    attention_mfma_transform_spec.mlir
-    attention_transform_spec.mlir
     transform_dialect_codegen_bufferize_spec.mlir
     transform_dialect_codegen_foreach_to_gpu_spec.mlir
     transform_dialect_codegen_vector_distribution_spec.mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
deleted file mode 100644
index 3f2c090..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ /dev/null
@@ -1,175 +0,0 @@
-// RUN: iree-opt %s --pass-pipeline='builtin.module(iree-transform-dialect-interpreter{library-file-name=%p/attention_transform_spec.mlir})' \
-// RUN:   --iree-gpu-test-target=sm_60 | \
-// RUN: FileCheck --check-prefix=CHECK %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-func.func @_attention_dispatch_0() {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 1.250000e-01 : f16
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>>
-  %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
-  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
-  %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
-  %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
-  %7 = tensor.empty() : tensor<192x1024x64xf16>
-  %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                    affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                    affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                    affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                    affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                    ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
-  flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
-  return
-}
-
-// CHECK-DAG:  #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)>
-// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG:  #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG:  #[[MAP5:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG:  #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-DAG:  #[[MAP7:.+]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG:  #[[TRANSLATION:.+]] = #iree_codegen.translation_info<None workgroup_size = [4, 8, 4] subgroup_size = 32>
-// CHECK:      func.func @_attention_dispatch_0()
-// CHECK-SAME:     translation_info = #[[TRANSLATION]]
-// CHECK-DAG:    %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
-// CHECK-DAG:    %[[CST_0:.+]] = arith.constant dense<-1.000000e+30> : vector<32xf32>
-// CHECK-DAG:    %[[CST_1:.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
-// CHECK-DAG:    %[[CST_2:.+]] = arith.constant dense<0.000000e+00> : vector<32x128xf32>
-// CHECK-DAG:    %[[CST_3:.+]] = arith.constant dense<1.000000e+00> : vector<64x32xf32>
-// CHECK-DAG:    %[[CST_4:.+]] = arith.constant 0.000000e+00 : f16
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C128:.+]] = arith.constant 128 : index
-// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK-DAG:    %[[CST_5:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-dAG:    %[[CST_6:.+]] = arith.constant dense<1.802980e-01> : vector<128x64xf16>
-// CHECK:        %[[D0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64)
-// CHECK-SAME:     offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        memref.assume_alignment %[[D0]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        %[[D1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64)
-// CHECK-SAME:     offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        memref.assume_alignment %[[D1]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        %[[D2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64)
-// CHECK-SAME:     offset(%[[C0]]) flags(ReadOnly) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        memref.assume_alignment %[[D2]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        %[[D3:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(3) alignment(64)
-// CHECK-SAME:     offset(%[[C0]]) : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        memref.assume_alignment %[[D3]], 64 : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>>
-// CHECK:        %[[WORKGROUP_ID_X:.+]] = hal.interface.workgroup.id[0] : index
-// CHECK:        %[[WORKGROUP_ID_Y:.+]] = hal.interface.workgroup.id[1] : index
-// CHECK-DAG:    %[[D4:.+]] = affine.apply #[[MAP]]()[%[[WORKGROUP_ID_Y]]]
-// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[D0]][%[[WORKGROUP_ID_X]], %[[D4]], 0] [1, 128, 64] [1, 1, 1] :
-// CHECK-SAME:     memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK:        %[[SUBVIEW_6:.+]] = memref.subview %[[D3]][%[[WORKGROUP_ID_X]], %[[D4]], 0] [1, 128, 64] [1, 1, 1] :
-// CHECK-SAME:     memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK:        %[[ALLOC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<1x128x64xf16,
-// CHECK-SAME:     #[[GPU:.+]].address_space<workgroup>>
-// CHECK:        gpu.barrier
-// CHECK:        linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[SUBVIEW]] : memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-// CHECK-SAME:     outs(%[[ALLOC]] : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK:          linalg.yield %[[IN]] : f16
-// CHECK:        }
-// CHECK:        gpu.barrier
-// CHECK:        %[[ALLOC_7:.+]] = memref.alloc() {alignment = 64 : i64} : memref<1x128x64xf16,
-// CHECK-SAME:     #[[GPU]].address_space<workgroup>>
-// CHECK:        gpu.barrier
-// CHECK:        linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[SUBVIEW_6]] : memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
-// CHECK-SAME:     outs(%[[ALLOC_7]] : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK:          linalg.yield %[[IN]] : f16
-// CHECK:        }
-// CHECK:        gpu.barrier
-// CHECK-DAG:    %[[D5:.+]] = gpu.thread_id  x
-// CHECK-DAG:    %[[D6:.+]] = gpu.thread_id  y
-// CHECK-DAG:    %[[D7:.+]] = gpu.thread_id  z
-// CHECK-DAG:    %[[D8:.+]] = affine.apply #[[MAP2]]()[%[[D5]], %[[D6]], %[[D7]]]
-// CHECK:        %[[D9:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D8]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
-// CHECK-SAME:     true]} : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>, vector<32x64xf16>
-// CHECK:        %[[D11:.+]]:3 = scf.for %[[ARG0:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C128]]
-// CHECK-SAME:     iter_args(%[[ARG1:[a-zA-Z0-9_]+]] = %[[CST_0]], %[[ARG2:[a-zA-Z0-9_]+]] = %[[CST_1]],
-// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<32xf32>, vector<32xf32>, vector<32x64xf32>) {
-// CHECK:          %[[SUBVIEW_8:.+]] = memref.subview %[[D1]][%[[WORKGROUP_ID_X]], %[[ARG0]], 0] [1, 128, 64] [1, 1, 1]
-// CHECK-SAME:       : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK:          %[[SUBVIEW_9:.+]] = memref.subview %[[D2]][%[[WORKGROUP_ID_X]], %[[ARG0]], 0] [1, 128, 64] [1, 1, 1]
-// CHECK-SAME:       : memref<192x1024x64xf16, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
-// CHECK:          %[[ALLOC_12:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16, #gpu.address_space<workgroup>>
-// CHECK:          vector.transfer_write %[[CST_6:.+]], %[[ALLOC_12]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<128x64xf16>, memref<128x64xf16, #gpu.address_space<workgroup>>
-// CHECK:          %[[ALLOC_10:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16,
-// CHECK-SAME:       #[[GPU]].address_space<workgroup>>
-// CHECK:          gpu.barrier
-// CHECK:          linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME:       ins(%[[SUBVIEW_8]] : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%[[ALLOC_10]] :
-// CHECK-SAME:       memref<128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK:            linalg.yield %[[IN]] : f16
-// CHECK:          }
-// CHECK:          gpu.barrier
-// CHECK:          %[[ALLOC_11:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf16,
-// CHECK-SAME:       #[[GPU]].address_space<workgroup>>
-// CHECK:          gpu.barrier
-// CHECK:          linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP3]]], iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME:       ins(%[[SUBVIEW_9]] : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%[[ALLOC_11]] :
-// CHECK-SAME:       memref<128x64xf16, #[[GPU]].address_space<workgroup>>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK:            linalg.yield %[[IN]] : f16
-// CHECK:          }
-// CHECK:          gpu.barrier
-// CHECK:          %[[READ:.+]] = vector.transfer_read %[[ALLOC_12]][%[[D8]], %[[C0]]], %{{.+}} : memref<128x64xf16, #gpu.address_space<workgroup>>, vector<32x64xf16>
-// CHECK:          %[[MUL:.+]] = arith.mulf %[[D9]], %[[READ]] : vector<32x64xf16>
-// CHECK:          %[[D13:.+]] = vector.transfer_read %[[ALLOC_10]][%[[C0]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
-// CHECK-SAME:       true]} : memref<128x64xf16, #[[GPU]].address_space<workgroup>>, vector<128x64xf16>
-// CHECK:          %[[D10:.+]] = arith.extf %[[MUL]] : vector<32x64xf16> to vector<32x64xf32>
-// CHECK:          %[[D14:.+]] = arith.extf %[[D13]] : vector<128x64xf16> to vector<128x64xf32>
-// CHECK:          %[[D15:.+]] = vector.contract {indexing_maps = [#[[MAP4]], #[[MAP5]], #[[MAP6]]], iterator_types =
-// CHECK-SAME:       ["parallel", "parallel", "reduction"], kind = #[[VECTOR:.+]].kind<add>} %[[D10]], %[[D14]],
-// CHECK-SAME:       %[[CST_2]] : vector<32x64xf32>, vector<128x64xf32> into vector<32x128xf32>
-// CHECK:          %[[D16:.+]] = vector.multi_reduction <maximumf>, %[[D15]], %[[ARG1]] [1] : vector<32x128xf32> to
-// CHECK-SAME:       vector<32xf32>
-// CHECK:          %[[D17:.+]] = vector.broadcast %[[D16]] : vector<32xf32> to vector<128x32xf32>
-// CHECK:          %[[D18:.+]] = vector.transpose %[[D17]], [1, 0] : vector<128x32xf32> to vector<32x128xf32>
-// CHECK:          %[[D19:.+]] = arith.subf %[[D15]], %[[D18]] : vector<32x128xf32>
-// CHECK:          %[[D20:.+]] = math.exp2 %[[D19]] : vector<32x128xf32>
-// CHECK:          %[[D21:.+]] = arith.subf %[[ARG1]], %[[D16]] : vector<32xf32>
-// CHECK:          %[[D22:.+]] = math.exp2 %[[D21]] : vector<32xf32>
-// CHECK:          %[[D23:.+]] = arith.mulf %[[D22]], %[[ARG2]] : vector<32xf32>
-// CHECK:          %[[D24:.+]] = vector.multi_reduction <add>, %[[D20]], %[[D23]] [1] : vector<32x128xf32> to
-// CHECK-SAME:       vector<32xf32>
-// CHECK:          %[[D29:.+]] = arith.truncf %[[D20]] : vector<32x128xf32> to vector<32x128xf16>
-// CHECK:          %[[D31:.+]] = vector.broadcast %[[D22]] : vector<32xf32> to vector<64x32xf32>
-// CHECK:          %[[D33:.+]] = vector.transpose %[[D31]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
-// CHECK:          %[[D34:.+]] = arith.mulf %[[D33]], %[[ARG3]] : vector<32x64xf32>
-// CHECK:          %[[D36:.+]] = arith.extf %[[D29]] : vector<32x128xf16> to vector<32x128xf32>
-// CHECK:          %[[D35:.+]] = vector.transfer_read %[[ALLOC_11]][%[[C0]], %[[C0]]], %[[CST_4]]
-// CHECK-SAME:       {in_bounds = [true, true], permutation_map = #[[MAP7]]} : memref<128x64xf16, #[[GPU]].address_space<workgroup>>, vector<64x128xf16>
-// CHECK:          %[[D37:.+]] = arith.extf %[[D35]] : vector<64x128xf16> to vector<64x128xf32>
-// CHECK:          %[[D39:.+]] = vector.contract {indexing_maps = [#[[MAP4]], #[[MAP5]], #[[MAP6]]], iterator_types =
-// CHECK-SAME:       ["parallel", "parallel", "reduction"], kind = #[[VECTOR]].kind<add>} %[[D36]], %[[D37]], %[[D34]] :
-// CHECK-SAME:       vector<32x128xf32>, vector<64x128xf32> into vector<32x64xf32>
-// CHECK:          scf.yield %[[D16]], %[[D24]], %[[D39]] : vector<32xf32>, vector<32xf32>, vector<32x64xf32>
-// CHECK:        }
-// CHECK:        %[[DSCALE1:.+]] = vector.broadcast %[[D11]]#1 : vector<32xf32> to vector<64x32xf32>
-// CHECK:        %[[DSCALE2:.+]] = arith.divf %[[CST_3]], %[[DSCALE1]] : vector<64x32xf32>
-// CHECK:        %[[DSCALE3:.+]] = vector.transpose %[[DSCALE2]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
-// CHECK:        %[[DSCALE4:.+]] = arith.mulf %[[DSCALE3]], %[[D11]]#2 : vector<32x64xf32>
-// CHECK:        %[[D12:.+]] = arith.truncf %[[DSCALE4]] : vector<32x64xf32> to vector<32x64xf16>
-// CHECK:        vector.transfer_write %[[D12]], %[[ALLOC_7]][%[[C0]], %[[D8]], %[[C0]]] {in_bounds = [true, true]} :
-// CHECK-SAME:     vector<32x64xf16>, memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>
-// CHECK:        gpu.barrier
-// CHECK:        linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[ALLOC_7]] : memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>)
-// CHECK-SAME:     outs(%[[SUBVIEW_6]] : memref<1x128x64xf16, strided<[65536, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f16, %[[OUT:.+]]: f16):
-// CHECK:          linalg.yield %[[IN]] : f16
-// CHECK:        }
-// CHECK:        gpu.barrier
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
deleted file mode 100644
index 109b107..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
+++ /dev/null
@@ -1,37 +0,0 @@
-// RUN: iree-opt %s --pass-pipeline='builtin.module(iree-transform-dialect-interpreter{library-file-name=%p/attention_mfma_transform_spec.mlir})' \
-// RUN:   --iree-gpu-test-target=gfx908 | \
-// RUN: FileCheck --check-prefix=CHECK %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-func.func @attention_dispatch_0_attention_16x16384x128xf16() {
-  %c0 = arith.constant 0 : index
-  %scale = arith.constant 0.08838834764 : f16
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>>
-  %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
-  %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
-  %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
-  %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
-  %7 = tensor.empty() : tensor<16x16384x128xf16>
-  %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                    affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                    affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                    affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                    affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                    ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
-  flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
-  return
-}
-    // CHECK-NOT: vector.contract
-    // CHECK-NOT: iree_vector_ext.to_simd
-    // CHECK-NOT: iree_vector_ext.to_simt
-    // CHECK-COUNT-8: vector.load {{.*}} : memref<16x16384x128xf16, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
-    // CHECK: scf.for {{.*}} = %c0 to %c16384 step %c64 {{.*}} -> (vector<2xf32>, vector<2xf32>, vector<8x2x4xf32>)
-    // CHECK-COUNT-16: vector.load {{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, vector<8xf16>
-    // CHECK-COUNT-128: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
deleted file mode 100644
index 9261f2e..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir
+++ /dev/null
@@ -1,203 +0,0 @@
-#layout = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-
-module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%variant_op: !transform.any_op) {
-    // Get attention op
-    // ==========================================
-    %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-
-    // Tile and distribute to workgroups
-    // ==========================================
-    %tiled_attention, %forall_grid =
-    transform.structured.tile_using_forall %attention tile_sizes [1, 128]
-      ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    // transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
-
-    // Tile batch dimensions of attention
-    // ==========================================
-    %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %top_level_func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %top_level_func : !transform.any_op
-
-    // Promote query and output operands
-    // ==========================================
-    //%attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    //%promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 3]
-    //  : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Tile and decompose attention
-    // ==========================================
-    %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 {tile_size = 64} :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
-        = transform.iree.decompose_tiled_attention %blocked_attention {tile_size = 64} :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
-                              !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Promote key and value operands
-    // ==========================================
-    %promoted_first_matmul, %alloc0 = transform.iree.promote_operands %first_matmul [1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %promoted_second_matmul, %alloc1 = transform.iree.promote_operands %second_matmul [1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    // Tile and fuse attention ops
-    // ==========================================
-    %tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-
-    %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.apply_cse to %func : !transform.any_op
-
-    %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.apply_cse to %func : !transform.any_op
-
-    %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-    // Distribute fills
-    // ==========================================
-
-    // Get all fills that haven't been distributed to warps.
-    %fills = transform.include @get_undistributed_fills failures(propagate) (%variant_op)  : (!transform.any_op) -> !transform.any_op
-
-    %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    // Distribute last_truncate and fuse final_scaling into it
-    // ==========================================
-    %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    // Vectorize function
-    // ==========================================
-    transform.apply_patterns to %func {
-      transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
-      transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
-      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
-    } : !transform.any_op
-    %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> (!transform.any_op)
-
-    // Bufferization
-    // ==========================================
-    transform.apply_patterns to %func_3 {
-      transform.apply_patterns.tensor.reassociative_reshape_folding
-      transform.apply_patterns.canonicalization
-      transform.apply_patterns.iree.fold_fill_into_pad
-      transform.apply_patterns.linalg.tiling_canonicalization
-      transform.apply_patterns.scf.for_loop_canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func_3 : !transform.any_op
-    transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-    transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op
-    %func_4 = transform.iree.bufferize { target_gpu } %func_3 : (!transform.any_op) -> (!transform.any_op)
-
-    // Step 5. Pre-process the contract and transfer ops to put it in the right form.
-    // ===========================================================================
-    %func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func_2 {
-      transform.apply_patterns.vector.fold_arith_extension
-    } : !transform.any_op
-
-    // Step 6. Post-bufferization vector distribution
-    // ===========================================================================
-    %func_7 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
-    transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> ()
-
-    transform.apply_patterns to %func_7 {
-      transform.apply_patterns.memref.fold_memref_alias_ops
-    } : !transform.any_op
-    transform.iree.apply_licm %func_7 : !transform.any_op
-    transform.apply_patterns to %func_7 {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func_7 : !transform.any_op
-    %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7
-    : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func_8 {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func_8 : !transform.any_op
-    transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> ()
-
-    // Apply chained matmul optimization.
-    transform.apply_registered_pass "iree-amdgpu-prepare-chained-matmul" to %func_8 : (!transform.any_op) -> (!transform.any_op)
-
-    // Get the vector.contract ops.
-    %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op :  (!transform.any_op) -> !transform.any_op
-    %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %layout16x16x16 = transform.param.constant #layout -> !transform.any_param
-    transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 { read_layout_indices = array<i64: 0, 1> } : !transform.any_op, !transform.any_param
-    transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param
-
-    %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %distribute_func_2 = transform.iree.amdgpu_distribute_vectors %distribute_func test_conversion : (!transform.any_op) -> !transform.any_op
-
-    transform.apply_patterns to %distribute_func_2 {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %distribute_func_2 : !transform.any_op
-
-    // Distribute shared memory copies
-    // ==========================================
-    %func_10 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.gpu_distribute_shared_memory_copy %func_10 : (!transform.any_op) -> ()
-    transform.apply_patterns to %func_10 {
-        transform.apply_patterns.memref.fold_memref_alias_ops
-        transform.apply_patterns.canonicalization
-        transform.apply_patterns.linalg.tiling_canonicalization
-      } : !transform.any_op
-    transform.apply_cse to %func_10 : !transform.any_op
-
-    transform.yield
-  }
-  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
-    transform.match.operation_name %arg0 ["linalg.fill"] : !transform.any_op
-    %0 = transform.get_parent_op %arg0 {allow_empty_results, nth_parent = 2 : i64, op_name = "scf.forall"} : (!transform.any_op) -> !transform.any_op
-    transform.match.operation_empty %0 : !transform.any_op
-    transform.yield %arg0 : !transform.any_op
-  }
-  transform.named_sequence @get_undistributed_fills(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
-    %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.yield %0 : !transform.any_op
-  }
-} ////  module
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
deleted file mode 100644
index 2fc82bc..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
+++ /dev/null
@@ -1,158 +0,0 @@
-module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%variant_op: !transform.any_op) {
-    // Get attention op
-    // ==========================================
-    %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-
-    // Tile and distribute to workgroups
-    // ==========================================
-    %tiled_attention, %forall_grid =
-    transform.structured.tile_using_forall %attention tile_sizes [1, 128]
-      ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    // transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
-
-    // Tile batch dimensions of attention
-    // ==========================================
-    %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %top_level_func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %top_level_func : !transform.any_op
-
-    // Promote query and output operands
-    // ==========================================
-    %attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 4]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Tile and decompose attention
-    // ==========================================
-    %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
-        = transform.iree.decompose_tiled_attention %blocked_attention :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
-                              !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Promote key and value operands
-    // ==========================================
-    %promoted_first_matmul, %alloc0 = transform.iree.promote_operands %first_matmul [1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %promoted_second_matmul, %alloc1 = transform.iree.promote_operands %second_matmul [1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    // Tile and fuse attention ops
-    // ==========================================
-    %tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-
-    %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.apply_cse to %func : !transform.any_op
-
-    %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.apply_cse to %func : !transform.any_op
-
-    %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    // Distribute fills
-    // ==========================================
-    %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill : !transform.any_op
-    %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    // Distribute last_truncate and fuse final_scaling into it
-    // ==========================================
-    %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func : !transform.any_op
-
-    // Vectorize function
-    // ==========================================
-    transform.apply_patterns to %func {
-      transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
-      transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
-      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
-    } : !transform.any_op
-    %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> (!transform.any_op)
-
-    // Bufferization
-    // ==========================================
-    transform.apply_patterns to %func_3 {
-      transform.apply_patterns.tensor.reassociative_reshape_folding
-      transform.apply_patterns.canonicalization
-      transform.apply_patterns.iree.fold_fill_into_pad
-      transform.apply_patterns.linalg.tiling_canonicalization
-      transform.apply_patterns.scf.for_loop_canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func_3 : !transform.any_op
-    transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-    transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op
-    %func_4 = transform.iree.bufferize { target_gpu } %func_3 : (!transform.any_op) -> (!transform.any_op)
-
-    // Step 5. Pre-process the contract and transfer ops to put it in the right form.
-    // ===========================================================================
-    %func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func_2 {
-      transform.apply_patterns.iree.prepare_vector_to_mma
-    } : !transform.any_op
-
-    // Step 6. Post-bufferization vector distribution
-    // ===========================================================================
-    %func_7 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
-    transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [4, 8, 4] subgroup_size = 32 sync_after_distribution = false : (!transform.any_op) -> ()
-
-    transform.apply_patterns to %func_7 {
-      transform.apply_patterns.memref.fold_memref_alias_ops
-    } : !transform.any_op
-    transform.iree.apply_licm %func_7 : !transform.any_op
-    transform.apply_patterns to %func_7 {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func_7 : !transform.any_op
-    %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7
-    : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func_8 {
-      transform.apply_patterns.canonicalization
-    } : !transform.any_op
-    transform.apply_cse to %func_8 : !transform.any_op
-    transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> ()
-    transform.yield
-  }
-} ////  module
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
index ef169fd..78e3ea5 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
@@ -43,7 +43,8 @@
   funcPassManager.addPass(createCSEPass());
   funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
   funcPassManager.addPass(createConcretizePadResultShapePass());
-  funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
+  funcPassManager.addPass(
+      IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
   funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
   funcPassManager.addPass(
       IREE::LinalgExt::createDecomposeWinogradTransformPass());
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
index b350053..46c24e4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
@@ -27,30 +27,6 @@
 // TileAndDecomposeAttention
 //===---------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure LinalgExt::TileAttentionOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  SmallVector<Operation *> ops;
-  LinalgExt::tileAttention(attentionOp, ops, rewriter, getTileSize());
-  for (auto op : ops) {
-    results.push_back(op);
-  }
-  return DiagnosedSilenceableFailure::success();
-}
-
-DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  SmallVector<Operation *> ops;
-  LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter, getTileSize());
-  for (auto op : ops) {
-    results.push_back(op);
-  }
-  return DiagnosedSilenceableFailure::success();
-}
-
 DiagnosedSilenceableFailure LinalgExt::DecomposeAggregateOp::applyToOne(
     transform::TransformRewriter &rewriter,
     linalg::AggregatedOpInterface aggregateOp,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
index 9e761c0..4417c80 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td
@@ -12,81 +12,6 @@
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
-def TileAttentionOp : Op<Transform_Dialect, "iree.tile_attention",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Target iree_linalg_ext.attention ops and tile them.
-    This transform consumes the target handle and produces a result handle.
-  }];
-
-  let arguments = (
-      ins TransformHandleTypeInterface:$target,
-          OptionalAttr<I64Attr>:$tile_size
-  );
-  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
-
-  let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
-  let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
-  let builders = [
-    OpBuilder<(ins "Value":$target)>
-  ];
-
-  let assemblyFormat = [{
-    $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
-def DecomposeTiledAttentionOp : Op<Transform_Dialect, "iree.decompose_tiled_attention",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Target iree_linalg_ext.attention ops and decompose them.
-    This transform consumes the target handle and produces a result handle.
-  }];
-
-  let arguments = (
-      ins TransformHandleTypeInterface:$target,
-          OptionalAttr<I64Attr>:$tile_size
-  );
-  let results = (outs Variadic<TransformHandleTypeInterface>:$result);
-
-  let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
-  let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
-  let builders = [
-    OpBuilder<(ins "Value":$target)>
-  ];
-
-  let assemblyFormat = [{
-    $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
-
 def DecomposeAggregateOp : Op<Transform_Dialect, "iree.decompose_aggregate_op",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
index b1ff512..9f08e2a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
@@ -7,9 +7,6 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Pass/Pass.h"
 
@@ -19,335 +16,14 @@
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
 
 namespace {
-
-// Computes a reduction along the rows of a 2d tensor of shape MxN
-// to produce a tensor of shape M
-template <typename T>
-static Value computeRowwiseReduction(Value a, Value output, Location loc,
-                                     OpBuilder &builder,
-                                     SmallVectorImpl<Operation *> &ops) {
-  SmallVector<utils::IteratorType> iteratorTypes{
-      utils::IteratorType::parallel, utils::IteratorType::reduction};
-  AffineMap id = AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  AffineExpr d0, d1;
-  bindDims(builder.getContext(), d0, d1);
-  // (d0, d1) -> (d0)
-  auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{id, rowMap};
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), a, output, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<T>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static Value computePartialSoftmax(Value qkTranspose, Value currentMax,
-                                   Location loc, OpBuilder &builder,
-                                   SmallVectorImpl<Operation *> &ops) {
-  AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  AffineExpr d0, d1;
-  bindDims(builder.getContext(), d0, d1);
-  // (d0, d1) -> (d0)
-  auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, qkTranspose.getType(), ValueRange{currentMax}, qkTranspose,
-      indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
-        Value result = b.create<math::Exp2Op>(loc, diff);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-/// Return the scale factor for the new softmax maximum and add the generic to
-/// the provided list of operations.
-static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
-                                OpBuilder &builder,
-                                SmallVectorImpl<Operation *> &ops) {
-  SmallVector<utils::IteratorType> iteratorTypes(1,
-                                                 utils::IteratorType::parallel);
-  auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
-  SmallVector<AffineMap> indexingMaps(2, identityMap);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
-        Value weight = b.create<math::Exp2Op>(loc, diff);
-        b.create<linalg::YieldOp>(loc, weight);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static Value updateAndScale(Value scaleFactor, Value oldSum, Location loc,
-                            OpBuilder &builder,
-                            SmallVectorImpl<Operation *> &ops) {
-  SmallVector<utils::IteratorType> iteratorTypes(1,
-                                                 utils::IteratorType::parallel);
-  auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
-  SmallVector<AffineMap> indexingMaps(2, identityMap);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, oldSum.getType(), scaleFactor, oldSum, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value scaledOldSum = b.create<arith::MulFOp>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, scaledOldSum);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static Value scalePartialSoftmax(Value softmax, Value inverseNewSum,
-                                 Location loc, OpBuilder &builder,
-                                 SmallVectorImpl<Operation *> &ops) {
-  AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  AffineExpr d0, d1;
-  bindDims(builder.getContext(), d0, d1);
-  // (d0, d1) -> (d0)
-  auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, softmax.getType(), ValueRange{inverseNewSum}, softmax, indexingMaps,
-      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::MulFOp>(loc, args[1], args[0]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static Value scaleAccumulator(Value accumulator, Value scaleFactor,
-                              Location loc, OpBuilder &builder,
-                              SmallVectorImpl<Operation *> &ops) {
-  AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  AffineExpr d0, d1;
-  bindDims(builder.getContext(), d0, d1);
-  // (d0, d1) -> (d0)
-  auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, accumulator.getType(), scaleFactor, accumulator, indexingMaps,
-      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static Value computeQKTranspose(Value query, Value key, Value output,
-                                Value zero, Location loc, OpBuilder &builder,
-                                SmallVectorImpl<Operation *> &ops) {
-  auto fillOp = builder.create<linalg::FillOp>(loc, ValueRange{zero}, output);
-  ops.push_back(fillOp);
-  Value acc = fillOp.result();
-  auto matmulOp = builder.create<linalg::MatmulTransposeBOp>(
-      loc, output.getType(), ValueRange{query, key}, acc);
-  ops.push_back(matmulOp);
-  return matmulOp.getResult(0);
-}
-
-static Value truncateToF16(Value input, Value output,
-                           SmallVectorImpl<Operation *> &ops,
-                           OpBuilder &builder, Location loc) {
-  AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{identityMap, identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), ValueRange{input}, output, indexingMaps,
-      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::TruncFOp>(loc, b.getF16Type(), args[0]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static std::tuple<Value, Value, Value>
-createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
-                    Value outputSlice, Value maxSlice, Value sumSlice,
-                    OpFoldResult sequenceTileLength,
-                    OpFoldResult keyValueTileLength, OpFoldResult headDimension,
-                    Type elementType, SmallVectorImpl<Operation *> &ops,
-                    bool transposeV, Location loc, OpBuilder &builder) {
-
-  Type f32Type = builder.getF32Type();
-  // Compute matmul(q, transpose(k))
-  Value zero =
-      builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(f32Type));
-  SmallVector<OpFoldResult> resultShape{sequenceTileLength, keyValueTileLength};
-  Value emptySquare =
-      builder.create<tensor::EmptyOp>(loc, resultShape, f32Type);
-  Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
-                                         zero, loc, builder, ops);
-
-  // Compute current statistics
-  Value newMax = computeRowwiseReduction<arith::MaximumFOp>(
-      qkTranspose, maxSlice, loc, builder, ops);
-  Value partialSoftmax =
-      computePartialSoftmax(qkTranspose, newMax, loc, builder, ops);
-  Value scaleFactor = computeScaleFactor(maxSlice, newMax, loc, builder, ops);
-  Value scaledOldSum = updateAndScale(scaleFactor, sumSlice, loc, builder, ops);
-  Value newSum = computeRowwiseReduction<arith::AddFOp>(
-      partialSoftmax, scaledOldSum, loc, builder, ops);
-  if (elementType.isF16()) {
-    Value empty =
-        builder.create<tensor::EmptyOp>(loc, resultShape, builder.getF16Type());
-    partialSoftmax = truncateToF16(partialSoftmax, empty, ops, builder, loc);
-  }
-
-  // Update accumulator
-  Value scaledAcc =
-      scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
-
-  // Compute matmul(softmax, v)
-  Operation *matmulOp;
-  if (transposeV) {
-    matmulOp = builder.create<linalg::MatmulTransposeBOp>(
-        loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
-        scaledAcc);
-  } else {
-    matmulOp = builder.create<linalg::MatmulOp>(
-        loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
-        scaledAcc);
-  }
-  ops.push_back(matmulOp);
-  Value result = matmulOp->getResult(0);
-  return std::make_tuple(result, newMax, newSum);
-}
-
-static Value scaleQuery(Value querySlice, Value scale, RewriterBase &rewriter) {
-  ShapedType queryType = cast<ShapedType>(querySlice.getType());
-  Location loc = querySlice.getLoc();
-
-  // Create a fill op for scale.
-  SmallVector<OpFoldResult> queryDims =
-      tensor::getMixedSizes(rewriter, loc, querySlice);
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, queryDims,
-                                                 queryType.getElementType());
-  auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{scale}, empty)
-                    .getResult(0);
-
-  // Create a generic op to multiply the query by the scale.
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto identityMap =
-      AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
-  SmallVector<AffineMap> indexingMaps(2, identityMap);
-  auto scaleOp = rewriter.create<linalg::GenericOp>(
-      loc, TypeRange{fillOp.getType()}, ValueRange{querySlice},
-      ValueRange{fillOp}, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return scaleOp.getResult(0);
-}
-
-} // namespace
-
-/// This is an implementation of flash attention which
-/// is a tiled and fused implementation of the attention operator.
-/// The attention operator computes:
-/// matmul(softmax(matmul(Q, transpose(K))), V)
-/// where: Q is the query matrix [B x N x d]
-///        K is the key matrix   [B x S x d]
-///        V is the value matrix [B x S x d]
-///
-/// The core algorithm is as follows:
-/// For each element in B,
-/// 1. Load a tile from the Q matrix of size T x d -> q
-/// 2. Initialize statistics: running_sum, running_max
-/// 3. for i = 0 to S with step T
-///    a. Load a tile from the K matrix of size T x d -> k
-///    b. Load a tile from the V matrix of size T x d -> v
-///    c. Compute matmul_transpose_b(q, k) -> qkT
-///    d. Compute max(max(qkT) along rows, old_max) -> new_max
-///    e. Compute curent estimate of softmax: exp(qKT - current_max) -> s
-///    f. Compute product of fixup and old_sum -> fsum
-///    g. Compute sum(sum(qkT) along rows, fsum) -> new_sum
-///    h. Compute 1.0 / new_sum -> inv_new_sum
-///    i. Compute softmax = softmax * inv_new_sum
-///    j. Truncate softmax to fp16
-///    k. Compute fsum  * inv_new_sum * accumulator -> new_accumulator
-///    j. Compute matmul(s, v) and add new_accumulator
-///
-/// Decompose tiled iree_linalg_ext.attention op.
-/// TODO: Adopt decomposeOperation with this.
-void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
-                             SmallVectorImpl<Operation *> &ops,
-                             RewriterBase &rewriter,
-                             std::optional<uint64_t> tileSize) {
-  Location loc = tiledAttnOp.getLoc();
-  Value keySlice = tiledAttnOp.getKey();
-  Value valueSlice = tiledAttnOp.getValue();
-  Value querySlice = tiledAttnOp.getQuery();
-  Value tiledResult = tiledAttnOp.getOutput();
-  Value max = *tiledAttnOp.getMax();
-  Value sum = *tiledAttnOp.getSum();
-
-  OpBuilder::InsertionGuard withinScfLoop(rewriter);
-  rewriter.setInsertionPointAfter(tiledAttnOp);
-  SmallVector<OpFoldResult> queryDimValues =
-      tensor::getMixedSizes(rewriter, loc, querySlice);
-  OpFoldResult headDimension = queryDimValues[1];
-  OpFoldResult sequenceTileLength = queryDimValues[0];
-  OpFoldResult keyValueTileLength =
-      tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
-
-  Type elementType = tiledAttnOp.getQueryType().getElementType();
-
-  // Since we use exp2 for attention instead of the original exp, we have to
-  // multiply the scale by log2(e). We use exp2 instead of exp as most GPUs
-  // have better support for exp2.
-  Value scale = tiledAttnOp.getScale();
-  Value log2e = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getFloatAttr(elementType, M_LOG2E));
-  scale = rewriter.create<arith::MulFOp>(loc, scale, log2e);
-
-  // In the original algorithm, the scaling is done after the softmax:
-  //        softmax(Q @ K.T * scale) @ V
-  //
-  // But, it is mathematically equivalent to do it on Q first and then multiply
-  // it by K.T. This just allows us to do the scaling once, instead of each
-  // iteration of the loop.
-  querySlice = scaleQuery(querySlice, scale, rewriter);
-  ops.push_back(querySlice.getDefiningOp());
-
-  auto [result, newMax, newSum] = createAttentionBody(
-      keySlice, valueSlice, querySlice, tiledResult, max, sum,
-      sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
-      tiledAttnOp.isTransposeV(), loc, rewriter);
-
-  rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
-}
-
-namespace {
 struct DecomposeAttentionPass final
     : impl::DecomposeAttentionPassBase<DecomposeAttentionPass> {
   using impl::DecomposeAttentionPassBase<
       DecomposeAttentionPass>::DecomposeAttentionPassBase;
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<
-        affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
-        linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
+    registry.insert<IREE::LinalgExt::IREELinalgExtDialect,
+                    linalg::LinalgDialect, tensor::TensorDialect>();
   }
   void runOnOperation() override;
 };
@@ -356,14 +32,6 @@
 void DecomposeAttentionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   IRRewriter rewriter(context);
-  std::optional<uint64_t> optionalTileSize{std::nullopt};
-  if (tileSize.hasValue()) {
-    optionalTileSize = tileSize.getValue();
-  }
-  getOperation().walk([&](AttentionOp attnOp) {
-    SmallVector<Operation *> ops;
-    decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize);
-  });
   getOperation().walk([&](OnlineAttentionOp onlineAtt) {
     rewriter.setInsertionPoint(onlineAtt);
     FailureOr<SmallVector<Value>> results =
@@ -375,4 +43,5 @@
     rewriter.replaceOp(onlineAtt, results.value());
   });
 }
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index af7716c..1e858df 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -34,16 +34,6 @@
     RewritePatternSet &patterns,
     std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);
 
-IREE::LinalgExt::AttentionOp
-tileAttention(IREE::LinalgExt::AttentionOp attnOp,
-              SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
-              std::optional<uint64_t> tileSize = std::nullopt);
-
-void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
-                             SmallVectorImpl<Operation *> &ops,
-                             RewriterBase &rewriter,
-                             std::optional<uint64_t> tileSize = std::nullopt);
-
 void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
                               SmallVectorImpl<Operation *> &ops,
                               RewriterBase &rewriter);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index 6454c5a..3995fba 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -83,24 +83,10 @@
   ];
 }
 
-def TileAttentionPass :
-    InterfacePass<"iree-linalg-ext-tile-attention", "mlir::FunctionOpInterface"> {
-  let summary =
-      "Tile the attention op along the reduction dimension";
-  let options = [
-    Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
-           "Tile size for sequential for loop in attention">,
-  ];
-}
-
 def DecomposeAttentionPass :
     InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> {
   let summary =
       "Decomposes attention op into a sequence of linalg ops";
-  let options = [
-    Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
-           "Tile size for sequential for loop in attention">,
-  ];
 }
 
 def ConvertAttentionToOnlineAttentionPass :
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index 8089cd0..472a073 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -16,196 +16,11 @@
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
-#define GEN_PASS_DEF_TILEATTENTIONPASS
 #define GEN_PASS_DEF_CONVERTATTENTIONTOONLINEATTENTIONPASS
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
 
 namespace {
 
-static Value truncateToF16(Value input, Value output,
-                           SmallVectorImpl<Operation *> &ops,
-                           OpBuilder &builder, Location loc) {
-  AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{identityMap, identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), ValueRange{input}, output, indexingMaps,
-      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::TruncFOp>(loc, b.getF16Type(), args[0]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static Value applyFinalScaling(Value result, Value newSum, Location loc,
-                               OpBuilder &builder,
-                               SmallVectorImpl<Operation *> &ops) {
-  AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
-  AffineExpr d0, d1;
-  bindDims(builder.getContext(), d0, d1);
-  // (d0, d1) -> (d0)
-  auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
-  SmallVector<AffineMap> indexingMaps = {rowMap, identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(2,
-                                                 utils::IteratorType::parallel);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, result.getType(), newSum, result, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value one = b.create<arith::ConstantOp>(
-            loc, b.getFloatAttr(args[0].getType(), 1.0));
-        Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
-        Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  ops.push_back(genericOp);
-  return genericOp.getResult(0);
-}
-
-static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
-                                    Value step, Value ub, ValueRange args,
-                                    Location loc, OpBuilder &builder) {
-  SmallVector<Value> lbs{lb};
-  SmallVector<Value> steps{step};
-  SmallVector<Value> ubs{ub};
-  scf::LoopNest loopNest = scf::buildLoopNest(
-      builder, loc, lbs, ubs, steps, args,
-      [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
-          ValueRange iterArgs) -> scf::ValueVector { return iterArgs; });
-  for (scf::ForOp loop : loopNest.loops) {
-    ivs.push_back(loop.getInductionVar());
-  }
-  return loopNest;
-}
-
-static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
-                          ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
-                          OpFoldResult headDimension, Type elementType,
-                          Location loc, OpBuilder &builder,
-                          bool swapLastTwoDims = false) {
-  auto one = builder.getIndexAttr(1);
-  auto zero = builder.getIndexAttr(0);
-  SmallVector<OpFoldResult> strides(keyShape.size(), one);
-  SmallVector<OpFoldResult> sizes(keyShape.size(), one);
-  SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
-  sizes[1] = keyValueTileLength;
-  sizes[2] = headDimension;
-  if (!ivs.empty()) {
-    offsets[1] = ivs[0];
-  }
-  SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
-  if (swapLastTwoDims) {
-    std::swap(sizes[1], sizes[2]);
-    std::swap(offsets[1], offsets[2]);
-    std::swap(tensorShape[0], tensorShape[1]);
-  }
-  auto tensorType = RankedTensorType::get(tensorShape, elementType);
-  Value keySlice = builder.create<tensor::ExtractSliceOp>(
-      loc, tensorType, key, offsets, sizes, strides);
-  return keySlice;
-}
-
-static Value extractOrInsertOutputSlice(Value src, Value dst,
-                                        ArrayRef<int64_t> queryShape,
-                                        OpFoldResult sequenceTileLength,
-                                        OpFoldResult headDimension,
-                                        Location loc, OpBuilder &builder) {
-  auto one = builder.getIndexAttr(1);
-  auto zero = builder.getIndexAttr(0);
-  SmallVector<OpFoldResult> strides(3, one);
-  SmallVector<OpFoldResult> sizes = {one, sequenceTileLength, headDimension};
-  SmallVector<OpFoldResult> offsets(3, zero);
-  Value slice;
-  if (!dst) {
-    SmallVector<int64_t> accShape{queryShape[1], queryShape[2]};
-    Type elementType = cast<ShapedType>(src.getType()).getElementType();
-    auto tensorType = RankedTensorType::get(accShape, elementType);
-    slice = builder.create<tensor::ExtractSliceOp>(loc, tensorType, src,
-                                                   offsets, sizes, strides);
-  } else {
-    slice = builder.create<tensor::InsertSliceOp>(loc, src, dst, offsets, sizes,
-                                                  strides);
-  }
-  return slice;
-}
-
-static Value extractOutputSlice(Value src, ArrayRef<int64_t> queryShape,
-                                OpFoldResult sequenceTileLength,
-                                OpFoldResult headDimension, Location loc,
-                                OpBuilder &builder) {
-  return extractOrInsertOutputSlice(src, {}, queryShape, sequenceTileLength,
-                                    headDimension, loc, builder);
-}
-
-static Value insertOutputSlice(Value src, Value dst,
-                               OpFoldResult sequenceTileLength,
-                               OpFoldResult headDimension, Location loc,
-                               OpBuilder &builder) {
-  return extractOrInsertOutputSlice(src, dst, {}, sequenceTileLength,
-                                    headDimension, loc, builder);
-}
-
-static SmallVector<AffineMap>
-getTileAttentionIndexingMaps(RewriterBase &rewriter, int64_t tiledInputRank,
-                             bool transposeV) {
-  MLIRContext *ctx = rewriter.getContext();
-  AffineExpr m, k1, k2, n;
-  bindDims(ctx, m, k1, k2, n);
-
-  AffineMap qMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
-  AffineMap kMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
-  AffineMap vMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
-  AffineMap sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {}, ctx);
-  AffineMap rMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
-  AffineMap maxMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m}, ctx);
-  AffineMap sumMap =
-      AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m}, ctx);
-
-  if (transposeV) {
-    SmallVector<AffineExpr> vDims(vMap.getResults());
-    std::swap(vDims[0], vDims[1]);
-    vMap = AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), vDims, ctx);
-  }
-
-  SmallVector<AffineMap> attentionMaps = {qMap, kMap,   vMap,  sMap,
-                                          rMap, maxMap, sumMap};
-  // Add batches to standard attention indexing maps.
-  int64_t numBatches = tiledInputRank - 2;
-  for (AffineMap &map : attentionMaps) {
-    map = map.shiftDims(numBatches);
-    for (int batch : llvm::seq<int>(numBatches)) {
-      map = map.insertResult(rewriter.getAffineDimExpr(batch), batch);
-    }
-  }
-
-  return attentionMaps;
-}
-
-struct TileAttentionPass final
-    : impl::TileAttentionPassBase<TileAttentionPass> {
-  using impl::TileAttentionPassBase<TileAttentionPass>::TileAttentionPassBase;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<
-        affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
-        linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
-  }
-  TileAttentionPass() = default;
-  TileAttentionPass(bool onlyTile, uint64_t tileSize) {
-    this->tileSize = tileSize;
-  }
-  TileAttentionPass(const TileAttentionPass &pass) { tileSize = pass.tileSize; }
-  void runOnOperation() override;
-};
-
 struct ConvertAttentionToOnlineAttentionPass final
     : impl::ConvertAttentionToOnlineAttentionPassBase<
           ConvertAttentionToOnlineAttentionPass> {
@@ -219,140 +34,6 @@
 
 } // namespace
 
-/// Tile iree_linalg_ext.attention.
-/// TODO: Adopt getTiledImplementation with this.
-IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
-                                           SmallVectorImpl<Operation *> &ops,
-                                           RewriterBase &rewriter,
-                                           std::optional<uint64_t> tileSize) {
-  Location loc = attnOp.getLoc();
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(attnOp);
-
-  Value query = attnOp.getQuery();
-  ShapedType queryType = attnOp.getQueryType();
-  Type elementType = queryType.getElementType();
-  ArrayRef<int64_t> queryShape = queryType.getShape();
-  SmallVector<OpFoldResult> queryDimValues =
-      tensor::getMixedSizes(rewriter, loc, query);
-  OpFoldResult headDimension = queryDimValues[2];
-  OpFoldResult sequenceTileLength = queryDimValues[1];
-  OpFoldResult keyValueTileLength = sequenceTileLength;
-  SmallVector<int64_t> keyShape{queryShape};
-  if (tileSize) {
-    keyValueTileLength = rewriter.getIndexAttr(tileSize.value());
-    for (auto [idx, val] : llvm::enumerate(attnOp.getKeyType().getShape())) {
-      keyShape[idx] = idx == 1 ? tileSize.value() : val;
-    }
-  }
-
-  Value key = attnOp.getKey();
-  Value value = attnOp.getValue();
-  SmallVector<OpFoldResult> keyDimValues =
-      tensor::getMixedSizes(rewriter, loc, key);
-  OpFoldResult sequenceLength = keyDimValues[1];
-
-  // Create output accumulator
-  Value output = attnOp.getOutput();
-  Type f32Type = rewriter.getF32Type();
-  SmallVector<OpFoldResult> accShape{queryDimValues[1], queryDimValues[2]};
-  Value accumulatorF32 =
-      rewriter.create<tensor::EmptyOp>(loc, accShape, f32Type);
-
-  // Create accumulator, max and sum statistics
-  Value outputSlice = extractOutputSlice(output, queryShape, sequenceTileLength,
-                                         headDimension, loc, rewriter);
-  Value zeroF32 =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(f32Type));
-  auto accumulatorFill =
-      rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, accumulatorF32);
-  accumulatorF32 = accumulatorFill.result();
-  ops.push_back(accumulatorFill);
-
-  Value largeNegativeF32 = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getFloatAttr(f32Type, -1.0e+30));
-  SmallVector<OpFoldResult> dims{sequenceTileLength};
-  Value max = rewriter.create<tensor::EmptyOp>(loc, dims, f32Type);
-  auto maxFill =
-      rewriter.create<linalg::FillOp>(loc, ValueRange{largeNegativeF32}, max);
-  Value negativeMax = maxFill.result();
-  ops.push_back(maxFill);
-  Value sum = rewriter.create<tensor::EmptyOp>(loc, dims, f32Type);
-  auto sumFill = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, sum);
-  Value zeroSum = sumFill.result();
-  ops.push_back(sumFill);
-
-  // Construct sequential loop
-  SmallVector<Value> ivs;
-  Value zeroValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  scf::LoopNest loopNest = createLoopNest(
-      ivs, zeroValue,
-      getValueOrCreateConstantIndexOp(rewriter, loc, keyValueTileLength),
-      getValueOrCreateConstantIndexOp(rewriter, loc, sequenceLength),
-      ValueRange({accumulatorF32, negativeMax, zeroSum}), loc, rewriter);
-  ops.push_back(loopNest.loops.back());
-
-  Value iterArgResult = loopNest.loops.back().getRegionIterArg(0);
-  Value iterArgMax = loopNest.loops.back().getRegionIterArg(1);
-  Value iterArgSum = loopNest.loops.back().getRegionIterArg(2);
-
-  OpBuilder::InsertionGuard guardSecondLoop(rewriter);
-  rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
-
-  // Extract slices
-  Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
-                                headDimension, elementType, loc, rewriter);
-  Value valueSlice =
-      extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
-                   elementType, loc, rewriter, attnOp.isTransposeV());
-  Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
-                                  headDimension, elementType, loc, rewriter);
-
-  Value scale = attnOp.getScale();
-
-  int64_t tiledInputRank = cast<ShapedType>(querySlice.getType()).getRank();
-  SmallVector<AffineMap> tiledIndexingMaps = getTileAttentionIndexingMaps(
-      rewriter, tiledInputRank, attnOp.isTransposeV());
-
-  auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
-      attnOp.getLoc(),
-      SmallVector<Type>{accumulatorF32.getType(), sum.getType(), max.getType()},
-      querySlice, keySlice, valueSlice, scale,
-      SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum},
-      rewriter.getAffineMapArrayAttr(tiledIndexingMaps));
-
-  Value tiledResult = tiledAttentionOp.getResult(0);
-  Value newMax = tiledAttentionOp.getResult(1);
-  Value newSum = tiledAttentionOp.getResult(2);
-
-  if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
-          loopNest.loops.back().getBody()->getTerminator())) {
-    OpBuilder::InsertionGuard yieldGuard(rewriter);
-    rewriter.setInsertionPoint(yieldOp);
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(
-        yieldOp, ValueRange{tiledResult, newMax, newSum});
-  }
-
-  OpBuilder::InsertionGuard yieldGuard(rewriter);
-  rewriter.setInsertionPointAfter(loopNest.loops.back());
-
-  loopNest.results[0] = applyFinalScaling(
-      loopNest.results[0], loopNest.results[2], loc, rewriter, ops);
-
-  if (elementType.isF16()) {
-    loopNest.results[0] =
-        truncateToF16(loopNest.results[0], outputSlice, ops, rewriter, loc);
-  }
-  loopNest.results[0] =
-      insertOutputSlice(loopNest.results[0], output, sequenceTileLength,
-                        headDimension, loc, rewriter);
-
-  rewriter.replaceOp(attnOp, loopNest.results[0]);
-  ops.push_back(tiledAttentionOp);
-
-  return tiledAttentionOp;
-}
-
 void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
                               SmallVectorImpl<Operation *> &ops,
                               RewriterBase &rewriter) {
@@ -462,19 +143,6 @@
   rewriter.replaceOp(attnOp, genericOp);
 }
 
-void TileAttentionPass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  IRRewriter rewriter(context);
-  std::optional<uint64_t> optionalTileSize{std::nullopt};
-  if (tileSize.hasValue()) {
-    optionalTileSize = tileSize.getValue();
-  }
-  getOperation().walk([&](AttentionOp attnOp) {
-    SmallVector<Operation *> ops;
-    tileAttention(attnOp, ops, rewriter, optionalTileSize);
-  });
-}
-
 void ConvertAttentionToOnlineAttentionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   IRRewriter rewriter(context);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index 0d88661..efe463a6 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -21,14 +21,12 @@
             "convert_to_loops.mlir",
             "convert_to_online_attention.mlir",
             "decompose_aggregate_op.mlir",
-            "decompose_attention.mlir",
             "decompose_im2col.mlir",
             "decompose_online_attention.mlir",
             "decompose_winograd.mlir",
             "distribution.mlir",
             "pad_contraction_to_block_size.mlir",
             "split_reduction.mlir",
-            "tile_attention.mlir",
             "tiling.mlir",
         ],
         include = ["*.mlir"],
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 1268444..3288c14 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -19,14 +19,12 @@
     "convert_to_loops.mlir"
     "convert_to_online_attention.mlir"
     "decompose_aggregate_op.mlir"
-    "decompose_attention.mlir"
     "decompose_im2col.mlir"
     "decompose_online_attention.mlir"
     "decompose_winograd.mlir"
     "distribution.mlir"
     "pad_contraction_to_block_size.mlir"
     "split_reduction.mlir"
-    "tile_attention.mlir"
     "tiling.mlir"
   TOOLS
     FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
deleted file mode 100644
index 19d6a6b..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
+++ /dev/null
@@ -1,350 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-tile-attention{tileSize=32},iree-linalg-ext-decompose-attention{tileSize=32}),cse)" %s | FileCheck %s
-
-func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
-  %0 = tensor.empty() : tensor<1x1024x64xf32>
-  %scale = arith.constant 0.05 : f32
-  %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
-  return %1 : tensor<1x1024x64xf32>
-}
-
-// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func.func @attention
-// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<1x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
-// CHECK:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
-// CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
-// CHECK-SAME:     tensor<1024x64xf32>
-// CHECK-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
-// CHECK:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
-// CHECK-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
-// CHECK:          %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME:       tensor<1x1024x64xf32> to tensor<32x64xf32>
-// CHECK:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME:       tensor<1x1024x64xf32> to tensor<32x64xf32>
-// CHECK:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
-// CHECK:          %[[SCALE_Q:.+]] = linalg.generic {{.+}} ins(%[[EXTRACTED_SLICE_2]] : tensor<1024x64xf32>)
-// CHECK:          %[[D8:.+]] = tensor.empty() : tensor<1024x32xf32>
-// CHECK:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x32xf32>) ->
-// CHECK-SAME:       tensor<1024x32xf32>
-// CHECK:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
-// CHECK-SAME:       tensor<1024x64xf32>, tensor<32x64xf32>) outs(%[[D9]] : tensor<1024x32xf32>) -> tensor<1024x32xf32>
-// CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D10]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x32xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
-// CHECK:            linalg.yield %[[D19]] : f32
-// CHECK:          } -> tensor<1024x32xf32>
-// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK:            %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK:            linalg.yield %[[D19]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D12]] : tensor<1024x32xf32>) outs(%[[D14]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<1024x64xf32>
-// CHECK:          %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_1]] : tensor<1024x32xf32>,
-// CHECK-SAME:       tensor<32x64xf32>) outs(%[[D16]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK:          scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
-// CHECK:        }
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
-// CHECK-SAME:     {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:          %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// CHECK:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
-// CHECK:          linalg.yield %[[D9]] : f32
-// CHECK:        } -> tensor<1024x64xf32>
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME:     tensor<1024x64xf32> into tensor<1x1024x64xf32>
-// CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
-// CHECK:      }
-
-// -----
-
-func.func @attention(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<?x?x?xf32> {
-  %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
-  %scale = arith.constant 0.05 : f32
-  %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-  return %1 : tensor<?x?x?xf32>
-}
-
-// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL:      func.func @attention
-// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: index,
-// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index) -> tensor<?x?x?xf32> {
-// CHECK:        %[[D0:.+]] = tensor.empty(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?xf32>
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK:        %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK-DAG:    %[[C2:.+]] = arith.constant 2 : index
-// CHECK:        %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK:        %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK:        %[[D1:.+]] = tensor.empty(%[[DIM]], %[[DIM_0]]) : tensor<?x?xf32>
-// CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-DAG:    %[[CST_2:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK:        %[[D3:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
-// CHECK:        %[[D4:.+]] = linalg.fill ins(%[[CST_2]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK-DAG:    %[[C32:.+]] = arith.constant 32 : index
-// CHECK:        %[[D6:.+]]:3 = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[DIM_1]] step %[[C32]]
-// CHECK-SAME:     iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG8:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME:     %[[ARG9:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) {
-// CHECK:          %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1, 1,
-// CHECK-SAME:       1] : tensor<?x?x?xf32> to tensor<32x?xf32>
-// CHECK:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1,
-// CHECK-SAME:       1, 1] : tensor<?x?x?xf32> to tensor<32x?xf32>
-// CHECK:          %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
-// CHECK-SAME:       1] : tensor<?x?x?xf32> to tensor<?x?xf32>
-// CHECK:          %[[DIM_5:.+]] = tensor.dim %[[EXTRACTED_SLICE_4]], %[[C0]] : tensor<?x?xf32>
-// CHECK:          %[[SCALE_Q:.+]] = linalg.generic
-// CHECK:          %[[D8:.+]] = tensor.empty(%[[DIM_5]]) : tensor<?x32xf32>
-// CHECK:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<?x32xf32>) -> tensor<?x32xf32>
-// CHECK:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE]] :
-// CHECK-SAME:       tensor<?x?xf32>, tensor<32x?xf32>) outs(%[[D9]] : tensor<?x32xf32>) -> tensor<?x32xf32>
-// CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D10]] : tensor<?x32xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D11]] : tensor<?xf32>) outs(%[[D10]] : tensor<?x32xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
-// CHECK:            linalg.yield %[[D19]] : f32
-// CHECK:          } -> tensor<?x32xf32>
-// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D11]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK:            %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK:            linalg.yield %[[D19]] : f32
-// CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG9]] : tensor<?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D12]] : tensor<?x32xf32>) outs(%[[D14]] : tensor<?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG7]] : tensor<?x?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<?x?xf32>
-// CHECK:          %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_3]] : tensor<?x32xf32>,
-// CHECK-SAME:       tensor<32x?xf32>) outs(%[[D16]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK:          scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
-// CHECK:        }
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG:      %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:          %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
-// CHECK:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
-// CHECK:          linalg.yield %[[D9]] : f32
-// CHECK:        } -> tensor<?x?xf32>
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
-// CHECK-SAME:     [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
-// CHECK:        return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
-// CHECK:      }
-
-// -----
-
-func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
-  %0 = tensor.empty() : tensor<1x1024x64xf16>
-  %scale = arith.constant 0.05 : f16
-  %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
-  return %1 : tensor<1x1024x64xf16>
-}
-
-// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL:  func.func @attention_f16
-// CHECK-SAME:   (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
-// CHECK:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
-// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME:     tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
-// CHECK-SAME:     tensor<1024x64xf32>
-// CHECK-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
-// CHECK:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
-// CHECK-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
-// CHECK:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME:       tensor<1x1024x64xf16> to tensor<32x64xf16>
-// CHECK:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
-// CHECK-SAME:       tensor<1x1024x64xf16> to tensor<32x64xf16>
-// CHECK:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// CHECK:          %[[SCALE_Q:.+]] = linalg.generic
-// CHECK:          %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
-// CHECK:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
-// CHECK-SAME:       tensor<1024x32xf32>
-// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[SCALE_Q]], %[[EXTRACTED_SLICE_1]] :
-// CHECK-SAME:       tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
-// CHECK-SAME:       tensor<1024x32xf32>
-// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D11]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D21]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x32xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK:            %[[D22:.+]] = math.exp2 %[[D21]] : f32
-// CHECK:            linalg.yield %[[D22]] : f32
-// CHECK:          } -> tensor<1024x32xf32>
-// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
-// CHECK:            %[[D22]] = math.exp2 %[[D21]] : f32
-// CHECK:            linalg.yield %[[D22]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D21]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D15]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D21]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D17:.+]] = tensor.empty() : tensor<1024x32xf16>
-// CHECK:          %[[D18:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D17]] : tensor<1024x32xf16>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// CHECK:            %[[D21]] = arith.truncf %[[IN]] : f32 to f16
-// CHECK:            linalg.yield %[[D21]] : f16
-// CHECK:          } -> tensor<1024x32xf16>
-// CHECK:          %[[D19:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D21]] : f32
-// CHECK:          } -> tensor<1024x64xf32>
-// CHECK:          %[[D20:.+]] = linalg.matmul ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x32xf16>,
-// CHECK-SAME:       tensor<32x64xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK:          scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
-// CHECK:        }
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
-// CHECK-SAME:     {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:          %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
-// CHECK:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
-// CHECK:          linalg.yield %[[D10]] : f32
-// CHECK:        } -> tensor<1024x64xf32>
-// CHECK:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// CHECK:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
-// CHECK:          linalg.yield %[[D9]] : f16
-// CHECK:        } -> tensor<1024x64xf16>
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
-// CHECK-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
-// CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
-// CHECK:      }
-
-// -----
-
-// transpose_V is detected through indexingMap.
-
-func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
-  %0 = tensor.empty() : tensor<1x1024x64xf16>
-  %scale = arith.constant 0.05 : f16
-  %1 = iree_linalg_ext.attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
-  return %1 : tensor<1x1024x64xf16>
-}
-
-// CHECK-LABEL:  func.func @attention_transpose_v
-// There should be two matmul_transpose_b for tranpose_v variant instead of
-// only 1.
-// CHECK: linalg.matmul_transpose_b
-// CHECK-NOT: linalg.matmul
-// CHECK: linalg.matmul_transpose_b
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
deleted file mode 100644
index 51d3c51..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
+++ /dev/null
@@ -1,165 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-tile-attention),cse)" %s | FileCheck %s --check-prefix=CHECK
-
-// TODO: These tests should be moved to tiling.mlir when PartialReductionOpInterface is implemented for attention op.
-
-func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
-  %0 = tensor.empty() : tensor<1x1024x64xf32>
-  %scale = arith.constant 0.05 : f32
-  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
-  return %1 : tensor<1x1024x64xf32>
-}
-
-// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-LABEL:      func.func @attention
-// CHECK-SAME: (%[[QUERY:.+]]: tensor<1x1024x64xf32>, %[[KEY:.+]]: tensor<1x1024x64xf32>, %[[VALUE:.+]]: tensor<1x1024x64xf32>)
-// CHECK-DAG:    %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
-// CHECK-DAG:    %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
-// CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:    %[[CST_1:.+]] = arith.constant 5.000000e-02 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>)
-// CHECK-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
-// CHECK:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>)
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>)
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
-// CHECK:        %[[D6:.+]]:3 = scf.for %[[ARG3:.+]] = %[[C0]] to %[[C1024]] step %[[C1024]]
-// CHECK-SAME:     iter_args(%[[ARG4:.+]] = %[[D2]], %[[ARG5:.+]] = %[[D4]], %[[ARG6:.+]] = %[[D5]])
-// CHECK:          %[[K_S:.+]] = tensor.extract_slice %[[KEY]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1]
-// CHECK:          %[[V_S:.+]] = tensor.extract_slice %[[VALUE]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1]
-// CHECK:          %[[Q_S:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, 1024, 64] [1, 1, 1]
-// CHECK:          %[[ATT:.+]]:3 = iree_linalg_ext.attention
-// CHECK-SAME:                                           ins(%[[Q_S]], %[[K_S]], %[[V_S]], %[[CST_1]]
-// CHECK-SAME:                                           outs(%[[ARG4]], %[[ARG5]], %[[ARG6]]
-// CHECK:          scf.yield %[[ATT]]#0, %[[ATT]]#1, %[[ATT]]#2
-// CHECK:        }
-// CHECK:        %[[D7:.+]] = linalg.generic
-// CHECK-SAME:                {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
-// CHECK-SAME:                 iterator_types = ["parallel", "parallel"]}
-// CHECK-SAME:               ins(%[[D6]]#2 : tensor<1024xf32>)
-// CHECK-SAME:               outs(%[[D6]]#0 : tensor<1024x64xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00
-// CHECK:          %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]]
-// CHECK:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]]
-// CHECK:          linalg.yield %[[D9]]
-// CHECK:        } -> tensor<1024x64xf32>
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]]
-// CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
-// CHECK:      }
-
-// -----
-
-func.func @attention(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %dim0: index, %dim1: index, %dim2: index) -> tensor<?x?x?xf32> {
-  %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
-  %scale = arith.constant 0.05 : f32
-  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-  return %1 : tensor<?x?x?xf32>
-}
-
-// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-LABEL:      func.func @attention(
-// CHECK-SAME:  %[[QUERY:.+]]: tensor<?x?x?xf32>, %[[KEY:.+]]: tensor<?x?x?xf32>, %[[VALUE:.+]]: tensor<?x?x?xf32>,
-// CHECK-SAME:  %[[ARG3:[a-zA-Z0-9_]+]]: index, %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index)
-// CHECK:        %[[D0:.+]] = tensor.empty(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?xf32>
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK:        %[[DIM:.+]] = tensor.dim %[[QUERY]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK-DAG:    %[[C2:.+]] = arith.constant 2 : index
-// CHECK:        %[[DIM_0:.+]] = tensor.dim %[[QUERY]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK:        %[[DIM_1:.+]] = tensor.dim %[[KEY]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK:        %[[D1:.+]] = tensor.empty(%[[DIM]], %[[DIM_0]]) : tensor<?x?xf32>
-// CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-DAG:    %[[CST_2:.+]] = arith.constant -1.000000e+30 : f32
-// CHECK:        %[[D3:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
-// CHECK:        %[[D4:.+]] = linalg.fill ins(%[[CST_2]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
-// CHECK:        %[[D6:.+]]:3 = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[DIM_1]] step %[[DIM]]
-// CHECK-SAME:     iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG8:[a-zA-Z0-9_]+]] = %[[D4]],
-// CHECK-SAME:     %[[ARG9:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) {
-// CHECK:          %[[K_S:.+]] = tensor.extract_slice %[[KEY]][0, %[[ARG6]], 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1]
-// CHECK:          %[[V_S:.+]] = tensor.extract_slice %[[VALUE]][0, %[[ARG6]], 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1]
-// CHECK:          %[[Q_S:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1]
-// CHECK:          %[[ATT:.+]]:3 = iree_linalg_ext.attention
-// CHECK-SAME:                      ins(%[[Q_S]], %[[K_S]], %[[V_S]], %{{[a-z0-1]+}}
-// CHECK-SAME:                      outs(%[[ARG7]], %[[ARG8]], %[[ARG9]]
-// CHECK:          scf.yield %[[ATT]]#0, %[[ATT]]#1, %[[ATT]]#2
-// CHECK:        }
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
-// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK-DAG:      %[[CST_3:.+]] = arith.constant 1.000000e+00
-// CHECK:          %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]]
-// CHECK:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]]
-// CHECK:          linalg.yield %[[D9]]
-// CHECK:        } -> tensor<?x?xf32>
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
-// CHECK-SAME:     [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
-// CHECK:        return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
-// CHECK:      }
-
-// -----
-
-func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
-  %0 = tensor.empty() : tensor<1x1024x64xf16>
-  %scale = arith.constant 0.05 : f16
-  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
-  return %1 : tensor<1x1024x64xf16>
-}
-
-// CHECK-LABEL:  @attention_f16
-
-// CHECK:        scf.for
-// CHECK:          iree_linalg_ext.attention
-// CHECK-SAME:                                           ins(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<1024x64xf16>, tensor<1024x64xf16>, tensor<1024x64xf16>, f16
-// CHECK-SAME:                                           outs(%{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME:                                           -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
-// CHECK:        scf.yield
-
-// CHECK:        linalg.generic
-
-// CHECK:        %[[TRUNCED:.+]] = linalg.generic
-// CHECK:          arith.truncf
-// CHECK:        } -> tensor<1024x64xf16>
-
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[TRUNCED]]
-// CHECK:        return %[[INSERTED_SLICE]]
-// CHECK:      }
-
-// -----
-
-// transpose_V is detected through indexingMap.
-
-func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
-  %0 = tensor.empty() : tensor<1x1024x64xf16>
-  %scale = arith.constant 0.05 : f16
-  %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
-                     affine_map<(d0, d1, d2, d3, d4) -> ()>,
-                     affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
-                     ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
-  return %1 : tensor<1x1024x64xf16>
-}
-// CHECK-LABEL:  func.func @attention_transpose_v
-// CHECK:          scf.for
-// CHECK:            iree_linalg_ext.attention
-// CHECK:          scf.yield