[LinalgExt] Do not decompose attention op with manual analysis. (#16525)

diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 4107c55..4bff8f0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -100,10 +100,8 @@
       createFuseTensorPadWithConsumerPass());
   nestedModulePM.addNestedPass<func::FuncOp>(
       createConcretizePadResultShapePass());
-  // TODO(#16421) Decomposition is failing during bufferization.
-  // Disable till fixed
-  // nestedModulePM.addNestedPass<func::FuncOp>(
-  //    IREE::LinalgExt::createTileAndDecomposeAttentionPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(
+      IREE::LinalgExt::createTileAndDecomposeAttentionPass());
   nestedModulePM.addNestedPass<func::FuncOp>(
       IREE::LinalgExt::createTileAndDecomposeWinogradTransformPass());
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index cff5a3e..ded6cb0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -113,9 +113,8 @@
   nestedModulePM.addNestedPass<func::FuncOp>(
       createConvertToDestinationPassingStylePass(
           useWARForCooperativeMatrixCodegen));
-  // TODO(#16421): Disable decomposition due to failure in bufferization.
-  // nestedModulePM.addNestedPass<func::FuncOp>(
-  //     IREE::LinalgExt::createTileAndDecomposeAttentionPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(
+      IREE::LinalgExt::createTileAndDecomposeAttentionPass());
   nestedModulePM.addPass(createCanonicalizerPass());
   nestedModulePM.addPass(createCSEPass());
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index ce450ba..48bd221 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -132,7 +132,10 @@
 // CHECK:          %[[D18:.+]] = vector.transpose %[[D17]], [1, 0] : vector<128x32xf32> to vector<32x128xf32>
 // CHECK:          %[[D19:.+]] = arith.subf %[[D15]], %[[D18]] : vector<32x128xf32>
 // CHECK:          %[[D20:.+]] = math.exp2 %[[D19]] : vector<32x128xf32>
-// CHECK:          %[[D21:.+]] = arith.subf %[[ARG1]], %[[D16]] : vector<32xf32>
+// CHECK:          %[[ALLOC_12:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32xf32,
+// CHECK-SAME:       #[[GPU]].address_space<workgroup>>
+// CHECK:          %[[READ_ALLOC_12:.+]] = vector.transfer_read %[[ALLOC_12]]
+// CHECK:          %[[D21:.+]] = arith.subf %[[READ_ALLOC_12]], %[[D16]] : vector<32xf32>
 // CHECK:          %[[D22:.+]] = math.exp2 %[[D21]] : vector<32xf32>
 // CHECK:          %[[D23:.+]] = arith.mulf %[[D22]], %[[ARG2]] : vector<32xf32>
 // CHECK:          %[[D24:.+]] = vector.multi_reduction <add>, %[[D20]], %[[D23]] [1] : vector<32x128xf32> to
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 539f20b..c2232c8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/Support/Debug.h"
+#include <mlir/IR/TypeUtilities.h>
 
 namespace mlir {
 namespace iree_compiler {
@@ -79,12 +80,20 @@
 static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
                                 OpBuilder &builder,
                                 SmallVectorImpl<Operation *> &ops) {
+  // Create tensor.empty op. This is a workaround to avoid a failure in
+  // bufferization. For more details, please see
+  // https://github.com/openxla/iree/issues/16421
+  SmallVector<OpFoldResult> initShape =
+      tensor::getMixedSizes(builder, loc, newMax);
+  Value init = builder.create<tensor::EmptyOp>(
+      loc, initShape, getElementTypeOrSelf(oldMax.getType()));
+
   SmallVector<utils::IteratorType> iteratorTypes(1,
                                                  utils::IteratorType::parallel);
   auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
   SmallVector<AffineMap> indexingMaps(2, identityMap);
   auto genericOp = builder.create<linalg::GenericOp>(
-      loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
+      loc, init.getType(), newMax, init, indexingMaps, iteratorTypes,
       [&](OpBuilder &b, Location loc, ValueRange args) {
         Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
         Value weight = b.create<math::Exp2Op>(loc, diff);
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 3339eb5..3531fb1 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -53,7 +53,7 @@
 // TILESIZE:            linalg.yield %[[D19]] : f32
 // TILESIZE:          } -> tensor<1024x32xf32>
 // TILESIZE:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
 // TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
 // TILESIZE:            %[[D19]] = math.exp2 %[[D18]] : f32
@@ -183,7 +183,7 @@
 // CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<1024x1024xf32>
 // CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18:.+]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
@@ -280,8 +280,10 @@
 // TILESIZE:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
 // TILESIZE:            linalg.yield %[[D19]] : f32
 // TILESIZE:          } -> tensor<?x32xf32>
+// TILESIZE:          %[[DIM6:.+]] = tensor.dim %[[D11]], %[[C0]]
+// TILESIZE:          %[[INIT:.+]] = tensor.empty(%[[DIM6]])
 // TILESIZE:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME:       ins(%[[D11]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// TILESIZE-SAME:       ins(%[[D11]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>) {
 // TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
 // TILESIZE:            %[[D19]] = math.exp2 %[[D18]] : f32
@@ -413,8 +415,10 @@
 // CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
 // CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<?x?xf32>
+// CHECK:          %[[DIM6:.+]] = tensor.dim %[[D11]], %[[C0]]
+// CHECK:          %[[INIT:.+]] = tensor.empty(%[[DIM6]])
 // CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D10]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// CHECK-SAME:       ins(%[[D10]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D19]] = math.exp2 %[[D18]] : f32
@@ -511,7 +515,7 @@
 // TILESIZE:            linalg.yield %[[D22]] : f32
 // TILESIZE:          } -> tensor<1024x32xf32>
 // TILESIZE:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
 // TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
 // TILESIZE:            %[[D22]] = math.exp2 %[[D21]] : f32
@@ -663,7 +667,7 @@
 // CHECK:            linalg.yield %[[D22]] : f32
 // CHECK:          } -> tensor<1024x1024xf32>
 // CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D22]] = math.exp2 %[[D21]] : f32
@@ -774,7 +778,7 @@
 // TILESIZE:            linalg.yield %[[D22]] : f32
 // TILESIZE:          } -> tensor<1024x32xf32>
 // TILESIZE:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// TILESIZE-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
 // TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // TILESIZE:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
 // TILESIZE:            %[[D22]] = math.exp2 %[[D21]] : f32
@@ -926,7 +930,7 @@
 // CHECK:            linalg.yield %[[D22]] : f32
 // CHECK:          } -> tensor<1024x1024xf32>
 // CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[D3]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D22]] = math.exp2 %[[D21]] : f32