Add FlashAttention v2 (#15527)

This patch adds the algorithmic modifications introduced in FA2 to the
tile and decompose attention pass.

---------

Co-authored-by: Harsh Menon <harsh@nod-labs.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 367809b..2005453 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -40,7 +40,7 @@
 // CHECK-DAG:    %[[CST_0:.+]] = arith.constant dense<-1.000000e+30> : vector<32xf32>
 // CHECK-DAG:    %[[CST_1:.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
 // CHECK-DAG:    %[[CST_2:.+]] = arith.constant dense<0.000000e+00> : vector<32x128xf32>
-// CHECK-DAG:    %[[CST_3:.+]] = arith.constant dense<1.000000e+00> : vector<32xf32>
+// CHECK-DAG:    %[[CST_3:.+]] = arith.constant dense<1.000000e+00> : vector<64x32xf32>
 // CHECK-DAG:    %[[CST_4:.+]] = arith.constant 0.000000e+00 : f16
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:    %[[C128:.+]] = arith.constant 128 : index
@@ -139,13 +139,8 @@
 // CHECK:          %[[D23:.+]] = arith.mulf %[[D22]], %[[ARG2]] : vector<32xf32>
 // CHECK:          %[[D24:.+]] = vector.multi_reduction <add>, %[[D20]], %[[D23]] [1] : vector<32x128xf32> to
 // CHECK-SAME:       vector<32xf32>
-// CHECK:          %[[D25:.+]] = arith.divf %[[CST_3]], %[[D24]] : vector<32xf32>
-// CHECK:          %[[D26:.+]] = vector.broadcast %[[D25]] : vector<32xf32> to vector<128x32xf32>
-// CHECK:          %[[D27:.+]] = vector.transpose %[[D26]], [1, 0] : vector<128x32xf32> to vector<32x128xf32>
-// CHECK:          %[[D28:.+]] = arith.mulf %[[D20]], %[[D27]] : vector<32x128xf32>
-// CHECK:          %[[D29:.+]] = arith.truncf %[[D28]] : vector<32x128xf32> to vector<32x128xf16>
-// CHECK:          %[[D30:.+]] = arith.mulf %[[D23]], %[[D25]] : vector<32xf32>
-// CHECK:          %[[D31:.+]] = vector.broadcast %[[D30]] : vector<32xf32> to vector<64x32xf32>
+// CHECK:          %[[D29:.+]] = arith.truncf %[[D20]] : vector<32x128xf32> to vector<32x128xf16>
+// CHECK:          %[[D31:.+]] = vector.broadcast %[[D22]] : vector<32xf32> to vector<64x32xf32>
 // CHECK:          %[[D33:.+]] = vector.transpose %[[D31]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
 // CHECK:          %[[D34:.+]] = arith.mulf %[[D33]], %[[ARG3]] : vector<32x64xf32>
 // CHECK:          %[[D35:.+]] = vector.transfer_read %[[ALLOC_11]][%[[C0]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
@@ -159,7 +154,11 @@
 // CHECK:          gpu.barrier
 // CHECK:          scf.yield %[[D16]], %[[D24]], %[[D39]] : vector<32xf32>, vector<32xf32>, vector<32x64xf32>
 // CHECK:        }
-// CHECK:        %[[D12:.+]] = arith.truncf %[[D11]]#[[D2:.+]] : vector<32x64xf32> to vector<32x64xf16>
+// CHECK:        %[[DSCALE1:.+]] = vector.broadcast %[[D11]]#1 : vector<32xf32> to vector<64x32xf32>
+// CHECK:        %[[DSCALE2:.+]] = arith.divf %[[CST_3]], %[[DSCALE1]] : vector<64x32xf32>
+// CHECK:        %[[DSCALE3:.+]] = vector.transpose %[[DSCALE2]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
+// CHECK:        %[[DSCALE4:.+]] = arith.mulf %[[DSCALE3]], %[[D11]]#2 : vector<32x64xf32>
+// CHECK:        %[[D12:.+]] = arith.truncf %[[DSCALE4]] : vector<32x64xf32> to vector<32x64xf16>
 // CHECK:        vector.transfer_write %[[D12]], %[[ALLOC_7]][%[[C0]], %[[D8]], %[[C0]]] {in_bounds = [true, true]} :
 // CHECK-SAME:     vector<32x64xf16>, memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>
 // CHECK:        gpu.barrier
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
index 03e1cd6..e33e07e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
@@ -1,4 +1,3 @@
-
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.consumed}) {
     // Get attention op
@@ -31,11 +30,11 @@
     // Tile and decompose attention
     // ==========================================
     %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %acc_fill, %max_fill, %sum_fill, %inner_loop, %last_truncate, %blocked_attention = transform.tile_attention %attention4 :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %truncate, %scale_acc, %second_matmul
+    %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.tile_attention %attention4 :
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
         = transform.decompose_tiled_attention %blocked_attention :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 
     // Promote key and value operands
     // ==========================================
@@ -47,32 +46,54 @@
     // Tile and fuse attention ops
     // ==========================================
     %tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
 
     %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f2, %loop2 = transform.structured.fuse_into_containing_op %softmax into %loop1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
 
     %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.apply_cse %func : !transform.any_op
 
-    %f3, %loop3 = transform.structured.fuse_into_containing_op %reciprocal_sum into %loop2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f4, %loop4 = transform.structured.fuse_into_containing_op %reduce_sum into %loop3 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.iree.apply_cse %func : !transform.any_op
 
-    %f5, %loop5 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.iree.apply_cse %func : !transform.any_op
+
+    %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.iree.apply_cse %func : !transform.any_op
 
     %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
     transform.iree.apply_cse %func : !transform.any_op
 
-    // Distribute fills and last truncate
+    %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.iree.apply_cse %func : !transform.any_op
+
+    // Distribute fills
     // ==========================================
-    %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill, %last_truncate : !transform.any_op
+    %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill : !transform.any_op
     %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
+    // Distribute last_truncate and fuse final_scaling into it
+    // ==========================================
+    %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.iree.apply_cse %func : !transform.any_op
+
     // Vectorize function
     // ==========================================
     transform.apply_patterns to %func {
@@ -126,4 +147,4 @@
     transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> ()
     transform.yield 
   }
-} // module
+} ////  module
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 430320a..778aeb0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -74,22 +74,38 @@
   return genericOp.getResult(0);
 }
 
