[LinalgExt] Expose attention tile size parameter (#16030)

This patch exposes the tile size for the innermost loop in flash
attention.
Currently, the the step size is taken from the workgroup tile size. This
PR allows specifying the tile size. A value of 0 defaults to the
workgroup tile size.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index 7257cb1..1a008d7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -116,17 +116,21 @@
 std::unique_ptr<Pass> createConvertConv2DToWinogradPass();
 
 // Transform dialect version of tile and decompose attention wrapper.
+// The optional tile size specifies the step for the innermost for loop.
 void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
                                SmallVectorImpl<Operation *> &ops,
-                               RewriterBase &rewriter, bool onlyTile = false);
+                               RewriterBase &rewriter, bool onlyTile = false,
+                               std::optional<uint64_t> tileSize = std::nullopt);
 
-IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
-                                           SmallVectorImpl<Operation *> &ops,
-                                           RewriterBase &rewriter);
+IREE::LinalgExt::AttentionOp
+tileAttention(IREE::LinalgExt::AttentionOp attnOp,
+              SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
+              std::optional<uint64_t> tileSize = std::nullopt);
 
 void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
                              SmallVectorImpl<Operation *> &ops,
-                             RewriterBase &rewriter);
+                             RewriterBase &rewriter,
+                             std::optional<uint64_t> tileSize = std::nullopt);
 
 // Creates a pass to convert the attention op into a sequence of
 // linalg ops.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
index 55ee53f..f212f3c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
@@ -80,6 +80,8 @@
   let options = [
     Option<"onlyTile", "onlyTile", "bool", /*default=*/"false",
            "Choose whether to only tile or go through till decomposition">,
+    Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
+           "Tile size for sequential for loop in attention">,
   ];
 }
 
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index ad171c4..c6fdce9 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -117,7 +117,8 @@
   }];
 
   let arguments = (
-      ins TransformHandleTypeInterface:$target
+      ins TransformHandleTypeInterface:$target,
+          OptionalAttr<I64Attr>:$tile_size
   );
   let results = (outs Variadic<TransformHandleTypeInterface>:$result);
 
@@ -153,7 +154,8 @@
   }];
 
   let arguments = (
-      ins TransformHandleTypeInterface:$target
+      ins TransformHandleTypeInterface:$target,
+          OptionalAttr<I64Attr>:$tile_size
   );
   let results = (outs Variadic<TransformHandleTypeInterface>:$result);
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 778aeb0..93c7919 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -192,30 +192,24 @@
   return matmulOp.getResult(0);
 }
 
-static std::tuple<Value, Value, Value>
-extractSlices(Value key, Value value, Value query, ArrayRef<int64_t> queryShape,
-              ArrayRef<Value> ivs, OpFoldResult sequenceTileLength,
-              OpFoldResult headDimension, Type elementType, Location loc,
-              OpBuilder &builder) {
+static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
+                          ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
+                          OpFoldResult headDimension, Type elementType,
+                          Location loc, OpBuilder &builder) {
   auto one = builder.getIndexAttr(1);
   auto zero = builder.getIndexAttr(0);
-  SmallVector<OpFoldResult> strides(queryShape.size(), one);
-  SmallVector<OpFoldResult> sizes(queryShape.size(), one);
-  SmallVector<OpFoldResult> offsets(queryShape.size(), zero);
-  sizes[1] = sequenceTileLength;
+  SmallVector<OpFoldResult> strides(keyShape.size(), one);
+  SmallVector<OpFoldResult> sizes(keyShape.size(), one);
+  SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
+  sizes[1] = keyValueTileLength;
   sizes[2] = headDimension;
-  offsets[1] = ivs[0];
-  SmallVector<int64_t> tensorShape{queryShape[1], queryShape[2]};
+  if (!ivs.empty())
+    offsets[1] = ivs[0];
+  SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
   auto tensorType = RankedTensorType::get(tensorShape, elementType);
   Value keySlice = builder.create<tensor::ExtractSliceOp>(
       loc, tensorType, key, offsets, sizes, strides);
-  Value valueSlice = builder.create<tensor::ExtractSliceOp>(
-      loc, tensorType, value, offsets, sizes, strides);
-
-  offsets = SmallVector<OpFoldResult>(queryShape.size(), zero);
-  Value querySlice = builder.create<tensor::ExtractSliceOp>(
-      loc, tensorType, query, offsets, sizes, strides);
-  return std::make_tuple(keySlice, valueSlice, querySlice);
+  return keySlice;
 }
 
 static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
