Add FlashAttention v2 (#15527)
This patch adds the algorithmic modifications introduced in FA2 to the
tile and decompose attention pass.
---------
Co-authored-by: Harsh Menon <harsh@nod-labs.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 367809b..2005453 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -40,7 +40,7 @@
// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<-1.000000e+30> : vector<32xf32>
// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[CST_2:.+]] = arith.constant dense<0.000000e+00> : vector<32x128xf32>
-// CHECK-DAG: %[[CST_3:.+]] = arith.constant dense<1.000000e+00> : vector<32xf32>
+// CHECK-DAG: %[[CST_3:.+]] = arith.constant dense<1.000000e+00> : vector<64x32xf32>
// CHECK-DAG: %[[CST_4:.+]] = arith.constant 0.000000e+00 : f16
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
@@ -139,13 +139,8 @@
// CHECK: %[[D23:.+]] = arith.mulf %[[D22]], %[[ARG2]] : vector<32xf32>
// CHECK: %[[D24:.+]] = vector.multi_reduction <add>, %[[D20]], %[[D23]] [1] : vector<32x128xf32> to
// CHECK-SAME: vector<32xf32>
-// CHECK: %[[D25:.+]] = arith.divf %[[CST_3]], %[[D24]] : vector<32xf32>
-// CHECK: %[[D26:.+]] = vector.broadcast %[[D25]] : vector<32xf32> to vector<128x32xf32>
-// CHECK: %[[D27:.+]] = vector.transpose %[[D26]], [1, 0] : vector<128x32xf32> to vector<32x128xf32>
-// CHECK: %[[D28:.+]] = arith.mulf %[[D20]], %[[D27]] : vector<32x128xf32>
-// CHECK: %[[D29:.+]] = arith.truncf %[[D28]] : vector<32x128xf32> to vector<32x128xf16>
-// CHECK: %[[D30:.+]] = arith.mulf %[[D23]], %[[D25]] : vector<32xf32>
-// CHECK: %[[D31:.+]] = vector.broadcast %[[D30]] : vector<32xf32> to vector<64x32xf32>
+// CHECK: %[[D29:.+]] = arith.truncf %[[D20]] : vector<32x128xf32> to vector<32x128xf16>
+// CHECK: %[[D31:.+]] = vector.broadcast %[[D22]] : vector<32xf32> to vector<64x32xf32>
// CHECK: %[[D33:.+]] = vector.transpose %[[D31]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
// CHECK: %[[D34:.+]] = arith.mulf %[[D33]], %[[ARG3]] : vector<32x64xf32>
// CHECK: %[[D35:.+]] = vector.transfer_read %[[ALLOC_11]][%[[C0]], %[[C0]]], %[[CST_4]] {in_bounds = [true,
@@ -159,7 +154,11 @@
// CHECK: gpu.barrier
// CHECK: scf.yield %[[D16]], %[[D24]], %[[D39]] : vector<32xf32>, vector<32xf32>, vector<32x64xf32>
// CHECK: }
-// CHECK: %[[D12:.+]] = arith.truncf %[[D11]]#[[D2:.+]] : vector<32x64xf32> to vector<32x64xf16>
+// CHECK: %[[DSCALE1:.+]] = vector.broadcast %[[D11]]#1 : vector<32xf32> to vector<64x32xf32>
+// CHECK: %[[DSCALE2:.+]] = arith.divf %[[CST_3]], %[[DSCALE1]] : vector<64x32xf32>
+// CHECK: %[[DSCALE3:.+]] = vector.transpose %[[DSCALE2]], [1, 0] : vector<64x32xf32> to vector<32x64xf32>
+// CHECK: %[[DSCALE4:.+]] = arith.mulf %[[DSCALE3]], %[[D11]]#2 : vector<32x64xf32>
+// CHECK: %[[D12:.+]] = arith.truncf %[[DSCALE4]] : vector<32x64xf32> to vector<32x64xf16>
// CHECK: vector.transfer_write %[[D12]], %[[ALLOC_7]][%[[C0]], %[[D8]], %[[C0]]] {in_bounds = [true, true]} :
// CHECK-SAME: vector<32x64xf16>, memref<1x128x64xf16, #[[GPU]].address_space<workgroup>>
// CHECK: gpu.barrier
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
index 03e1cd6..e33e07e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
@@ -1,4 +1,3 @@
-
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.consumed}) {
// Get attention op
@@ -31,11 +30,11 @@
// Tile and decompose attention
// ==========================================
%attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %acc_fill, %max_fill, %sum_fill, %inner_loop, %last_truncate, %blocked_attention = transform.tile_attention %attention4 :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %truncate, %scale_acc, %second_matmul
+ %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.tile_attention %attention4 :
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul
= transform.decompose_tiled_attention %blocked_attention :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// Promote key and value operands
// ==========================================
@@ -47,32 +46,54 @@
// Tile and fuse attention ops
// ==========================================
%tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
%f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f2, %loop2 = transform.structured.fuse_into_containing_op %softmax into %loop1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.apply_cse %func : !transform.any_op
- %f3, %loop3 = transform.structured.fuse_into_containing_op %reciprocal_sum into %loop2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f4, %loop4 = transform.structured.fuse_into_containing_op %reduce_sum into %loop3 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.iree.apply_cse %func : !transform.any_op
- %f5, %loop5 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.iree.apply_cse %func : !transform.any_op
+
+ %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.iree.apply_cse %func : !transform.any_op
%f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
transform.iree.apply_cse %func : !transform.any_op
- // Distribute fills and last truncate
+ %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.iree.apply_cse %func : !transform.any_op
+
+ // Distribute fills
// ==========================================
- %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill, %last_truncate : !transform.any_op
+ %fills = transform.merge_handles %acc_fill, %max_fill, %sum_fill : !transform.any_op
%tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ // Distribute last_truncate and fuse final_scaling into it
+ // ==========================================
+ %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp<linear_dim_0>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.iree.apply_cse %func : !transform.any_op
+
// Vectorize function
// ==========================================
transform.apply_patterns to %func {
@@ -126,4 +147,4 @@
transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> ()
transform.yield
}
-} // module
+} //// module
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 430320a..778aeb0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -74,22 +74,38 @@
return genericOp.getResult(0);
}
-static Value updateAndScale(Value oldMax, Value newMax, Value oldSum,
- Location loc, OpBuilder &builder,
+/// Return the scale factor for the new softmax maximum and add the generic to
+/// the provided list of operations.
+static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
+ OpBuilder &builder,
+ SmallVectorImpl<Operation *> &ops) {
+ SmallVector<utils::IteratorType> iteratorTypes(1,
+ utils::IteratorType::parallel);
+ auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
+ SmallVector<AffineMap> indexingMaps(2, identityMap);
+ auto genericOp = builder.create<linalg::GenericOp>(
+ loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
+ Value weight = b.create<math::Exp2Op>(loc, diff);
+ b.create<linalg::YieldOp>(loc, weight);
+ });
+ ops.push_back(genericOp);
+ return genericOp.getResult(0);
+}
+
+static Value updateAndScale(Value scaleFactor, Value oldSum, Location loc,
+ OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
SmallVector<utils::IteratorType> iteratorTypes(1,
utils::IteratorType::parallel);
auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
- SmallVector<AffineMap> indexingMaps(3, identityMap);
- SmallVector<Type> resultTypes{oldSum.getType()};
+ SmallVector<AffineMap> indexingMaps(2, identityMap);
auto genericOp = builder.create<linalg::GenericOp>(
- loc, resultTypes, ValueRange{oldMax, newMax}, ValueRange{oldSum},
- indexingMaps, iteratorTypes,
+ loc, oldSum.getType(), scaleFactor, oldSum, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
- Value weight = b.create<math::Exp2Op>(loc, diff);
- Value scaledOldSum = b.create<arith::MulFOp>(loc, weight, args[2]);
- b.create<linalg::YieldOp>(loc, ValueRange{scaledOldSum});
+ Value scaledOldSum = b.create<arith::MulFOp>(loc, args[0], args[1]);
+ b.create<linalg::YieldOp>(loc, scaledOldSum);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
@@ -117,28 +133,33 @@
return genericOp.getResult(0);
}
-static Value computeReciprocal(Value x, Location loc, OpBuilder &builder,
+static Value applyFinalScaling(Value result, Value newSum, Location loc,
+ OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
AffineMap identityMap =
- AffineMap::getMultiDimIdentityMap(1, builder.getContext());
- SmallVector<AffineMap> indexingMaps{identityMap};
- SmallVector<utils::IteratorType> iteratorTypes(1,
+ AffineMap::getMultiDimIdentityMap(2, builder.getContext());
+ AffineExpr d0, d1;
+ bindDims(builder.getContext(), d0, d1);
+ // (d0, d1) -> (d0)
+ auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
+ SmallVector<AffineMap> indexingMaps = {rowMap, identityMap};
+ SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
- loc, x.getType(), ValueRange{}, x, indexingMaps, iteratorTypes,
+ loc, result.getType(), newSum, result, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value one = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(args[0].getType(), 1.0));
- Value result = b.create<arith::DivFOp>(loc, one, args[0]);
+ Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
+ Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
-static Value scaleAccumulator(Value accumulator, Value scaledOldSum,
- Value inverseNewSum, Location loc,
- OpBuilder &builder,
+static Value scaleAccumulator(Value accumulator, Value scaleFactor,
+ Location loc, OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
@@ -146,15 +167,13 @@
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
- SmallVector<AffineMap> indexingMaps{rowMap, rowMap, identityMap};
+ SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
- loc, accumulator.getType(), ValueRange{scaledOldSum, inverseNewSum},
- accumulator, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value ratio = b.create<arith::MulFOp>(loc, args[0], args[1]);
- Value result = b.create<arith::MulFOp>(loc, ratio, args[2]);
+ loc, accumulator.getType(), scaleFactor, accumulator, indexingMaps,
+ iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
@@ -254,26 +273,24 @@
qkTranspose, maxSlice, loc, builder, ops);
Value partialSoftmax =
computePartialSoftmax(qkTranspose, newMax, loc, builder, ops);
- Value scaledOldSum =
- updateAndScale(maxSlice, newMax, sumSlice, loc, builder, ops);
+ Value scaleFactor = computeScaleFactor(maxSlice, newMax, loc, builder, ops);
+ Value scaledOldSum = updateAndScale(scaleFactor, sumSlice, loc, builder, ops);
Value newSum = computeRowwiseReduction<arith::AddFOp>(
partialSoftmax, scaledOldSum, loc, builder, ops);
- Value inverseNewSum = computeReciprocal(newSum, loc, builder, ops);
- Value softmax =
- scalePartialSoftmax(partialSoftmax, inverseNewSum, loc, builder, ops);
if (elementType.isF16()) {
Value empty =
builder.create<tensor::EmptyOp>(loc, resultShape, builder.getF16Type());
- softmax = truncateToF16(softmax, empty, ops, builder, loc);
+ partialSoftmax = truncateToF16(partialSoftmax, empty, ops, builder, loc);
}
// Update accumulator
- Value scaledAcc = scaleAccumulator(outputSlice, scaledOldSum, inverseNewSum,
- loc, builder, ops);
+ Value scaledAcc =
+ scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
// Compute matmul(softmax, v)
auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, scaledAcc.getType(), ValueRange{softmax, valueSlice}, scaledAcc);
+ loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
+ scaledAcc);
ops.push_back(matmulOp);
Value result = matmulOp.getResult(0);
return std::make_tuple(result, newMax, newSum);
@@ -417,6 +434,10 @@
OpBuilder::InsertionGuard yieldGuard(rewriter);
rewriter.setInsertionPointAfter(loopNest.loops.back());
+
+ loopNest.results[0] = applyFinalScaling(
+ loopNest.results[0], loopNest.results[2], loc, rewriter, ops);
+
if (elementType.isF16()) {
loopNest.results[0] =
truncateToF16(loopNest.results[0], outputSlice, ops, rewriter, loc);
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
index 3561d8f..322f8ab 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile_and_decompose_attention.mlir
@@ -7,6 +7,8 @@
return %1 : tensor<1x1024x64xf32>
}
+// TILING-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// TILING-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// TILING-LABEL: @attention
// TILING-SAME: (%[[QUERY:.+]]: tensor<1x1024x64xf32>, %[[KEY:.+]]: tensor<1x1024x64xf32>, %[[VALUE:.+]]: tensor<1x1024x64xf32>)
// TILING: %[[D0:.+]] = tensor.empty() : tensor<1x1024x64xf32>
@@ -34,8 +36,17 @@
// TILING-SAME: -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
// TILING: scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
// TILING: }
-// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1,
-// TILING-SAME: 1, 1] : tensor<1024x64xf32> into tensor<1x1024x64xf32>
+// TILING: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME: "parallel"]} ins(%[[D6]]#2 : tensor<1024xf32>) outs(%[[D6]]#0 : tensor<1024x64xf32>)
+// TILING-SAME: {
+// TILING: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILING: %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILING: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILING: linalg.yield %[[D9]] : f32
+// TILING: } -> tensor<1024x64xf32>
+// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING-SAME: tensor<1024x64xf32> into tensor<1x1024x64xf32>
// TILING: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
// TILING: }
@@ -64,69 +75,66 @@
// CHECK-SAME: tensor<1x1024x64xf32> to tensor<1024x64xf32>
// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
// CHECK-SAME: tensor<1x1024x64xf32> to tensor<1024x64xf32>
-// CHECK: %[[D7:.+]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK: %[[D8:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D7]] : tensor<1024x1024xf32>) ->
+// CHECK: %[[D8:.+]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK: %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x1024xf32>) ->
// CHECK-SAME: tensor<1024x1024xf32>
-// CHECK: %[[D9:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
-// CHECK-SAME: tensor<1024x64xf32>, tensor<1024x64xf32>) outs(%[[D8]] : tensor<1024x1024xf32>) ->
+// CHECK: %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_2]], %[[EXTRACTED_SLICE]] :
+// CHECK-SAME: tensor<1024x64xf32>, tensor<1024x64xf32>) outs(%[[D9]] : tensor<1024x1024xf32>) ->
// CHECK-SAME: tensor<1024x1024xf32>
-// CHECK: %[[D10:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D9]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D10]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D18]] : f32
// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D10]] : tensor<1024xf32>) outs(%[[D9]] : tensor<1024x1024xf32>) {
+// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<1024x1024xf32>
-// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]], iterator_types =
-// CHECK-SAME: ["parallel"]} ins(%[[ARG5]], %[[D10]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[ARG6]] :
-// CHECK-SAME: tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_3:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.subf %[[IN]], %[[IN_3]] : f32
-// CHECK: %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK: %[[D20:.+]] = arith.mulf %[[D19]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D20]] : f32
+// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D11]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D18:.+]] = arith.subf %[[OUT]], %[[IN]] : f32
+// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
+// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D11]] : tensor<1024x1024xf32>) outs(%[[D12]] : tensor<1024xf32>) {
+// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D18]] : f32
+// CHECK: } -> tensor<1024xf32>
+// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D12]] : tensor<1024x1024xf32>) outs(%[[D14]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D18]] : f32
// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: outs(%[[D13]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D18]] = arith.divf %[[CST_3]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x1024xf32>) {
+// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[OUT]], %[[IN]] : f32
+// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<1024x1024xf32>
-// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel"]} ins(%[[D12]], %[[D14]] : tensor<1024xf32>, tensor<1024xf32>)
-// CHECK-SAME: outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_3:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[IN_3]] : f32
-// CHECK: %[[D19]] = arith.mulf %[[D18]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<1024x64xf32>
-// CHECK: %[[D17:.+]] = linalg.matmul ins(%[[D15]], %[[EXTRACTED_SLICE_1]] : tensor<1024x1024xf32>,
+// CHECK: %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_1]] : tensor<1024x1024xf32>,
// CHECK-SAME: tensor<1024x64xf32>) outs(%[[D16]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK: scf.yield %[[D17]], %[[D10]], %[[D13]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// CHECK: scf.yield %[[D17]], %[[D11]], %[[D15]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
// CHECK: }
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1,
-// CHECK-SAME: 1, 1] : tensor<1024x64xf32> into tensor<1x1024x64xf32>
+// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// CHECK-SAME: {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D8:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D9]] : f32
+// CHECK: } -> tensor<1024x64xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK-SAME: tensor<1024x64xf32> into tensor<1x1024x64xf32>
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf32>
-// CHECK: }
// -----
@@ -167,8 +175,16 @@
// TILING-SAME: -> tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
// TILING: scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
// TILING: }
-// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, %[[DIM]],
-// TILING-SAME: %[[DIM_0]]] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// TILING: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
+// TILING: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG: %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
+// TILING: %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// TILING: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// TILING: linalg.yield %[[D9]] : f32
+// TILING: } -> tensor<?x?xf32>
+// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
+// TILING-SAME: [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
// TILING: return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
// TILING: }
@@ -212,55 +228,52 @@
// CHECK: %[[D18:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D18]] : f32
// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel"]} ins(%[[D10]] : tensor<?xf32>) outs(%[[D9]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D19:.+]] = math.exp2 %[[D18]] : f32
// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<?x?xf32>
-// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]], iterator_types =
-// CHECK-SAME: ["parallel"]} ins(%[[ARG8]], %[[D10]] : tensor<?xf32>, tensor<?xf32>) outs(%[[ARG9]] :
-// CHECK-SAME: tensor<?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_5:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.subf %[[IN]], %[[IN_5]] : f32
+// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D10]] : tensor<?xf32>) outs(%[[ARG8]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D18]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D19]] = math.exp2 %[[D18]] : f32
-// CHECK: %[[D20:.+]] = arith.mulf %[[D19]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D20]] : f32
+// CHECK: linalg.yield %[[D19]] : f32
// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D11]] : tensor<?x?xf32>) outs(%[[D12]] : tensor<?xf32>) {
+// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG9]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D18]] : f32
+// CHECK: } -> tensor<?xf32>
+// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D12]] : tensor<?x?xf32>) outs(%[[D14]] : tensor<?xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D18]] = arith.addf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D18]] : f32
// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: outs(%[[D13]] : tensor<?xf32>) {
-// CHECK: ^bb0(%[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_5:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D18]] = arith.divf %[[CST_5]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D18]] : f32
-// CHECK: } -> tensor<?xf32>
-// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D14]] : tensor<?xf32>) outs(%[[D11]] : tensor<?x?xf32>) {
+// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<?xf32>) outs(%[[ARG7]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[OUT]], %[[IN]] : f32
+// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D18]] : f32
// CHECK: } -> tensor<?x?xf32>
-// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel"]} ins(%[[D12]], %[[D14]] : tensor<?xf32>, tensor<?xf32>) outs(%[[ARG7]] :
-// CHECK-SAME: tensor<?x?xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_5:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D18]] = arith.mulf %[[IN]], %[[IN_5]] : f32
-// CHECK: %[[D19]] = arith.mulf %[[D18]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D19]] : f32
-// CHECK: } -> tensor<?x?xf32>
-// CHECK: %[[D17:.+]] = linalg.matmul ins(%[[D15]], %[[EXTRACTED_SLICE_3]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK: %[[D17:.+]] = linalg.matmul ins(%[[D12]], %[[EXTRACTED_SLICE_3]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[D16]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: scf.yield %[[D17]], %[[D10]], %[[D13]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
+// CHECK: scf.yield %[[D17]], %[[D10]], %[[D15]] : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>
// CHECK: }
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D6]]#[[D0:.+]] into %[[D0]][0, 0, 0] [1, %[[DIM]],
-// CHECK-SAME: %[[DIM_0]]] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<?xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<?x?xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG: %[[CST_3:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D8:.+]] = arith.divf %[[CST_3]], %[[IN]] : f32
+// CHECK: %[[D9:.+]] = arith.mulf %[[D8]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D9]] : f32
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]]
+// CHECK-SAME: [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
// CHECK: return %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
// CHECK: }
@@ -302,14 +315,22 @@
// TILING-SAME: -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
// TILING: scf.yield %[[TILED_ATTENTION]]#0, %[[TILED_ATTENTION]]#1, %[[TILED_ATTENTION]]#2 : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
// TILING: }
-// TILING: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
-// TILING-SAME: "parallel"]} ins(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] :
-// TILING-SAME: tensor<1024x64xf16>) {
+// TILING: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// TILING-SAME: {
+// TILING: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// TILING-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// TILING: %[[D9:.+]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// TILING: %[[D10:.+]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// TILING: linalg.yield %[[D10]] : f32
+// TILING: } -> tensor<1024x64xf32>
+// TILING: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// TILING-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
// TILING: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// TILING: %[[D8:.+]] = arith.truncf %[[IN]] : f32 to f16
-// TILING: linalg.yield %[[D8]] : f16
+// TILING: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// TILING: linalg.yield %[[D9]] : f16
// TILING: } -> tensor<1024x64xf16>
-// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// TILING: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
// TILING-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
// TILING: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
// TILING: }
@@ -341,80 +362,77 @@
// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
// CHECK-SAME: tensor<1x1024x64xf16> to tensor<1024x64xf16>
-// CHECK: %[[D8:.+]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK: %[[D9:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D8]] : tensor<1024x1024xf32>) ->
+// CHECK: %[[D9:.+]] = tensor.empty() : tensor<1024x1024xf32>
+// CHECK: %[[D10:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D9]] : tensor<1024x1024xf32>) ->
// CHECK-SAME: tensor<1024x1024xf32>
-// CHECK: %[[D10:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
-// CHECK-SAME: tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D9]] : tensor<1024x1024xf32>) ->
+// CHECK: %[[D11:.+]] = linalg.matmul_transpose_b ins(%[[EXTRACTED_SLICE_3]], %[[EXTRACTED_SLICE_1]] :
+// CHECK-SAME: tensor<1024x64xf16>, tensor<1024x64xf16>) outs(%[[D10]] : tensor<1024x1024xf32>) ->
// CHECK-SAME: tensor<1024x1024xf32>
-// CHECK: %[[D11:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D10]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D11]] : tensor<1024x1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D21:.+]] = arith.maximumf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D21]] : f32
// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D12:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D11]] : tensor<1024xf32>) outs(%[[D10]] : tensor<1024x1024xf32>) {
+// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D12]] : tensor<1024xf32>) outs(%[[D11]] : tensor<1024x1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D22:.+]] = math.exp2 %[[D21]] : f32
// CHECK: linalg.yield %[[D22]] : f32
// CHECK: } -> tensor<1024x1024xf32>
-// CHECK: %[[D13:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]], iterator_types =
-// CHECK-SAME: ["parallel"]} ins(%[[ARG5]], %[[D11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[ARG6]] :
-// CHECK-SAME: tensor<1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.subf %[[IN]], %[[IN_4]] : f32
+// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D12]] : tensor<1024xf32>) outs(%[[ARG5]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.subf %[[OUT]], %[[IN]] : f32
// CHECK: %[[D22]] = math.exp2 %[[D21]] : f32
-// CHECK: %[[D23:.+]] = arith.mulf %[[D22]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D23]] : f32
+// CHECK: linalg.yield %[[D22]] : f32
// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D14:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "reduction"]} ins(%[[D12]] : tensor<1024x1024xf32>) outs(%[[D13]] : tensor<1024xf32>) {
+// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP2]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG6]] : tensor<1024xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D21]] : f32
+// CHECK: } -> tensor<1024xf32>
+// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME: "reduction"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D15]] : tensor<1024xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D21]] = arith.addf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D21]] : f32
// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D15:.+]] = linalg.generic {indexing_maps = [#[[MAP2]]], iterator_types = ["parallel"]}
-// CHECK-SAME: outs(%[[D14]] : tensor<1024xf32>) {
-// CHECK: ^bb0(%[[OUT:.+]]: f32):
-// CHECK-DAG: %[[CST_4:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D21]] = arith.divf %[[CST_4]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D21]] : f32
-// CHECK: } -> tensor<1024xf32>
-// CHECK: %[[D16:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D15]] : tensor<1024xf32>) outs(%[[D12]] : tensor<1024x1024xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.mulf %[[OUT]], %[[IN]] : f32
-// CHECK: linalg.yield %[[D21]] : f32
-// CHECK: } -> tensor<1024x1024xf32>
// CHECK: %[[D17:.+]] = tensor.empty() : tensor<1024x1024xf16>
// CHECK: %[[D18:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D16]] : tensor<1024x1024xf32>) outs(%[[D17]] : tensor<1024x1024xf16>) {
+// CHECK-SAME: "parallel"]} ins(%[[D13]] : tensor<1024x1024xf32>) outs(%[[D17]] : tensor<1024x1024xf16>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
// CHECK: %[[D21]] = arith.truncf %[[IN]] : f32 to f16
// CHECK: linalg.yield %[[D21]] : f16
// CHECK: } -> tensor<1024x1024xf16>
-// CHECK: %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel"]} ins(%[[D13]], %[[D15]] : tensor<1024xf32>, tensor<1024xf32>)
-// CHECK-SAME: outs(%[[ARG4]] : tensor<1024x64xf32>) {
-// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32):
-// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[IN_4]] : f32
-// CHECK: %[[D22]] = arith.mulf %[[D21]], %[[OUT]] : f32
-// CHECK: linalg.yield %[[D22]] : f32
+// CHECK: %[[D19:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D14]] : tensor<1024xf32>) outs(%[[ARG4]] : tensor<1024x64xf32>) {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[D21]] = arith.mulf %[[IN]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D21]] : f32
// CHECK: } -> tensor<1024x64xf32>
// CHECK: %[[D20:.+]] = linalg.matmul ins(%[[D18]], %[[EXTRACTED_SLICE_2]] : tensor<1024x1024xf16>,
// CHECK-SAME: tensor<1024x64xf16>) outs(%[[D19]] : tensor<1024x64xf32>) -> tensor<1024x64xf32>
-// CHECK: scf.yield %[[D20]], %[[D11]], %[[D14]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
+// CHECK: scf.yield %[[D20]], %[[D12]], %[[D16]] : tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32>
// CHECK: }
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] :
-// CHECK-SAME: tensor<1024x64xf16>) {
+// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D6]]#[[D2:.+]] : tensor<1024xf32>) outs(%[[D6]]#[[D0:.+]] : tensor<1024x64xf32>)
+// CHECK-SAME: {
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D9]] = arith.divf %[[CST_1]], %[[IN]] : f32
+// CHECK: %[[D10]] = arith.mulf %[[D9]], %[[OUT]] : f32
+// CHECK: linalg.yield %[[D10]] : f32
+// CHECK: } -> tensor<1024x64xf32>
+// CHECK: %[[D8:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel",
+// CHECK-SAME: "parallel"]} ins(%[[D7]] : tensor<1024x64xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<1024x64xf16>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f16):
-// CHECK: %[[D8]] = arith.truncf %[[IN]] : f32 to f16
-// CHECK: linalg.yield %[[D8]] : f16
+// CHECK: %[[D9]] = arith.truncf %[[IN]] : f32 to f16
+// CHECK: linalg.yield %[[D9]] : f16
// CHECK: } -> tensor<1024x64xf16>
-// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D7]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D8]] into %[[D0]][0, 0, 0] [1, 1024, 64] [1, 1, 1] :
// CHECK-SAME: tensor<1024x64xf16> into tensor<1x1024x64xf16>
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1024x64xf16>
-// CHECK: }
\ No newline at end of file
+// CHECK: }
diff --git a/tests/transform_dialect/cpu/attention_codegen_spec.mlir b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
index 3601d92..73ffab9 100644
--- a/tests/transform_dialect/cpu/attention_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
@@ -20,11 +20,11 @@
// Tile and decompose attention
// ==========================================
%attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %acc_fill, %max_fill, %sum_fill, %inner_loop, %blocked_attention = transform.tile_attention %attention4 :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %scale_acc, %second_matmul
+ %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %blocked_attention = transform.tile_attention %attention4 :
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %scale_acc, %second_matmul
= transform.decompose_tiled_attention %blocked_attention :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// Vectorize function
// ==========================================