-static Value updateAndScale(Value oldMax, Value newMax, Value oldSum,
-                            Location loc, OpBuilder &builder,
+/// Return the scale factor for the new softmax maximum and add the generic to
+/// the provided list of operations.
+static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
+                                OpBuilder &builder,
+                                SmallVectorImpl<Operation *> &ops) {
+  SmallVector<utils::IteratorType> iteratorTypes(1,
+                                                 utils::IteratorType::parallel);
+  auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
+  SmallVector<AffineMap> indexingMaps(2, identityMap);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
+        Value weight = b.create<math::Exp2Op>(loc, diff);
+        b.create<linalg::YieldOp>(loc, weight);
+      });
+  ops.push_back(genericOp);
+  return genericOp.getResult(0);
+}
+
+static Value updateAndScale(Value scaleFactor, Value oldSum, Location loc,
+                            OpBuilder &builder,
                             SmallVectorImpl<Operation *> &ops) {
   SmallVector<utils::IteratorType> iteratorTypes(1,
                                                  utils::IteratorType::parallel);
   auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
-  SmallVector<AffineMap> indexingMaps(3, identityMap);
-  SmallVector<Type> resultTypes{oldSum.getType()};
+  SmallVector<AffineMap> indexingMaps(2, identityMap);
   auto genericOp = builder.create<linalg::GenericOp>(
-      loc, resultTypes, ValueRange{oldMax, newMax}, ValueRange{oldSum},
-      indexingMaps, iteratorTypes,
+      loc, oldSum.getType(), scaleFactor, oldSum, indexingMaps, iteratorTypes,
       [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
-        Value weight = b.create<math::Exp2Op>(loc, diff);
-        Value scaledOldSum = b.create<arith::MulFOp>(loc, weight, args[2]);
-        b.create<linalg::YieldOp>(loc, ValueRange{scaledOldSum});
+        Value scaledOldSum = b.create<arith::MulFOp>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, scaledOldSum);
       });
   ops.push_back(genericOp);
   return genericOp.getResult(0);