@@ -254,7 +248,8 @@
 static std::tuple<Value, Value, Value>
 createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
                     Value outputSlice, Value maxSlice, Value sumSlice,
-                    OpFoldResult sequenceTileLength, OpFoldResult headDimension,
+                    OpFoldResult sequenceTileLength,
+                    OpFoldResult keyValueTileLength, OpFoldResult headDimension,
                     Type elementType, SmallVectorImpl<Operation *> &ops,
                     Location loc, OpBuilder &builder) {
 
@@ -262,7 +257,7 @@
   // Compute matmul(q, transpose(k))
   Value zero =
       builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(f32Type));
-  SmallVector<OpFoldResult> resultShape{sequenceTileLength, sequenceTileLength};
+  SmallVector<OpFoldResult> resultShape{sequenceTileLength, keyValueTileLength};
   Value emptySquare =
       builder.create<tensor::EmptyOp>(loc, resultShape, f32Type);
   Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
@@ -342,7 +337,8 @@
 /// TODO: Adopt getTiledImplementation with this.
 IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
                                            SmallVectorImpl<Operation *> &ops,
-                                           RewriterBase &rewriter) {
+                                           RewriterBase &rewriter,
+                                           std::optional<uint64_t> tileSize) {
   Location loc = attnOp.getLoc();
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(attnOp);
@@ -355,6 +351,14 @@
       tensor::getMixedSizes(rewriter, loc, query);
   OpFoldResult headDimension = queryDimValues[2];
   OpFoldResult sequenceTileLength = queryDimValues[1];
+  OpFoldResult keyValueTileLength = sequenceTileLength;
+  SmallVector<int64_t> keyShape{queryShape};
+  if (tileSize) {
+    keyValueTileLength = rewriter.getIndexAttr(tileSize.value());
+    for (auto it : llvm::enumerate(attnOp.getKeyType().getShape())) {
+      keyShape[it.index()] = it.index() == 1 ? tileSize.value() : it.value();
+    }
+  }
 
   Value key = attnOp.getKey();
   Value value = attnOp.getValue();
@@ -397,7 +401,7 @@
   Value zeroValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   scf::LoopNest loopNest = createLoopNest(
       ivs, zeroValue,
-      getValueOrCreateConstantIndexOp(rewriter, loc, sequenceTileLength),
+      getValueOrCreateConstantIndexOp(rewriter, loc, keyValueTileLength),
       getValueOrCreateConstantIndexOp(rewriter, loc, sequenceLength),
       ValueRange({accumulatorF32, negativeMax, zeroSum}), loc, rewriter);
   ops.push_back(loopNest.loops.back());
@@ -410,9 +414,12 @@
   rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
 
   // Extract slices
-  auto [keySlice, valueSlice, querySlice] =
-      extractSlices(key, value, query, queryShape, ivs, sequenceTileLength,
-                    headDimension, elementType, loc, rewriter);
+  Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
+                                headDimension, elementType, loc, rewriter);
+  Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength,
+                                  headDimension, elementType, loc, rewriter);
+  Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
+                                  headDimension, elementType, loc, rewriter);
 
   auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
       attnOp.getLoc(),
@@ -456,7 +463,8 @@
 /// TODO: Adopt decomposeOperation with this.
 void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
                              SmallVectorImpl<Operation *> &ops,
-                             RewriterBase &rewriter) {
+                             RewriterBase &rewriter,
+                             std::optional<uint64_t> tileSize) {
   Location loc = tiledAttnOp.getLoc();
   Value keySlice = tiledAttnOp.getKey();
   Value valueSlice = tiledAttnOp.getValue();
@@ -474,11 +482,14 @@
       tensor::getMixedSizes(rewriter, loc, querySlice);
   OpFoldResult headDimension = queryDimValues[1];
   OpFoldResult sequenceTileLength = queryDimValues[0];
+  OpFoldResult keyValueTileLength =
+      tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
 
   Type elementType = tiledAttnOp.getQueryType().getElementType();
-  auto [result, newMax, newSum] = createAttentionBody(
-      keySlice, valueSlice, querySlice, tiledResult, max, sum,
-      sequenceTileLength, headDimension, elementType, ops, loc, rewriter);
+  auto [result, newMax, newSum] =
+      createAttentionBody(keySlice, valueSlice, querySlice, tiledResult, max,
+                          sum, sequenceTileLength, keyValueTileLength,
+                          headDimension, elementType, ops, loc, rewriter);
 
   rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
 }
