[LinalgExt] Add optional transpose_v attribute to attention (#16254)

In the attention operator, the second contraction is a matrix
multiplication. Since many hardware vendors natively support matrix
multiplication with b transposed, a common benchmark is to assume that
the second operand of the second contraction is transposed, result in a
matrix multiplication with b transposed. This PR adds that optional
attribute and updates the tile and decompose pass to handle the case.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index 390a69b..cf8d671 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -552,7 +552,8 @@
   }];
 
   let arguments = (ins Variadic<AnyShaped>:$inputs,
-                       Variadic<AnyShaped>:$outputs
+                       Variadic<AnyShaped>:$outputs,
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_v
   );
 
   let builders = [
@@ -611,7 +612,7 @@
     std::optional<ShapedType> getSumType() {
       if (!getSum().has_value())
         return std::nullopt;
-      return (*getSum()).getType().cast<ShapedType>();      
+      return (*getSum()).getType().cast<ShapedType>();
     }
     int64_t getQueryRank() {
       return getQueryType().getRank();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 542f74a..dbcdfcd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -2478,8 +2478,13 @@
     return failure();
   ArrayRef<int64_t> queryShape = queryType.getShape();
   ArrayRef<int64_t> keyShape = keyType.getShape();
-  ArrayRef<int64_t> valueShape = valueType.getShape();
   ArrayRef<int64_t> outputShape = outputType.getShape();
+  SmallVector<int64_t> valueShape(valueType.getShape());
+  bool transposeV = getTransposeV();
+  if (transposeV) {
+    size_t lastIdx = valueShape.size() - 1;
+    std::swap(valueShape[lastIdx - 1], valueShape[lastIdx]);
+  }
   if (failed(verifyCompatibleShape(keyShape, valueShape)))
     return op->emitOpError("incompatible value shape");
   if (failed(verifyCompatibleShape(queryShape, outputShape)))
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index fb87d30..539f20b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -195,7 +195,8 @@
 static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
                           ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
                           OpFoldResult headDimension, Type elementType,
-                          Location loc, OpBuilder &builder) {
+                          Location loc, OpBuilder &builder,
+                          bool swapLastTwoDims = false) {
   auto one = builder.getIndexAttr(1);
   auto zero = builder.getIndexAttr(0);
   SmallVector<OpFoldResult> strides(keyShape.size(), one);
@@ -206,6 +207,11 @@
   if (!ivs.empty())
     offsets[1] = ivs[0];
   SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
+  if (swapLastTwoDims) {
+    std::swap(sizes[1], sizes[2]);
+    std::swap(offsets[1], offsets[2]);
+    std::swap(tensorShape[0], tensorShape[1]);
+  }
   auto tensorType = RankedTensorType::get(tensorShape, elementType);
   Value keySlice = builder.create<tensor::ExtractSliceOp>(
       loc, tensorType, key, offsets, sizes, strides);
@@ -251,7 +257,7 @@
                     OpFoldResult sequenceTileLength,
                     OpFoldResult keyValueTileLength, OpFoldResult headDimension,
                     Type elementType, SmallVectorImpl<Operation *> &ops,
-                    Location loc, OpBuilder &builder) {
+                    bool transposeV, Location loc, OpBuilder &builder) {
 
   Type f32Type = builder.getF32Type();
   // Compute matmul(q, transpose(k))
@@ -283,11 +289,18 @@
       scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
 
   // Compute matmul(softmax, v)
-  auto matmulOp = builder.create<linalg::MatmulOp>(
-      loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
-      scaledAcc);
+  Operation *matmulOp;
+  if (transposeV) {
+    matmulOp = builder.create<linalg::MatmulTransposeBOp>(
+        loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
+        scaledAcc);
+  } else {
+    matmulOp = builder.create<linalg::MatmulOp>(
+        loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
+        scaledAcc);
+  }
   ops.push_back(matmulOp);
-  Value result = matmulOp.getResult(0);
+  Value result = matmulOp->getResult(0);
   return std::make_tuple(result, newMax, newSum);
 }
 
@@ -416,8 +429,9 @@
   // Extract slices
   Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
                                 headDimension, elementType, loc, rewriter);