@@ -117,28 +133,33 @@
   return genericOp.getResult(0);
 }
 
-static Value computeReciprocal(Value x, Location loc, OpBuilder &builder,
+static Value applyFinalScaling(Value result, Value newSum, Location loc,
+                               OpBuilder &builder,
                                SmallVectorImpl<Operation *> &ops) {
   AffineMap identityMap =
-      AffineMap::getMultiDimIdentityMap(1, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{identityMap};
-  SmallVector<utils::IteratorType> iteratorTypes(1,
+      AffineMap::getMultiDimIdentityMap(2, builder.getContext());
+  AffineExpr d0, d1;
+  bindDims(builder.getContext(), d0, d1);
+  // (d0, d1) -> (d0)
+  auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
+  SmallVector<AffineMap> indexingMaps = {rowMap, identityMap};
+  SmallVector<utils::IteratorType> iteratorTypes(2,
                                                  utils::IteratorType::parallel);
   auto genericOp = builder.create<linalg::GenericOp>(
-      loc, x.getType(), ValueRange{}, x, indexingMaps, iteratorTypes,
+      loc, result.getType(), newSum, result, indexingMaps, iteratorTypes,
       [&](OpBuilder &b, Location loc, ValueRange args) {
         Value one = b.create<arith::ConstantOp>(
             loc, b.getFloatAttr(args[0].getType(), 1.0));
-        Value result = b.create<arith::DivFOp>(loc, one, args[0]);
+        Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
+        Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
         b.create<linalg::YieldOp>(loc, result);
       });
   ops.push_back(genericOp);
   return genericOp.getResult(0);
 }
 
-static Value scaleAccumulator(Value accumulator, Value scaledOldSum,
-                              Value inverseNewSum, Location loc,
-                              OpBuilder &builder,
+static Value scaleAccumulator(Value accumulator, Value scaleFactor,
+                              Location loc, OpBuilder &builder,
                               SmallVectorImpl<Operation *> &ops) {
   AffineMap identityMap =
       AffineMap::getMultiDimIdentityMap(2, builder.getContext());
@@ -146,15 +167,13 @@
   bindDims(builder.getContext(), d0, d1);
   // (d0, d1) -> (d0)
   auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
-  SmallVector<AffineMap> indexingMaps{rowMap, rowMap, identityMap};
+  SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
   SmallVector<utils::IteratorType> iteratorTypes(2,
                                                  utils::IteratorType::parallel);
   auto genericOp = builder.create<linalg::GenericOp>(
-      loc, accumulator.getType(), ValueRange{scaledOldSum, inverseNewSum},
-      accumulator, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value ratio = b.create<arith::MulFOp>(loc, args[0], args[1]);
-        Value result = b.create<arith::MulFOp>(loc, ratio, args[2]);
+      loc, accumulator.getType(), scaleFactor, accumulator, indexingMaps,
+      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
         b.create<linalg::YieldOp>(loc, result);
       });
   ops.push_back(genericOp);
@@ -254,26 +273,24 @@
       qkTranspose, maxSlice, loc, builder, ops);
   Value partialSoftmax =
       computePartialSoftmax(qkTranspose, newMax, loc, builder, ops);
-  Value scaledOldSum =
-      updateAndScale(maxSlice, newMax, sumSlice, loc, builder, ops);
+  Value scaleFactor = computeScaleFactor(maxSlice, newMax, loc, builder, ops);
+  Value scaledOldSum = updateAndScale(scaleFactor, sumSlice, loc, builder, ops);
   Value newSum = computeRowwiseReduction<arith::AddFOp>(
       partialSoftmax, scaledOldSum, loc, builder, ops);
-  Value inverseNewSum = computeReciprocal(newSum, loc, builder, ops);
-  Value softmax =
-      scalePartialSoftmax(partialSoftmax, inverseNewSum, loc, builder, ops);
   if (elementType.isF16()) {
     Value empty =
         builder.create<tensor::EmptyOp>(loc, resultShape, builder.getF16Type());
-    softmax = truncateToF16(softmax, empty, ops, builder, loc);
+    partialSoftmax = truncateToF16(partialSoftmax, empty, ops, builder, loc);
   }
 
   // Update accumulator
-  Value scaledAcc = scaleAccumulator(outputSlice, scaledOldSum, inverseNewSum,
-                                     loc, builder, ops);
+  Value scaledAcc =
+      scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
 
   // Compute matmul(softmax, v)
   auto matmulOp = builder.create<linalg::MatmulOp>(
-      loc, scaledAcc.getType(), ValueRange{softmax, valueSlice}, scaledAcc);
+      loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
+      scaledAcc);
   ops.push_back(matmulOp);
   Value result = matmulOp.getResult(0);
   return std::make_tuple(result, newMax, newSum);
@@ -417,6 +434,10 @@
 
   OpBuilder::InsertionGuard yieldGuard(rewriter);
   rewriter.setInsertionPointAfter(loopNest.loops.back());
+
+  loopNest.results[0] = applyFinalScaling(
+      loopNest.results[0], loopNest.results[2], loc, rewriter, ops);
+
   if (elementType.isF16()) {
     loopNest.results[0] =
         truncateToF16(loopNest.results[0], outputSlice, ops, rewriter, loc);
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 3561d8f..322f8ab 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -7,6 +7,8 @@
   return %1 : tensor<1x1024x64xf32>
 }
 
+// TILING-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILING-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
 // TILING-LABEL: @attention
 // TILING-SAME: (%[[QUERY:.+]]: tensor<1x1024x64xf32>, %[[KEY:.+]]: tensor<1x1024x64xf32>, %[[VALUE:.+]]: tensor<1x1024x64xf32>)
 // TILING:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
@@ -34,8 +36,17 @@
 // TILING-SAME:                                           -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:          scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:        }
-// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1,
-// TILING-SAME:     1, 1] : tensor<1024x64xf32> into tensor<1x1024x64xf32>
+// TILING:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME:     "parallel"]} ins(%[[D6]]#2 : tensor<1024xf32>) outs(%[[D6]]#0 : tensor<1024x64xf32>)
+// TILING-SAME:     {
+// TILING:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILING:          %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILING:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILING:          linalg.yield %[[D9]] : f32
+// TILING:        } -> tensor<1024x64xf32>
+// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME:    tensor<1024x64xf32> into tensor<1x1024x64xf32>
 // TILING:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
 // TILING:      }
 
@@ -64,69 +75,66 @@
 // CHECK-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
 // CHECK:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // CHECK-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
-// CHECK:          %[[D7:.+]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK:          %[[D8:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D7]] : tensor<1024x1024xf32>) ->
+// CHECK:          %[[D8:.+]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D9:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
-// CHECK-SAME:       tensor<1024x64xf32>, tensor<1024x64xf32>) outs(%[[D8]] : tensor<1024x1024xf32>) ->
+// CHECK:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
+// CHECK-SAME:       tensor<1024x64xf32>, tensor<1024x64xf32>) outs(%[[D9]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D10:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D9]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D10]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D18]] : f32
 // CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D10]] : tensor<1024xf32>) outs(%[[D9]] : tensor<1024x1024xf32>) {
+// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
 // CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<1024x1024xf32>
-// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]], iterator_types =
-// CHECK-SAME:       ["parallel"]} ins(%[[ARG5]], %[[D10]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[ARG6]] :
-// CHECK-SAME:       tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[IN_3:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.subf %[[IN]], %[[IN_3]] : f32
-// CHECK:            %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK:            %[[D20:.+]] = arith.mulf %[[D19]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D20]] : f32
+// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D18:.+]] = arith.subf %[[OUT]], %[[IN]] : f32
+// CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
+// CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D11]] : tensor<1024x1024xf32>) outs(%[[D12]] : tensor<1024xf32>) {
+// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D18]] : f32
+// CHECK:          } -> tensor<1024xf32>
+// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D12]] : tensor<1024x1024xf32>) outs(%[[D14]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D18]] : f32
 // CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       outs(%[[D13]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[OUT:.+]]: f32):
-// CHECK-DAG:        %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:            %[[D18]] = arith.divf %[[CST_3]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x1024xf32>) {
+// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[OUT]], %[[IN]] : f32
+// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<1024x1024xf32>
-// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP]]], iterator_types =
-// CHECK-SAME:       ["parallel", "parallel"]} ins(%[[D12]], %[[D14]] : tensor<1024xf32>, tensor<1024xf32>)
-// CHECK-SAME:       outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[IN_3:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[IN_3]] : f32
-// CHECK:            %[[D19]] = arith.mulf %[[D18]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<1024x64xf32>
-// CHECK:          %[[D17:.+]] = linalg.matmul ins(%[[D15]], %[[EXTRACTED_SLICE_1]] : tensor<1024x1024xf32>,
+// CHECK:          %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_1]] : tensor<1024x1024xf32>,
 // CHECK-SAME:       tensor<1024x64xf32>) outs(%[[D16]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK:          scf.yield %[[D17]], %[[D10]], %[[D13]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// CHECK:          scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // CHECK:        }
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1,
-// CHECK-SAME:     1, 1] : tensor<1024x64xf32> into tensor<1x1024x64xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// CHECK-SAME:     {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:          %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:        } -> tensor<1024x64xf32>
+// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME:     tensor<1024x64xf32> into tensor<1x1024x64xf32>
 // CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
-// CHECK:      }
 
 // -----
 
@@ -167,8 +175,16 @@
 // TILING-SAME:                      -> tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
 // TILING:          scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
 // TILING:        }
-// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, %[[DIM]],
-// TILING-SAME:     %[[DIM_0]]] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// TILING:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
+// TILING:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG:      %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
+// TILING:          %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// TILING:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILING:          linalg.yield %[[D9]] : f32
+// TILING:        } -> tensor<?x?xf32>
+// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
+// TILING-SAME:     [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
 // TILING:        return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
 // TILING:      }
 
@@ -212,55 +228,52 @@
 // CHECK:            %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D18]] : f32
 // CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
 // CHECK-SAME:       "parallel"]} ins(%[[D10]] : tensor<?xf32>) outs(%[[D9]] : tensor<?x?xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
 // CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<?x?xf32>
-// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]], iterator_types =
-// CHECK-SAME:       ["parallel"]} ins(%[[ARG8]], %[[D10]] : tensor<?xf32>, tensor<?xf32>) outs(%[[ARG9]] :
-// CHECK-SAME:       tensor<?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[IN_5:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.subf %[[IN]], %[[IN_5]] : f32
+// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D10]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK:            %[[D20:.+]] = arith.mulf %[[D19]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D20]] : f32
+// CHECK:            linalg.yield %[[D19]] : f32
 // CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D11]] : tensor<?x?xf32>) outs(%[[D12]] : tensor<?xf32>) {
+// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG9]] : tensor<?xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D18]] : f32
+// CHECK:          } -> tensor<?xf32>
+// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D12]] : tensor<?x?xf32>) outs(%[[D14]] : tensor<?xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D18]] : f32
 // CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       outs(%[[D13]] : tensor<?xf32>) {
-// CHECK:          ^bb0(%[[OUT:.+]]: f32):
-// CHECK-DAG:        %[[CST_5:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:            %[[D18]] = arith.divf %[[CST_5]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D18]] : f32
-// CHECK:          } -> tensor<?xf32>
-// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D14]] : tensor<?xf32>) outs(%[[D11]] : tensor<?x?xf32>) {
+// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG7]] : tensor<?x?xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[OUT]], %[[IN]] : f32
+// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D18]] : f32
 // CHECK:          } -> tensor<?x?xf32>
-// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP]]], iterator_types =
-// CHECK-SAME:       ["parallel", "parallel"]} ins(%[[D12]], %[[D14]] : tensor<?xf32>, tensor<?xf32>) outs(%[[ARG7]] :
-// CHECK-SAME:       tensor<?x?xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[IN_5:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D18]] = arith.mulf %[[IN]], %[[IN_5]] : f32
-// CHECK:            %[[D19]] = arith.mulf %[[D18]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D19]] : f32
-// CHECK:          } -> tensor<?x?xf32>
-// CHECK:          %[[D17:.+]] = linalg.matmul ins(%[[D15]], %[[EXTRACTED_SLICE_3]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK:          %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_3]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME:       outs(%[[D16]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK:          scf.yield %[[D17]], %[[D10]], %[[D13]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
+// CHECK:          scf.yield %[[D17]], %[[D10]], %[[D15]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
 // CHECK:        }
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, %[[DIM]],
-// CHECK-SAME:     %[[DIM_0]]] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG:      %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:          %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// CHECK:          %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:        } -> tensor<?x?xf32>
+// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
+// CHECK-SAME:     [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
 // CHECK:        return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
 // CHECK:      }
 
@@ -302,14 +315,22 @@
 // TILING-SAME:                                           -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:          scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // TILING:        }
-// TILING:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
-// TILING-SAME:     "parallel"]} ins(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] :
-// TILING-SAME:     tensor<1024x64xf16>) {
+// TILING:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILING-SAME:     {
+// TILING:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILING:          %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILING:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILING:          linalg.yield %[[D10]] : f32
+// TILING:        } -> tensor<1024x64xf32>
+// TILING:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
 // TILING:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// TILING:          %[[D8:.+]] = arith.truncf %[[IN]] : f32 to f16
-// TILING:          linalg.yield %[[D8]] : f16
+// TILING:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILING:          linalg.yield %[[D9]] : f16
 // TILING:        } -> tensor<1024x64xf16>
-// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // TILING-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
 // TILING:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
 // TILING:      }
@@ -341,80 +362,77 @@
 // CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
 // CHECK:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// CHECK:          %[[D8:.+]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x1024xf32>) ->
+// CHECK:          %[[D9:.+]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
-// CHECK-SAME:       tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D9]] : tensor<1024x1024xf32>) ->
+// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// CHECK-SAME:       tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D10]] : tensor<1024x1024xf32>) ->
 // CHECK-SAME:       tensor<1024x1024xf32>
-// CHECK:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D10]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D11]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D21]] : f32
 // CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x1024xf32>) {
+// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D22:.+]] = math.exp2 %[[D21]] : f32
 // CHECK:            linalg.yield %[[D22]] : f32
 // CHECK:          } -> tensor<1024x1024xf32>
-// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]], iterator_types =
-// CHECK-SAME:       ["parallel"]} ins(%[[ARG5]], %[[D11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[ARG6]] :
-// CHECK-SAME:       tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.subf %[[IN]], %[[IN_4]] : f32
+// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
 // CHECK:            %[[D22]] = math.exp2 %[[D21]] : f32
-// CHECK:            %[[D23:.+]] = arith.mulf %[[D22]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D23]] : f32
+// CHECK:            linalg.yield %[[D22]] : f32
 // CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:       "reduction"]} ins(%[[D12]] : tensor<1024x1024xf32>) outs(%[[D13]] : tensor<1024xf32>) {
+// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D21]] : f32
+// CHECK:          } -> tensor<1024xf32>
+// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D15]] : tensor<1024xf32>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:            %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
 // CHECK:            linalg.yield %[[D21]] : f32
 // CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME:       outs(%[[D14]] : tensor<1024xf32>) {
-// CHECK:          ^bb0(%[[OUT:.+]]: f32):
-// CHECK-DAG:        %[[CST_4:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:            %[[D21]] = arith.divf %[[CST_4]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D21]] : f32
-// CHECK:          } -> tensor<1024xf32>
-// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D15]] : tensor<1024xf32>) outs(%[[D12]] : tensor<1024x1024xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.mulf %[[OUT]], %[[IN]] : f32
-// CHECK:            linalg.yield %[[D21]] : f32
-// CHECK:          } -> tensor<1024x1024xf32>
 // CHECK:          %[[D17:.+]] = tensor.empty() : tensor<1024x1024xf16>
 // CHECK:          %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:       "parallel"]} ins(%[[D16]] : tensor<1024x1024xf32>) outs(%[[D17]] : tensor<1024x1024xf16>) {
+// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D17]] : tensor<1024x1024xf16>) {
 // CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
 // CHECK:            %[[D21]] = arith.truncf %[[IN]] : f32 to f16
 // CHECK:            linalg.yield %[[D21]] : f16
 // CHECK:          } -> tensor<1024x1024xf16>
-// CHECK:          %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP]]], iterator_types =
-// CHECK-SAME:       ["parallel", "parallel"]} ins(%[[D13]], %[[D15]] : tensor<1024xf32>, tensor<1024xf32>)
-// CHECK-SAME:       outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[IN_4]] : f32
-// CHECK:            %[[D22]] = arith.mulf %[[D21]], %[[OUT]] : f32
-// CHECK:            linalg.yield %[[D22]] : f32
+// CHECK:          %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D21]] : f32
 // CHECK:          } -> tensor<1024x64xf32>
 // CHECK:          %[[D20:.+]] = linalg.matmul ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x1024xf16>,
 // CHECK-SAME:       tensor<1024x64xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK:          scf.yield %[[D20]], %[[D11]], %[[D14]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// CHECK:          scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
 // CHECK:        }
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] :
-// CHECK-SAME:     tensor<1024x64xf16>) {
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// CHECK-SAME:     {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D10]] : f32
+// CHECK:        } -> tensor<1024x64xf32>
+// CHECK:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// CHECK:          %[[D8]] = arith.truncf %[[IN]] : f32 to f16
-// CHECK:          linalg.yield %[[D8]] : f16
+// CHECK:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// CHECK:          linalg.yield %[[D9]] : f16
 // CHECK:        } -> tensor<1024x64xf16>
-// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
 // CHECK-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
 // CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
-// CHECK:      }
\ No newline at end of file
+// CHECK:      }
diff --git a/tests/transform_dialect/cpu/attention_codegen_spec.mlir b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
index 3601d92..73ffab9 100644
--- a/tests/transform_dialect/cpu/attention_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
@@ -20,11 +20,11 @@
     // Tile and decompose attention
     // ==========================================
     %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    %acc_fill, %max_fill, %sum_fill, %inner_loop, %blocked_attention = transform.tile_attention %attention4 :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %scale_acc, %second_matmul
+    %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %blocked_attention = transform.tile_attention %attention4 :
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %scale_acc, %second_matmul
         = transform.decompose_tiled_attention %blocked_attention :
-      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 
     // Vectorize function
     // ==========================================