@@ -487,12 +498,13 @@
 /// FlashAttention algorithm.
 void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
                                SmallVectorImpl<Operation *> &ops,
-                               RewriterBase &rewriter, bool onlyTile) {
+                               RewriterBase &rewriter, bool onlyTile,
+                               std::optional<uint64_t> tileSize) {
   IREE::LinalgExt::AttentionOp tiledAttentionOp =
-      tileAttention(attnOp, ops, rewriter);
+      tileAttention(attnOp, ops, rewriter, tileSize);
   if (onlyTile)
     return;
-  decomposeTiledAttention(tiledAttentionOp, ops, rewriter);
+  decomposeTiledAttention(tiledAttentionOp, ops, rewriter, tileSize);
 }
 
 namespace {
@@ -524,11 +536,12 @@
 ///    j. Compute matmul(s, v) and add new_accumulator
 ///
 ///
-LogicalResult reifyAttentionTransform(func::FuncOp funcOp, bool onlyTile) {
+LogicalResult reifyAttentionTransform(func::FuncOp funcOp, bool onlyTile,
+                                      std::optional<uint64_t> tileSize) {
   IRRewriter rewriter(funcOp.getContext());
   funcOp.walk([&](IREE::LinalgExt::AttentionOp attnOp) {
     SmallVector<Operation *> ops;
-    tileAndDecomposeAttention(attnOp, ops, rewriter, onlyTile);
+    tileAndDecomposeAttention(attnOp, ops, rewriter, onlyTile, tileSize);
     return WalkResult::advance();
   });
   return success();
@@ -545,9 +558,13 @@
         linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
   }
   TileAndDecomposeAttentionPass() = default;
-  TileAndDecomposeAttentionPass(bool onlyTile) { this->onlyTile = onlyTile; }
+  TileAndDecomposeAttentionPass(bool onlyTile, uint64_t tileSize) {
+    this->onlyTile = onlyTile;
+    this->tileSize = tileSize;
+  }
   TileAndDecomposeAttentionPass(const TileAndDecomposeAttentionPass &pass) {
     onlyTile = pass.onlyTile;
+    tileSize = pass.tileSize;
   }
   void runOnOperation() override;
 };
@@ -556,7 +573,11 @@
 void TileAndDecomposeAttentionPass::runOnOperation() {
   MLIRContext *context = &getContext();
   IRRewriter rewriter(context);
-  if (failed(reifyAttentionTransform(getOperation(), onlyTile)))
+  std::optional<uint64_t> optionalTileSize{std::nullopt};
+  if (tileSize.hasValue())
+    optionalTileSize = tileSize.getValue();
+  if (failed(
+          reifyAttentionTransform(getOperation(), onlyTile, optionalTileSize)))
     return signalPassFailure();
 }
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index c72c1a7..fa18c40 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -166,7 +166,7 @@
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   SmallVector<Operation *> ops;
-  LinalgExt::tileAttention(attentionOp, ops, rewriter);
+  LinalgExt::tileAttention(attentionOp, ops, rewriter, getTileSize());
   for (auto op : ops)
     results.push_back(op);
   return DiagnosedSilenceableFailure::success();
@@ -177,7 +177,7 @@
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   SmallVector<Operation *> ops;
-  LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter);
+  LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter, getTileSize());
   for (auto op : ops)
     results.push_back(op);
   return DiagnosedSilenceableFailure::success();
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 4691721..a169a8b 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -1,5 +1,6 @@
 // RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-tile-and-decompose-attention -cse %s | FileCheck %s
 // RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-tile-and-decompose-attention=onlyTile -cse %s | FileCheck %s --check-prefix=TILING
+// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-tile-and-decompose-attention="tileSize=32" -cse %s | FileCheck %s --check-prefix=TILESIZE
 
 func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
   %0 = tensor.empty() : tensor<1x1024x64xf32>