-  Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength,
-                                  headDimension, elementType, loc, rewriter);
+  Value valueSlice =
+      extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
+                   elementType, loc, rewriter, attnOp.getTransposeV());
   Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
                                   headDimension, elementType, loc, rewriter);
 
@@ -427,6 +441,9 @@
       SmallVector<Value>{querySlice, keySlice, valueSlice},
       SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
 
+  if (attnOp.getTransposeV())
+    tiledAttentionOp.setTransposeVAttr(attnOp.getTransposeVAttr());
+
   Value tiledResult = tiledAttentionOp.getResult(0);
   Value newMax = tiledAttentionOp.getResult(1);
   Value newSum = tiledAttentionOp.getResult(2);
@@ -486,10 +503,10 @@
       tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
 
   Type elementType = tiledAttnOp.getQueryType().getElementType();
-  auto [result, newMax, newSum] =
-      createAttentionBody(keySlice, valueSlice, querySlice, tiledResult, max,
-                          sum, sequenceTileLength, keyValueTileLength,
-                          headDimension, elementType, ops, loc, rewriter);
+  auto [result, newMax, newSum] = createAttentionBody(
+      keySlice, valueSlice, querySlice, tiledResult, max, sum,
+      sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
+      tiledAttnOp.getTransposeV(), loc, rewriter);
 
   rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
 }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
index fcd63e1..e4e54ce 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
@@ -1112,3 +1112,35 @@
 // CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
 // CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
 // CHECK:      }
