[LinalgExt] Expose attention tile size parameter (#16030)
This patch exposes the tile size for the innermost loop in flash
attention.
Currently, the the step size is taken from the workgroup tile size. This
PR allows specifying the tile size. A value of 0 defaults to the
workgroup tile size.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index 7257cb1..1a008d7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -116,17 +116,21 @@
std::unique_ptr<Pass> createConvertConv2DToWinogradPass();
// Transform dialect version of tile and decompose attention wrapper.
+// The optional tile size specifies the step for the innermost for loop.
void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter, bool onlyTile = false);
+ RewriterBase &rewriter, bool onlyTile = false,
+ std::optional<uint64_t> tileSize = std::nullopt);
-IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
- SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter);
+IREE::LinalgExt::AttentionOp
+tileAttention(IREE::LinalgExt::AttentionOp attnOp,
+ SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
+ std::optional<uint64_t> tileSize = std::nullopt);
void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter);
+ RewriterBase &rewriter,
+ std::optional<uint64_t> tileSize = std::nullopt);
// Creates a pass to convert the attention op into a sequence of
// linalg ops.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
index 55ee53f..f212f3c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
@@ -80,6 +80,8 @@
let options = [
Option<"onlyTile", "onlyTile", "bool", /*default=*/"false",
"Choose whether to only tile or go through till decomposition">,
+ Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
+ "Tile size for sequential for loop in attention">,
];
}
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index ad171c4..c6fdce9 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -117,7 +117,8 @@
}];
let arguments = (
- ins TransformHandleTypeInterface:$target
+ ins TransformHandleTypeInterface:$target,
+ OptionalAttr<I64Attr>:$tile_size
);
let results = (outs Variadic<TransformHandleTypeInterface>:$result);
@@ -153,7 +154,8 @@
}];
let arguments = (
- ins TransformHandleTypeInterface:$target
+ ins TransformHandleTypeInterface:$target,
+ OptionalAttr<I64Attr>:$tile_size
);
let results = (outs Variadic<TransformHandleTypeInterface>:$result);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 778aeb0..93c7919 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -192,30 +192,24 @@
return matmulOp.getResult(0);
}
-static std::tuple<Value, Value, Value>
-extractSlices(Value key, Value value, Value query, ArrayRef<int64_t> queryShape,
- ArrayRef<Value> ivs, OpFoldResult sequenceTileLength,
- OpFoldResult headDimension, Type elementType, Location loc,
- OpBuilder &builder) {
+static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
+ ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
+ OpFoldResult headDimension, Type elementType,
+ Location loc, OpBuilder &builder) {
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
- SmallVector<OpFoldResult> strides(queryShape.size(), one);
- SmallVector<OpFoldResult> sizes(queryShape.size(), one);
- SmallVector<OpFoldResult> offsets(queryShape.size(), zero);
- sizes[1] = sequenceTileLength;
+ SmallVector<OpFoldResult> strides(keyShape.size(), one);
+ SmallVector<OpFoldResult> sizes(keyShape.size(), one);
+ SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
+ sizes[1] = keyValueTileLength;
sizes[2] = headDimension;
- offsets[1] = ivs[0];
- SmallVector<int64_t> tensorShape{queryShape[1], queryShape[2]};
+ if (!ivs.empty())
+ offsets[1] = ivs[0];
+ SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
auto tensorType = RankedTensorType::get(tensorShape, elementType);
Value keySlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, key, offsets, sizes, strides);
- Value valueSlice = builder.create<tensor::ExtractSliceOp>(
- loc, tensorType, value, offsets, sizes, strides);
-
- offsets = SmallVector<OpFoldResult>(queryShape.size(), zero);
- Value querySlice = builder.create<tensor::ExtractSliceOp>(
- loc, tensorType, query, offsets, sizes, strides);
- return std::make_tuple(keySlice, valueSlice, querySlice);
+ return keySlice;
}
static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
@@ -254,7 +248,8 @@
static std::tuple<Value, Value, Value>
createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
Value outputSlice, Value maxSlice, Value sumSlice,
- OpFoldResult sequenceTileLength, OpFoldResult headDimension,
+ OpFoldResult sequenceTileLength,
+ OpFoldResult keyValueTileLength, OpFoldResult headDimension,
Type elementType, SmallVectorImpl<Operation *> &ops,
Location loc, OpBuilder &builder) {
@@ -262,7 +257,7 @@
// Compute matmul(q, transpose(k))
Value zero =
builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(f32Type));
- SmallVector<OpFoldResult> resultShape{sequenceTileLength, sequenceTileLength};
+ SmallVector<OpFoldResult> resultShape{sequenceTileLength, keyValueTileLength};
Value emptySquare =
builder.create<tensor::EmptyOp>(loc, resultShape, f32Type);
Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
@@ -342,7 +337,8 @@
/// TODO: Adopt getTiledImplementation with this.
IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ std::optional<uint64_t> tileSize) {
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(attnOp);
@@ -355,6 +351,14 @@
tensor::getMixedSizes(rewriter, loc, query);
OpFoldResult headDimension = queryDimValues[2];
OpFoldResult sequenceTileLength = queryDimValues[1];
+ OpFoldResult keyValueTileLength = sequenceTileLength;
+ SmallVector<int64_t> keyShape{queryShape};
+ if (tileSize) {
+ keyValueTileLength = rewriter.getIndexAttr(tileSize.value());
+ for (auto it : llvm::enumerate(attnOp.getKeyType().getShape())) {
+ keyShape[it.index()] = it.index() == 1 ? tileSize.value() : it.value();
+ }
+ }
Value key = attnOp.getKey();
Value value = attnOp.getValue();
@@ -397,7 +401,7 @@
Value zeroValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
scf::LoopNest loopNest = createLoopNest(
ivs, zeroValue,
- getValueOrCreateConstantIndexOp(rewriter, loc, sequenceTileLength),
+ getValueOrCreateConstantIndexOp(rewriter, loc, keyValueTileLength),
getValueOrCreateConstantIndexOp(rewriter, loc, sequenceLength),
ValueRange({accumulatorF32, negativeMax, zeroSum}), loc, rewriter);
ops.push_back(loopNest.loops.back());
@@ -410,9 +414,12 @@
rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
// Extract slices
- auto [keySlice, valueSlice, querySlice] =
- extractSlices(key, value, query, queryShape, ivs, sequenceTileLength,
- headDimension, elementType, loc, rewriter);
+ Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
+ headDimension, elementType, loc, rewriter);
+ Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength,
+ headDimension, elementType, loc, rewriter);
+ Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
+ headDimension, elementType, loc, rewriter);
auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
attnOp.getLoc(),
@@ -456,7 +463,8 @@
/// TODO: Adopt decomposeOperation with this.
void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ std::optional<uint64_t> tileSize) {
Location loc = tiledAttnOp.getLoc();
Value keySlice = tiledAttnOp.getKey();
Value valueSlice = tiledAttnOp.getValue();
@@ -474,11 +482,14 @@
tensor::getMixedSizes(rewriter, loc, querySlice);
OpFoldResult headDimension = queryDimValues[1];
OpFoldResult sequenceTileLength = queryDimValues[0];
+ OpFoldResult keyValueTileLength =
+ tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
Type elementType = tiledAttnOp.getQueryType().getElementType();
- auto [result, newMax, newSum] = createAttentionBody(
- keySlice, valueSlice, querySlice, tiledResult, max, sum,
- sequenceTileLength, headDimension, elementType, ops, loc, rewriter);
+ auto [result, newMax, newSum] =
+ createAttentionBody(keySlice, valueSlice, querySlice, tiledResult, max,
+ sum, sequenceTileLength, keyValueTileLength,
+ headDimension, elementType, ops, loc, rewriter);
rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
}
@@ -487,12 +498,13 @@
/// FlashAttention algorithm.
void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
- RewriterBase &rewriter, bool onlyTile) {
+ RewriterBase &rewriter, bool onlyTile,
+ std::optional<uint64_t> tileSize) {
IREE::LinalgExt::AttentionOp tiledAttentionOp =
- tileAttention(attnOp, ops, rewriter);
+ tileAttention(attnOp, ops, rewriter, tileSize);
if (onlyTile)
return;
- decomposeTiledAttention(tiledAttentionOp, ops, rewriter);
+ decomposeTiledAttention(tiledAttentionOp, ops, rewriter, tileSize);
}
namespace {
@@ -524,11 +536,12 @@
/// j. Compute matmul(s, v) and add new_accumulator
///
///
-LogicalResult reifyAttentionTransform(func::FuncOp funcOp, bool onlyTile) {
+LogicalResult reifyAttentionTransform(func::FuncOp funcOp, bool onlyTile,
+ std::optional<uint64_t> tileSize) {
IRRewriter rewriter(funcOp.getContext());
funcOp.walk([&](IREE::LinalgExt::AttentionOp attnOp) {
SmallVector<Operation *> ops;
- tileAndDecomposeAttention(attnOp, ops, rewriter, onlyTile);
+ tileAndDecomposeAttention(attnOp, ops, rewriter, onlyTile, tileSize);
return WalkResult::advance();
});
return success();
@@ -545,9 +558,13 @@
linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
}
TileAndDecomposeAttentionPass() = default;
- TileAndDecomposeAttentionPass(bool onlyTile) { this->onlyTile = onlyTile; }
+ TileAndDecomposeAttentionPass(bool onlyTile, uint64_t tileSize) {
+ this->onlyTile = onlyTile;
+ this->tileSize = tileSize;
+ }
TileAndDecomposeAttentionPass(const TileAndDecomposeAttentionPass &pass) {
onlyTile = pass.onlyTile;
+ tileSize = pass.tileSize;
}
void runOnOperation() override;
};
@@ -556,7 +573,11 @@
void TileAndDecomposeAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
- if (failed(reifyAttentionTransform(getOperation(), onlyTile)))
+ std::optional<uint64_t> optionalTileSize{std::nullopt};
+ if (tileSize.hasValue())
+ optionalTileSize = tileSize.getValue();
+ if (failed(
+ reifyAttentionTransform(getOperation(), onlyTile, optionalTileSize)))
return signalPassFailure();
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index c72c1a7..fa18c40 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -166,7 +166,7 @@
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
SmallVector<Operation *> ops;
- LinalgExt::tileAttention(attentionOp, ops, rewriter);
+ LinalgExt::tileAttention(attentionOp, ops, rewriter, getTileSize());
for (auto op : ops)
results.push_back(op);
return DiagnosedSilenceableFailure::success();
@@ -177,7 +177,7 @@
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
SmallVector<Operation *> ops;
- LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter);
+ LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter, getTileSize());
for (auto op : ops)
results.push_back(op);
return DiagnosedSilenceableFailure::success();
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 4691721..a169a8b 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -1,5 +1,6 @@
// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-tile-and-decompose-attention -cse %s | FileCheck %s
// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-tile-and-decompose-attention=onlyTile -cse %s | FileCheck %s --check-prefix=TILING
+// RUN: iree-dialects-opt --split-input-file -iree-linalg-ext-tile-and-decompose-attention="tileSize=32" -cse %s | FileCheck %s --check-prefix=TILESIZE
func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
%0 = tensor.empty() : tensor<1x1024x64xf32>
@@ -7,6 +8,93 @@
return %1 : tensor<1x1024x64xf32>
}
+// TILESIZE-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME: tensor<1x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> {
+// TILESIZE: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
+// TILESIZE: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILESIZE-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILESIZE-SAME: tensor<1024x64xf32>
+// TILESIZE-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILESIZE: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG: %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// TILESIZE: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// TILESIZE-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILESIZE: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf32> to tensor<32x64xf32>
+// TILESIZE: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf32> to tensor<32x64xf32>
+// TILESIZE: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf32> to tensor<1024x64xf32>
+// TILESIZE: %[[D8:.+]] = tensor.empty() : tensor<1024x32xf32>
+// TILESIZE: %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME: tensor<1024x32xf32>
+// TILESIZE: %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
+// TILESIZE-SAME: tensor<1024x64xf32>, tensor<32x64xf32>) outs(%[[D9]] : tensor<1024x32xf32>) -> tensor<1024x32xf32>
+// TILESIZE: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D10]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x32xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D19:.+]] = math.exp2 %[[D18]] : f32
+// TILESIZE: linalg.yield %[[D19]] : f32
+// TILESIZE: } -> tensor<1024x32xf32>
+// TILESIZE: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D19]] = math.exp2 %[[D18]] : f32
+// TILESIZE: linalg.yield %[[D19]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D12]] : tensor<1024x32xf32>) outs(%[[D14]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<1024x64xf32>
+// TILESIZE: %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_1]] : tensor<1024x32xf32>,
+// TILESIZE-SAME: tensor<32x64xf32>) outs(%[[D16]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// TILESIZE: scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILESIZE: }
+// TILESIZE: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILESIZE-SAME: {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE: %[[D8]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE: %[[D9]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D9]] : f32
+// TILESIZE: } -> tensor<1024x64xf32>
+// TILESIZE: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1024x64xf32> into tensor<1x1024x64xf32>
+// TILESIZE: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
+// TILESIZE: }
+
// TILING-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// TILING-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// TILING-LABEL: @attention
@@ -144,6 +232,96 @@
return %1 : tensor<?x?x?xf32>
}
+// TILESIZE-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME: tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: index,
+// TILESIZE-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index) -> tensor<?x?x?xf32> {
+// TILESIZE: %[[D0:.+]] = tensor.empty(%[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?xf32>
+// TILESIZE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG: %[[C1:.+]] = arith.constant 1 : index
+// TILESIZE: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+// TILESIZE-DAG: %[[C2:.+]] = arith.constant 2 : index
+// TILESIZE: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+// TILESIZE: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
+// TILESIZE: %[[D1:.+]] = tensor.empty(%[[DIM]], %[[DIM_0]]) : tensor<?x?xf32>
+// TILESIZE-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TILESIZE-DAG: %[[CST_2:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE: %[[D3:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// TILESIZE: %[[D4:.+]] = linalg.fill ins(%[[CST_2]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
+// TILESIZE: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<?xf32>) -> tensor<?xf32>
+// TILESIZE-DAG: %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE: %[[D6:.+]]:3 = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[DIM_1]] step %[[C32]]
+// TILESIZE-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG8:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME: %[[ARG9:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) {
+// TILESIZE: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1, 1,
+// TILESIZE-SAME: 1] : tensor<?x?x?xf32> to tensor<32x?xf32>
+// TILESIZE: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG6]], 0] [1, 32, %[[DIM_0]]] [1,
+// TILESIZE-SAME: 1, 1] : tensor<?x?x?xf32> to tensor<32x?xf32>
+// TILESIZE: %[[EXTRACTED_SLICE_4:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1,
+// TILESIZE-SAME: 1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+// TILESIZE: %[[DIM_5:.+]] = tensor.dim %[[EXTRACTED_SLICE_4]], %[[C0]] : tensor<?x?xf32>
+// TILESIZE: %[[D8:.+]] = tensor.empty(%[[DIM_5]]) : tensor<?x32xf32>
+// TILESIZE: %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<?x32xf32>) -> tensor<?x32xf32>
+// TILESIZE: %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_4]], %[[EXTRACTED_SLICE]] :
+// TILESIZE-SAME: tensor<?x?xf32>, tensor<32x?xf32>) outs(%[[D9]] : tensor<?x32xf32>) -> tensor<?x32xf32>
+// TILESIZE: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D10]] : tensor<?x32xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<?xf32>
+// TILESIZE: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D11]] : tensor<?xf32>) outs(%[[D10]] : tensor<?x32xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D19:.+]] = math.exp2 %[[D18]] : f32
+// TILESIZE: linalg.yield %[[D19]] : f32
+// TILESIZE: } -> tensor<?x32xf32>
+// TILESIZE: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D11]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D19]] = math.exp2 %[[D18]] : f32
+// TILESIZE: linalg.yield %[[D19]] : f32
+// TILESIZE: } -> tensor<?xf32>
+// TILESIZE: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG9]] : tensor<?xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<?xf32>
+// TILESIZE: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D12]] : tensor<?x32xf32>) outs(%[[D14]] : tensor<?xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<?xf32>
+// TILESIZE: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG7]] : tensor<?x?xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D18]] : f32
+// TILESIZE: } -> tensor<?x?xf32>
+// TILESIZE: %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_3]] : tensor<?x32xf32>,
+// TILESIZE-SAME: tensor<32x?xf32>) outs(%[[D16]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TILESIZE: scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
+// TILESIZE: }
+// TILESIZE: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG: %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE: %[[D8]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// TILESIZE: %[[D9]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D9]] : f32
+// TILESIZE: } -> tensor<?x?xf32>
+// TILESIZE: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
+// TILESIZE-SAME: [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// TILESIZE: return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
+// TILESIZE: }
+
// TILING: @attention(
// TILING-SAME: %[[QUERY:.+]]: tensor<?x?x?xf32>, %[[KEY:.+]]: tensor<?x?x?xf32>, %[[VALUE:.+]]: tensor<?x?x?xf32>,
// TILING-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index, %[[ARG4:[a-zA-Z0-9_]+]]: index, %[[ARG5:[a-zA-Z0-9_]+]]: index)
@@ -285,6 +463,109 @@
return %1 : tensor<1x1024x64xf16>
}
+// TILESIZE-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME: tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> {
+// TILESIZE: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// TILESIZE: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILESIZE: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILESIZE-SAME: tensor<1024x64xf32>
+// TILESIZE-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILESIZE: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG: %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// TILESIZE: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// TILESIZE-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILESIZE: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<32x64xf16>
+// TILESIZE: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<32x64xf16>
+// TILESIZE: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE: %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
+// TILESIZE: %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME: tensor<1024x32xf32>
+// TILESIZE: %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// TILESIZE-SAME: tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME: tensor<1024x32xf32>
+// TILESIZE: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D11]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x32xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D22:.+]] = math.exp2 %[[D21]] : f32
+// TILESIZE: linalg.yield %[[D22]] : f32
+// TILESIZE: } -> tensor<1024x32xf32>
+// TILESIZE: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D22]] = math.exp2 %[[D21]] : f32
+// TILESIZE: linalg.yield %[[D22]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D15]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D17:.+]] = tensor.empty() : tensor<1024x32xf16>
+// TILESIZE: %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D17]] : tensor<1024x32xf16>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE: %[[D21]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE: linalg.yield %[[D21]] : f16
+// TILESIZE: } -> tensor<1024x32xf16>
+// TILESIZE: %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024x64xf32>
+// TILESIZE: %[[D20:.+]] = linalg.matmul ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x32xf16>,
+// TILESIZE-SAME: tensor<32x64xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// TILESIZE: scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILESIZE: }
+// TILESIZE: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILESIZE-SAME: {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE: %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE: %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D10]] : f32
+// TILESIZE: } -> tensor<1024x64xf32>
+// TILESIZE: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE: linalg.yield %[[D9]] : f16
+// TILESIZE: } -> tensor<1024x64xf16>
+// TILESIZE: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// TILESIZE: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
+// TILESIZE: }
+
// TILING-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// TILING: @attention(
// TILING-SAME: %[[QUERY:.+]]: tensor<1x1024x64xf16>, %[[KEY:.+]]: tensor<1x1024x64xf16>, %[[VALUE:.+]]: tensor<1x1024x64xf16>)