@@ -7,6 +8,93 @@
   return %1 : tensor<1x1024x64xf32>
 }
 
+// TILESIZE-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME:   tensor<1x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
+// TILESIZE:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
+// TILESIZE:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILESIZE-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILESIZE-SAME:     tensor<1024x64xf32>
+// TILESIZE-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILESIZE:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG:    %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// TILESIZE:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// TILESIZE-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILESIZE:          %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf32> to tensor<32x64xf32>
+// TILESIZE:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf32> to tensor<32x64xf32>
+// TILESIZE:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf32> to tensor<1024x64xf32>
+// TILESIZE:          %[[D8:.+]] = tensor.empty() : tensor<1024x32xf32>
+// TILESIZE:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME:       tensor<1024x32xf32>
+// TILESIZE:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
+// TILESIZE-SAME:       tensor<1024x64xf32>, tensor<32x64xf32>) outs(%[[D9]] : tensor<1024x32xf32>) -> tensor<1024x32xf32>
+// TILESIZE:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D10]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x32xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
+// TILESIZE:            linalg.yield %[[D19]] : f32
+// TILESIZE:          } -> tensor<1024x32xf32>
+// TILESIZE:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D19]] = math.exp2 %[[D18]] : f32
+// TILESIZE:            linalg.yield %[[D19]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D12]] : tensor<1024x32xf32>) outs(%[[D14]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<1024x64xf32>
+// TILESIZE:          %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_1]] : tensor<1024x32xf32>,
+// TILESIZE-SAME:       tensor<32x64xf32>) outs(%[[D16]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// TILESIZE:          scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILESIZE:        }
+// TILESIZE:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILESIZE-SAME:     {
+// TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE:          %[[D8]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE:          %[[D9]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILESIZE:          linalg.yield %[[D9]] : f32
+// TILESIZE:        } -> tensor<1024x64xf32>
+// TILESIZE:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:     tensor<1024x64xf32> into tensor<1x1024x64xf32>
+// TILESIZE:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
+// TILESIZE:      }
+
 // TILING-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // TILING-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
 // TILING-LABEL: @attention
@@ -144,6 +232,96 @@
   return %1 : tensor<?x?x?xf32>
 }
 
+// TILESIZE-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME:   tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: index,
+// TILESIZE-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index) -> tensor<?x?x?xf32> {
+// TILESIZE:        %[[D0:.+]] = tensor.empty(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?xf32>
+// TILESIZE-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG:    %[[C1:.+]] = arith.constant 1 : index
+// TILESIZE:        %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+// TILESIZE-DAG:    %[[C2:.+]] = arith.constant 2 : index
+// TILESIZE:        %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+// TILESIZE:        %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
+// TILESIZE:        %[[D1:.+]] = tensor.empty(%[[DIM]], %[[DIM_0]]) : tensor<?x?xf32>
+// TILESIZE-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TILESIZE-DAG:    %[[CST_2:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE:        %[[D3:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// TILESIZE:        %[[D4:.+]] = linalg.fill ins(%[[CST_2]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
+// TILESIZE:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
+// TILESIZE-DAG:    %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE:        %[[D6:.+]]:3 = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[DIM_1]] step %[[C32]]
+// TILESIZE-SAME:     iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG8:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME:     %[[ARG9:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) {
+// TILESIZE:          %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1, 1,
+// TILESIZE-SAME:       1] : tensor<?x?x?xf32> to tensor<32x?xf32>
+// TILESIZE:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1,
+// TILESIZE-SAME:       1, 1] : tensor<?x?x?xf32> to tensor<32x?xf32>
+// TILESIZE:          %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
+// TILESIZE-SAME:       1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+// TILESIZE:          %[[DIM_5:.+]] = tensor.dim %[[EXTRACTED_SLICE_4]], %[[C0]] : tensor<?x?xf32>
+// TILESIZE:          %[[D8:.+]] = tensor.empty(%[[DIM_5]]) : tensor<?x32xf32>
+// TILESIZE:          %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<?x32xf32>) -> tensor<?x32xf32>
+// TILESIZE:          %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE]] :
+// TILESIZE-SAME:       tensor<?x?xf32>, tensor<32x?xf32>) outs(%[[D9]] : tensor<?x32xf32>) -> tensor<?x32xf32>
+// TILESIZE:          %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D10]] : tensor<?x32xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<?xf32>
+// TILESIZE:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D11]] : tensor<?xf32>) outs(%[[D10]] : tensor<?x32xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D19:.+]] = math.exp2 %[[D18]] : f32
+// TILESIZE:            linalg.yield %[[D19]] : f32
+// TILESIZE:          } -> tensor<?x32xf32>
+// TILESIZE:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D11]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D19]] = math.exp2 %[[D18]] : f32
+// TILESIZE:            linalg.yield %[[D19]] : f32
+// TILESIZE:          } -> tensor<?xf32>
+// TILESIZE:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG9]] : tensor<?xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<?xf32>
+// TILESIZE:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D12]] : tensor<?x32xf32>) outs(%[[D14]] : tensor<?xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<?xf32>
+// TILESIZE:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG7]] : tensor<?x?xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D18]] : f32
+// TILESIZE:          } -> tensor<?x?xf32>
+// TILESIZE:          %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_3]] : tensor<?x32xf32>,
+// TILESIZE-SAME:       tensor<32x?xf32>) outs(%[[D16]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TILESIZE:          scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
+// TILESIZE:        }
+// TILESIZE:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
+// TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG:      %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE:          %[[D8]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// TILESIZE:          %[[D9]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILESIZE:          linalg.yield %[[D9]] : f32
+// TILESIZE:        } -> tensor<?x?xf32>
+// TILESIZE:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
+// TILESIZE-SAME:     [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// TILESIZE:        return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
+// TILESIZE:      }
+
 // TILING:      @attention(
 // TILING-SAME:  %[[QUERY:.+]]: tensor<?x?x?xf32>, %[[KEY:.+]]: tensor<?x?x?xf32>, %[[VALUE:.+]]: tensor<?x?x?xf32>,
 // TILING-SAME:  %[[ARG3:[a-zA-Z0-9_]+]]: index, %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index)