+
+// -----
+
+func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: tensor<192x2048x64xf32>, %value: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32> {
+  %0 = tensor.empty() : tensor<192x1024x64xf32>
+  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+  return %1 : tensor<192x1024x64xf32>
+}
+// CHECK:      func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME:   tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32>
+// CHECK-SAME:   {
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
+// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>) outs(%[[D0]] :
+// CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
+// CHECK:      }
+
+// -----
+
+func.func @cross_attention_transposev_dyn(%query: tensor<?x?x?xf32>, %key: tensor<?x?x?xf32>, %value: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK:      func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME:   tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-SAME:   {
+// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
+// CHECK-SAME:     tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG3]] :
+// CHECK-SAME:     tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:        return %[[D1]] : tensor<?x?x?xf32>
+// CHECK:      }
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 1fd9a4d..3339eb5 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -717,3 +717,265 @@
 // CHECK-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
 // CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
 // CHECK:      }
+
+// -----
+
+func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+  %0 = tensor.empty() : tensor<1x1024x64xf16>
+  %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
+  return %1 : tensor<1x1024x64xf16>
+}
+
+// TILESIZE-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILESIZE-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILESIZE-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// TILESIZE:      func.func @attention_transpose_v(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILESIZE-SAME:   tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+// TILESIZE:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// TILESIZE:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILESIZE:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:     tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILESIZE:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILESIZE-SAME:     tensor<1024x64xf32>
+// TILESIZE-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILESIZE:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILESIZE:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILESIZE-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// TILESIZE-DAG:    %[[C32:.+]] = arith.constant 32 : index
+// TILESIZE-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// TILESIZE:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// TILESIZE-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILESIZE-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILESIZE:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 32, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<32x64xf16>
+// TILESIZE:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, 0, %[[ARG3]]] [1, 64, 32] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x64x1024xf16> to tensor<64x32xf16>
+// TILESIZE:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILESIZE:          %[[D9:.+]] = tensor.empty() : tensor<1024x32xf32>
+// TILESIZE:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME:       tensor<1024x32xf32>
+// TILESIZE:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// TILESIZE-SAME:       tensor<1024x64xf16>, tensor<32x64xf16>) outs(%[[D10]] : tensor<1024x32xf32>) ->
+// TILESIZE-SAME:       tensor<1024x32xf32>
+// TILESIZE:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D11]] : tensor<1024x32xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x32xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D22:.+]] = math.exp2 %[[D21]] : f32
+// TILESIZE:            linalg.yield %[[D22]] : f32
+// TILESIZE:          } -> tensor<1024x32xf32>
+// TILESIZE:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// TILESIZE:            %[[D22]] = math.exp2 %[[D21]] : f32
+// TILESIZE:            linalg.yield %[[D22]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// TILESIZE-SAME:       ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "reduction"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D15]] : tensor<1024xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024xf32>
+// TILESIZE:          %[[D17:.+]] = tensor.empty() : tensor<1024x32xf16>
+// TILESIZE:          %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024x32xf32>) outs(%[[D17]] : tensor<1024x32xf16>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE:            %[[D21]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE:            linalg.yield %[[D21]] : f16
+// TILESIZE:          } -> tensor<1024x32xf16>
+// TILESIZE:          %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:       "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// TILESIZE:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// TILESIZE:            linalg.yield %[[D21]] : f32
+// TILESIZE:          } -> tensor<1024x64xf32>
+// TILESIZE:          %[[D20:.+]] = linalg.matmul_transpose_b ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x32xf16>,
+// TILESIZE-SAME:       tensor<64x32xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// TILESIZE:          scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILESIZE:        }
+// TILESIZE:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILESIZE-SAME:     {
+// TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILESIZE-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILESIZE:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILESIZE:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILESIZE:          linalg.yield %[[D10]] : f32
+// TILESIZE:        } -> tensor<1024x64xf32>
+// TILESIZE:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILESIZE-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// TILESIZE:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILESIZE:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILESIZE:          linalg.yield %[[D9]] : f16
+// TILESIZE:        } -> tensor<1024x64xf16>
+// TILESIZE:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILESIZE-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// TILESIZE:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
+
+// TILING-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+// TILING-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILING:      func.func @attention_transpose_v(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// TILING-SAME:   tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+// TILING:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// TILING:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// TILING:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME:     tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILING-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// TILING:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// TILING-SAME:     tensor<1024x64xf32>
+// TILING-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// TILING:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// TILING:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILING:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// TILING-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// TILING-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// TILING:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C1024]]
+// TILING-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// TILING-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// TILING:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILING:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, 0, %[[ARG3]]] [1, 64, 1024] [1, 1, 1] :
+// TILING-SAME:       tensor<1x64x1024xf16> to tensor<64x1024xf16>
+// TILING:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// TILING:          %[[D9:.+]]:3 = iree_linalg_ext.attention {transpose_v = true} ins(%[[EXTRACTED_SLICE_3]],
+// TILING-SAME:       %[[EXTRACTED_SLICE_1]], %[[EXTRACTED_SLICE_2]] : tensor<1024x64xf16>, tensor<1024x64xf16>,
+// TILING-SAME:       tensor<64x1024xf16>) outs(%[[ARG4]], %[[ARG5]], %[[ARG6]] : tensor<1024x64xf32>, tensor<1024xf32>,
+// TILING-SAME:       tensor<1024xf32>) -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// TILING:          scf.yield %[[D9]]#[[D0:.+]], %[[D9]]#[[D1:.+]], %[[D9]]#[[D2:.+]] : tensor<1024x64xf32>,
+// TILING-SAME:       tensor<1024xf32>, tensor<1024xf32>
+// TILING:        }
+// TILING:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// TILING-SAME:     "parallel"]} ins(%[[D6]]#[[D2]] : tensor<1024xf32>) outs(%[[D6]]#[[D0]] : tensor<1024x64xf32>) {
+// TILING:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILING:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILING:          %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILING:          linalg.yield %[[D10]] : f32
+// TILING:        } -> tensor<1024x64xf32>
+// TILING:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel",
+// TILING-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// TILING:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// TILING:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILING:          linalg.yield %[[D9]] : f16
+// TILING:        } -> tensor<1024x64xf16>
+// TILING:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// TILING:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
+
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
+// CHECK:      func.func @attention_transpose_v(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1024x64xf16>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME:   tensor<1x1024x64xf16>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> {
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf16>
+// CHECK:        %[[D1:.+]] = tensor.empty() : tensor<1024x64xf32>
+// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME:     tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK-DAG:    %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<1024x64xf32>) ->
+// CHECK-SAME:     tensor<1024x64xf32>
+// CHECK-DAG:    %[[CST_0:.+]] = arith.constant -1.000000e+30 : f32
+// CHECK:        %[[D3:.+]] = tensor.empty() : tensor<1024xf32>
+// CHECK:        %[[D4:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D3]] : tensor<1024xf32>) -> tensor<1024xf32>
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK:        %[[D6:.+]]:3 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C1024]]
+// CHECK-SAME:     iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D2]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D4]],
+// CHECK-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[D5]]) -> (tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>) {
+// CHECK:          %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[ARG3]], 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK:          %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG2]][0, 0, %[[ARG3]]] [1, 64, 1024] [1, 1, 1] :
+// CHECK-SAME:       tensor<1x64x1024xf16> to tensor<64x1024xf16>
+// CHECK:          %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME:       tensor<1x1024x64xf16> to tensor<1024x64xf16>
+// CHECK:          %[[D9:.+]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK:          %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x1024xf32>) ->
+// CHECK-SAME:       tensor<1024x1024xf32>
+// CHECK:          %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// CHECK-SAME:       tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D10]] : tensor<1024x1024xf32>) ->
+// CHECK-SAME:       tensor<1024x1024xf32>
+// CHECK:          %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D11]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D21]] : f32
+// CHECK:          } -> tensor<1024xf32>
+// CHECK:          %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// CHECK:            %[[D22:.+]] = math.exp2 %[[D21]] : f32
+// CHECK:            linalg.yield %[[D22]] : f32
+// CHECK:          } -> tensor<1024x1024xf32>
+// CHECK:          %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
+// CHECK:            %[[D22]] = math.exp2 %[[D21]] : f32
+// CHECK:            linalg.yield %[[D22]] : f32
+// CHECK:          } -> tensor<1024xf32>
+// CHECK:          %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME:       ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D21]] : f32
+// CHECK:          } -> tensor<1024xf32>
+// CHECK:          %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:       "reduction"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D15]] : tensor<1024xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D21]] : f32
+// CHECK:          } -> tensor<1024xf32>
+// CHECK:          %[[D17:.+]] = tensor.empty() : tensor<1024x1024xf16>
+// CHECK:          %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D17]] : tensor<1024x1024xf16>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// CHECK:            %[[D21]] = arith.truncf %[[IN]] : f32 to f16
+// CHECK:            linalg.yield %[[D21]] : f16
+// CHECK:          } -> tensor<1024x1024xf16>
+// CHECK:          %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:       "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// CHECK:          ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:            %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK:            linalg.yield %[[D21]] : f32
+// CHECK:          } -> tensor<1024x64xf32>
+// CHECK:          %[[D20:.+]] = linalg.matmul_transpose_b ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x1024xf16>,
+// CHECK-SAME:       tensor<64x1024xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
+// CHECK:          scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// CHECK:        }
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// CHECK-SAME:     {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG:      %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:          %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK:          %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D10]] : f32
+// CHECK:        } -> tensor<1024x64xf32>
+// CHECK:        %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
+// CHECK:          %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// CHECK:          linalg.yield %[[D9]] : f16
+// CHECK:        } -> tensor<1024x64xf16>
+// CHECK:        %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME:     tensor<1024x64xf16> into tensor<1x1024x64xf16>
+// CHECK:        return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>