[LinalgExt] Add optional transpose_v attribute to attention (#16254)
In the attention operator, the second contraction is a matrix
multiplication. Since many hardware vendors natively support matrix
multiplication with b transposed, a common benchmark is to assume that
the second operand of the second contraction is transposed, result in a
matrix multiplication with b transposed. This PR adds that optional
attribute and updates the tile and decompose pass to handle the case.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index 390a69b..cf8d671 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -552,7 +552,8 @@
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
- Variadic<AnyShaped>:$outputs
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_v
);
let builders = [
@@ -611,7 +612,7 @@
std::optional<ShapedType> getSumType() {
if (!getSum().has_value())
return std::nullopt;
- return (*getSum()).getType().cast<ShapedType>();
+ return (*getSum()).getType().cast<ShapedType>();
}
int64_t getQueryRank() {
return getQueryType().getRank();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 542f74a..dbcdfcd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2478,8 +2478,13 @@
return failure();
ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
- ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
+ SmallVector<int64_t> valueShape(valueType.getShape());
+ bool transposeV = getTransposeV();
+ if (transposeV) {
+ size_t lastIdx = valueShape.size() - 1;
+ std::swap(valueShape[lastIdx - 1], valueShape[lastIdx]);
+ }
if (failed(verifyCompatibleShape(keyShape, valueShape)))
return op->emitOpError("incompatible value shape");
if (failed(verifyCompatibleShape(queryShape, outputShape)))
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index fb87d30..539f20b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -195,7 +195,8 @@
static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
OpFoldResult headDimension, Type elementType,
- Location loc, OpBuilder &builder) {
+ Location loc, OpBuilder &builder,
+ bool swapLastTwoDims = false) {
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
SmallVector<OpFoldResult> strides(keyShape.size(), one);
@@ -206,6 +207,11 @@
if (!ivs.empty())
offsets[1] = ivs[0];
SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
+ if (swapLastTwoDims) {
+ std::swap(sizes[1], sizes[2]);
+ std::swap(offsets[1], offsets[2]);
+ std::swap(tensorShape[0], tensorShape[1]);
+ }
auto tensorType = RankedTensorType::get(tensorShape, elementType);
Value keySlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, key, offsets, sizes, strides);
@@ -251,7 +257,7 @@
OpFoldResult sequenceTileLength,
OpFoldResult keyValueTileLength, OpFoldResult headDimension,
Type elementType, SmallVectorImpl<Operation *> &ops,
- Location loc, OpBuilder &builder) {
+ bool transposeV, Location loc, OpBuilder &builder) {
Type f32Type = builder.getF32Type();
// Compute matmul(q, transpose(k))
@@ -283,11 +289,18 @@
scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
// Compute matmul(softmax, v)
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
- scaledAcc);
+ Operation *matmulOp;
+ if (transposeV) {
+ matmulOp = builder.create<linalg::MatmulTransposeBOp>(
+ loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
+ scaledAcc);
+ } else {
+ matmulOp = builder.create<linalg::MatmulOp>(
+ loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
+ scaledAcc);
+ }
ops.push_back(matmulOp);
- Value result = matmulOp.getResult(0);
+ Value result = matmulOp->getResult(0);
return std::make_tuple(result, newMax, newSum);
}
@@ -416,8 +429,9 @@
// Extract slices
Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
headDimension, elementType, loc, rewriter);
- Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength,
- headDimension, elementType, loc, rewriter);
+ Value valueSlice =
+ extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
+ elementType, loc, rewriter, attnOp.getTransposeV());
Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
headDimension, elementType, loc, rewriter);
@@ -427,6 +441,9 @@
SmallVector<Value>{querySlice, keySlice, valueSlice},
SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
+ if (attnOp.getTransposeV())
+ tiledAttentionOp.setTransposeVAttr(attnOp.getTransposeVAttr());
+
Value tiledResult = tiledAttentionOp.getResult(0);
Value newMax = tiledAttentionOp.getResult(1);
Value newSum = tiledAttentionOp.getResult(2);
@@ -486,10 +503,10 @@
tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
Type elementType = tiledAttnOp.getQueryType().getElementType();
- auto [result, newMax, newSum] =
- createAttentionBody(keySlice, valueSlice, querySlice, tiledResult, max,
- sum, sequenceTileLength, keyValueTileLength,
- headDimension, elementType, ops, loc, rewriter);
+ auto [result, newMax, newSum] = createAttentionBody(
+ keySlice, valueSlice, querySlice, tiledResult, max, sum,
+ sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
+ tiledAttnOp.getTransposeV(), loc, rewriter);
rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
index fcd63e1..e4e54ce 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
@@ -1112,3 +1112,35 @@
// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
// CHECK: }
+
+// -----
+
+func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: tensor<192x2048x64xf32>, %value: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32> {
+ %0 = tensor.empty() : tensor<192x1024x64xf32>
+ %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+ return %1 : tensor<192x1024x64xf32>
+}
+// CHECK: func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME: tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32>
+// CHECK-SAME: {
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
+// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
+// CHECK: }
+
+// -----
+
+func.func @cross_attention_transposev_dyn(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK: func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME: tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-SAME: {
+// CHECK: %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
+// CHECK-SAME: tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG3]] :
+// CHECK-SAME: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: return %[[D1]] : tensor<?x?x?xf32>
+// CHECK: }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 1fd9a4d..3339eb5 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -717,3 +717,265 @@
// CHECK-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
// CHECK: }
+
+// -----
+
+func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+ %0 = tensor.empty() : tensor<1x1024x64xf16>
+ %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
+ return %1 : tensor<1x1024x64xf16>
+}
+
+// TILESIZE-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE: func.func @attention_transpose_v(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME: tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+// TILESIZE: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// TILESIZE: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILESIZE: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILESIZE-SAME: tensor<1024x64xf32>
+// TILESIZE-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILESIZE: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG: %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// TILESIZE: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// TILESIZE-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILESIZE: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<32x64xf16>
+// TILESIZE: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, 0, %[[ARG3]]] [1, 64, 32] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x64x1024xf16> to tensor<64x32xf16>
+// TILESIZE: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE: %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
+// TILESIZE: %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME: tensor<1024x32xf32>
+// TILESIZE: %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// TILESIZE-SAME: tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME: tensor<1024x32xf32>
+// TILESIZE: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D11]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x32xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D22:.+]] = math.exp2 %[[D21]] : f32
+// TILESIZE: linalg.yield %[[D22]] : f32
+// TILESIZE: } -> tensor<1024x32xf32>
+// TILESIZE: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE: %[[D22]] = math.exp2 %[[D21]] : f32
+// TILESIZE: linalg.yield %[[D22]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME: ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "reduction"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D15]] : tensor<1024xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024xf32>
+// TILESIZE: %[[D17:.+]] = tensor.empty() : tensor<1024x32xf16>
+// TILESIZE: %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D17]] : tensor<1024x32xf16>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE: %[[D21]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE: linalg.yield %[[D21]] : f16
+// TILESIZE: } -> tensor<1024x32xf16>
+// TILESIZE: %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D21]] : f32
+// TILESIZE: } -> tensor<1024x64xf32>
+// TILESIZE: %[[D20:.+]] = linalg.matmul_transpose_b ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x32xf16>,
+// TILESIZE-SAME: tensor<64x32xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// TILESIZE: scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILESIZE: }
+// TILESIZE: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILESIZE-SAME: {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE: %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE: %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILESIZE: linalg.yield %[[D10]] : f32
+// TILESIZE: } -> tensor<1024x64xf32>
+// TILESIZE: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// TILESIZE: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE: linalg.yield %[[D9]] : f16
+// TILESIZE: } -> tensor<1024x64xf16>
+// TILESIZE: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// TILESIZE: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
+
+// TILING-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILING-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILING: func.func @attention_transpose_v(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILING-SAME: tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+// TILING: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// TILING: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILING: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILING-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILING: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILING-SAME: tensor<1024x64xf32>
+// TILING-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILING: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILING: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILING: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILING-DAG: %[[C0:.+]] = arith.constant 0 : index
+// TILING-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// TILING: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C1024]]
+// TILING-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILING-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILING: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILING: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, 0, %[[ARG3]]] [1, 64, 1024] [1, 1, 1] :
+// TILING-SAME: tensor<1x64x1024xf16> to tensor<64x1024xf16>
+// TILING: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILING: %[[D9:.+]]:3 = iree_linalg_ext.attention {transpose_v = true} ins(%[[EXTRACTED_SLICE_3]],
+// TILING-SAME: %[[EXTRACTED_SLICE_1]], %[[EXTRACTED_SLICE_2]] : tensor<1024x64xf16>, tensor<1024x64xf16>,
+// TILING-SAME: tensor<64x1024xf16>) outs(%[[ARG4]], %[[ARG5]], %[[ARG6]] : tensor<1024x64xf32>, tensor<1024xf32>,
+// TILING-SAME: tensor<1024xf32>) -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILING: scf.yield %[[D9]]#[[D0:.+]], %[[D9]]#[[D1:.+]], %[[D9]]#[[D2:.+]] : tensor<1024x64xf32>,
+// TILING-SAME: tensor<1024xf32>, tensor<1024xf32>
+// TILING: }
+// TILING: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILING-SAME: "parallel"]} ins(%[[D6]]#[[D2]] : tensor<1024xf32>) outs(%[[D6]]#[[D0]] : tensor<1024x64xf32>) {
+// TILING: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILING: %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILING: %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILING: linalg.yield %[[D10]] : f32
+// TILING: } -> tensor<1024x64xf32>
+// TILING: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel",
+// TILING-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// TILING: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILING: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILING: linalg.yield %[[D9]] : f16
+// TILING: } -> tensor<1024x64xf16>
+// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// TILING: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func.func @attention_transpose_v(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME: tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// CHECK: %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// CHECK-SAME: tensor<1024x64xf32>
+// CHECK-DAG: %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// CHECK: %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// CHECK: %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK: %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C1024]]
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, 0, %[[ARG3]]] [1, 64, 1024] [1, 1, 1] :
+// CHECK-SAME: tensor<1x64x1024xf16> to tensor<64x1024xf16>
+// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK: %[[D9:.+]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK: %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x1024xf32>) ->
+// CHECK-SAME: tensor<1024x1024xf32>
+// CHECK: %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// CHECK-SAME: tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D10]] : tensor<1024x1024xf32>) ->
+// CHECK-SAME: tensor<1024x1024xf32>
+// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D11]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D21]] : f32
+// CHECK: } -> tensor<1024xf32>
+// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// CHECK: %[[D22:.+]] = math.exp2 %[[D21]] : f32
+// CHECK: linalg.yield %[[D22]] : f32
+// CHECK: } -> tensor<1024x1024xf32>
+// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// CHECK: %[[D22]] = math.exp2 %[[D21]] : f32
+// CHECK: linalg.yield %[[D22]] : f32
+// CHECK: } -> tensor<1024xf32>
+// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D21]] : f32
+// CHECK: } -> tensor<1024xf32>
+// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D15]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D21]] : f32
+// CHECK: } -> tensor<1024xf32>
+// CHECK: %[[D17:.+]] = tensor.empty() : tensor<1024x1024xf16>
+// CHECK: %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D17]] : tensor<1024x1024xf16>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// CHECK: %[[D21]] = arith.truncf %[[IN]] : f32 to f16
+// CHECK: linalg.yield %[[D21]] : f16
+// CHECK: } -> tensor<1024x1024xf16>
+// CHECK: %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D21]] : f32
+// CHECK: } -> tensor<1024x64xf32>
+// CHECK: %[[D20:.+]] = linalg.matmul_transpose_b ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x1024xf16>,
+// CHECK-SAME: tensor<64x1024xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// CHECK: scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// CHECK: }
+// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// CHECK-SAME: {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK: %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D10]] : f32
+// CHECK: } -> tensor<1024x64xf32>
+// CHECK: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// CHECK: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// CHECK: linalg.yield %[[D9]] : f16
+// CHECK: } -> tensor<1024x64xf16>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>