@@ -285,6 +463,109 @@
   return %1 : tensor<1x1024x64xf16>
 }
 
+// TILESIZE-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME:   tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
+// TILESIZE:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// TILESIZE:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILESIZE:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:     tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILESIZE-SAME:     tensor<1024x64xf32>
+// TILESIZE-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILESIZE:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG:    %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// TILESIZE:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// TILESIZE-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILESIZE:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<32x64xf16>
+// TILESIZE:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<32x64xf16>
+// TILESIZE:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE:          %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
+// TILESIZE:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME:       tensor<1024x32xf32>
+// TILESIZE:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// TILESIZE-SAME:       tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME:       tensor<1024x32xf32>
+// TILESIZE:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D11]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x32xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D22:.+]] = math.exp2 %[[D21]] : f32
+// TILESIZE:            linalg.yield %[[D22]] : f32
+// TILESIZE:          } -> tensor<1024x32xf32>
+// TILESIZE:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D22]] = math.exp2 %[[D21]] : f32
+// TILESIZE:            linalg.yield %[[D22]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D15]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D17:.+]] = tensor.empty() : tensor<1024x32xf16>
+// TILESIZE:          %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D17]] : tensor<1024x32xf16>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE:            %[[D21]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE:            linalg.yield %[[D21]] : f16
+// TILESIZE:          } -> tensor<1024x32xf16>
+// TILESIZE:          %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024x64xf32>
+// TILESIZE:          %[[D20:.+]] = linalg.matmul ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x32xf16>,
+// TILESIZE-SAME:       tensor<32x64xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// TILESIZE:          scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILESIZE:        }
+// TILESIZE:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILESIZE-SAME:     {
+// TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILESIZE:          linalg.yield %[[D10]] : f32
+// TILESIZE:        } -> tensor<1024x64xf32>
+// TILESIZE:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE:          linalg.yield %[[D9]] : f16
+// TILESIZE:        } -> tensor<1024x64xf16>
+// TILESIZE:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// TILESIZE:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
+// TILESIZE:      }
+
 // TILING-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 // TILING:      @attention(
 // TILING-SAME:  %[[QUERY:.+]]: tensor<1x1024x64xf16>, %[[KEY:.+]]: tensor<1x1024x64xf16>, %[[VALUE:.+]]: tensor<1x1024x64xf16>)