[LinalgExt] Masked Attention Implementation (#18525)
Enables float/boolean mask as parameters and created linalg generic ops
to apply masking. This image (https://imgur.com/a/1MePgcy) elaborates on
the main files changed and how they enable masked attention:
- Blue boxes represent changed .cpp and .td files to
enable/pass/decompose the mask
- Yellow boxes represent the different op classes
- Red boxes represent test mlir files pertaining to certain .cpp/.td
implementations or ops
For quick reference, AggregateOpInterfaceImpl.cpp contains the bulk of
the actual mask decomposition (QK += mask)
And for clarification, TileAttention.cpp only holds the
convertToOnlineAttentionOp and getTileAttentionIndexingMaps functions;
TilingInterfaceImpl.cpp contains the main tiling capabilities in the
form of AttentionOp::getTiledImplementation and
OnlineAttentionOp::getTiledImplementation.
Updated version of https://github.com/iree-org/iree/pull/18461. This
version was created to include scale affine map and enable fused
attention (incorporated
https://github.com/IanWood1/iree/tree/raikonen/sdpa_mask).
- To that end, many modifications in tests are for adding the scale
affine map (without much functionality change)
- For tiling and decomposition tests, most functionality tests are
included in "tiling.mlir" and "decompose_online_attention.mlir". On the
other hand, the "tile_attention.mlir and "decompose_attention.mlir" are
old paths intended to be be retired and deprecate soon. Hence, no major
tests were added it there.
Test directory for numerical verification:
https://github.com/rohan-tan-bhowmik/iree-masked-attention-test
---------
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
index d775c1c..b27193d 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
@@ -73,21 +73,22 @@
};
} // namespace
-static SmallVector<AffineMap>
-getStandardAttentionIndexingMaps(MLIRContext *ctx) {
+static SmallVector<AffineMap> getStandardAttentionIndexingMaps(MLIRContext *ctx,
+ bool hasMask) {
AffineExpr m, n, k1, k2;
bindDims(ctx, m, n, k1, k2);
- AffineMap qMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
- AffineMap kMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
- AffineMap vMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
- AffineMap rMap =
- AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
-
- return {qMap, kMap, vMap, rMap};
+ auto qMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
+ auto kMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
+ auto vMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
+ auto sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, ctx);
+ auto rMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
+ if (hasMask) {
+ // Add mask map only if it exists
+ auto mMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k2}, ctx);
+ return {qMap, kMap, vMap, sMap, mMap, rMap};
+ }
+ return {qMap, kMap, vMap, sMap, rMap};
}
struct AttentionOpConversion
@@ -100,6 +101,7 @@
Value query = op.getQuery();
Value key = op.getKey();
Value value = op.getValue();
+ std::optional<Value> optionalMask = op.getAttnMask();
ShapedType outputType = op.getOutputType();
@@ -147,10 +149,14 @@
loc, targetType, rewriter.getFloatAttr(targetType, dk));
// Add batches to standard attention indexing maps.
- SmallVector<AffineMap> indexingMaps = getStandardAttentionIndexingMaps(ctx);
+ SmallVector<AffineMap> indexingMaps =
+ getStandardAttentionIndexingMaps(ctx, optionalMask.has_value());
+
int64_t numBatches = op.getQueryType().getRank() - 2;
for (AffineMap &map : indexingMaps) {
map = map.shiftDims(numBatches);
+ if (map.getNumResults() == 0)
+ continue;
for (int batch : llvm::seq<int>(numBatches)) {
map = map.insertResult(rewriter.getAffineDimExpr(batch), batch);
}
@@ -158,7 +164,7 @@
auto attention = rewriter.create<IREE::LinalgExt::AttentionOp>(
loc, result.getType(), query, key, value, scale, result,
- rewriter.getAffineMapArrayAttr(indexingMaps));
+ rewriter.getAffineMapArrayAttr(indexingMaps), optionalMask);
rewriter.replaceOp(op, attention.getResult(0));
return success();
diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
index 06e85e7..0de566e 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir
@@ -8,14 +8,15 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
-// CHECK: %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
+// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32>
-// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
+// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x3x4xf32>
// -----
@@ -27,14 +28,15 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
-// CHECK: %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
+// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32>
-// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
+// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x8x4xf32>
// -----
@@ -46,14 +48,15 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>,
-// CHECK: %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
+// CHECK: %[[ARG3:.*]]: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
-// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
+// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<1x3x4xf32>
// -----
@@ -65,6 +68,7 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-LABEL: func.func @attention_dyn(
@@ -76,5 +80,5 @@
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4xf32>
-// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
+// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
// CHECK: return %[[ATTN]] : tensor<?x?x4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 7598316..4010ae8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -257,6 +257,7 @@
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
@@ -272,7 +273,7 @@
%scale = arith.constant 1.0 : f16
%out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR],
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR],
lowering_config = #config }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
index e67e7af..6533f23 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir
@@ -1747,6 +1747,7 @@
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
index 5ebd5c3..227ae1e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
@@ -369,6 +369,7 @@
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
@@ -407,6 +408,7 @@
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d1, d2, d3, d4) -> (d1, d2)>,
affine_map<(d1, d2, d3, d4) -> (d3, d2)>,
affine_map<(d1, d2, d3, d4) -> (d3, d4)>,
+ affine_map<(d1, d2, d3, d4) -> ()>,
affine_map<(d1, d2, d3, d4) -> (d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<1024x512xf16>, tensor<128x512xf16>, tensor<128x512xf16>, f16) outs(%7 : tensor<1024x512xf16>) -> tensor<1024x512xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : tensor<1024x512xf16> -> !flow.dispatch.tensor<writeonly:tensor<1024x512xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index eb8f4f1..6ebf8ab 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -574,6 +574,7 @@
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
lowering_config = #config}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
@@ -637,7 +638,7 @@
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
- %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
+ %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
@@ -692,7 +693,7 @@
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
- %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
+ %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 649cbd9..3f2c090 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -22,6 +22,7 @@
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
index 13a84ec..109b107 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir
@@ -22,6 +22,7 @@
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index 596a925..9607c9e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -149,11 +149,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == $_op);
- if(opOperand->getOperandNumber() >= $_op.getNumDpsInputs()){
- return $_op.getIndexingMapsForResults()[opOperand->getOperandNumber() - $_op.getNumDpsInputs()];
- }else {
- return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()];
- }
+ return getIndexingMapsArray()[opOperand->getOperandNumber()];
}]
>,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 1805597..e08838b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1200,6 +1200,15 @@
// AttentionOp
//===----------------------------------------------------------------------===//
+void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ TypeRange results, Value query, Value key, Value value,
+ Value scale, ValueRange outputs, ArrayAttr indexingMaps,
+ std::optional<Value> mask) {
+ Value maskIn = mask.value_or(Value());
+ build(odsBuilder, odsState, results, query, key, value, scale, maskIn,
+ outputs, indexingMaps);
+}
+
LogicalResult AttentionOp::verify() {
AttentionOp attnOp = *this;
@@ -1212,6 +1221,9 @@
// Check if indexing maps can represent attention.
SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
+ if (indexingMaps.size() != getOperation()->getNumOperands()) {
+ return attnOp->emitOpError("expected an indexing map for each operand");
+ }
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);
if (failed(maybeOpInfo)) {
@@ -1246,8 +1258,8 @@
}
if (shape[pos] != valShape[i]) {
return attnOp->emitError("Shape Mismatch for ")
- << operandName << ". Expected: " << shape[pos]
- << " Got: " << valShape[i];
+ << operandName << " at position " << i
+ << ". Expected: " << shape[pos] << " Got: " << valShape[i];
}
}
return success();
@@ -1261,6 +1273,38 @@
return failure();
}
+ // Additional check case if mask exists
+ if (auto maskMap = getMaskMap()) {
+ if (failed(checkShape("Mask", getMaskType()->getShape(), *maskMap)))
+ return failure();
+ }
+
+ int expectedSymbols = getQueryMap().getNumInputs();
+ auto checkDomain =
+ [&attnOp, &expectedSymbols](StringRef operandName,
+ AffineMap indexingMap) -> LogicalResult {
+ if (expectedSymbols != indexingMap.getNumInputs()) {
+ return attnOp->emitError("Mismatched map domain for ")
+ << operandName << ". Expected: " << expectedSymbols
+ << " Got: " << indexingMap.getNumInputs();
+ }
+ return success();
+ };
+
+ if (failed(checkDomain("Query", getQueryMap())) ||
+ failed(checkDomain("Key", getKeyMap())) ||
+ failed(checkDomain("Value", getValueMap())) ||
+ failed(checkDomain("Scale", getScaleMap())) ||
+ failed(checkDomain("Output", getOutputMap()))) {
+ return failure();
+ }
+
+ // Additional check case if mask exists
+ if (auto maskMap = getMaskMap()) {
+ if (failed(checkDomain("Mask", *maskMap)))
+ return failure();
+ }
+
if (isTiled) {
// Tiled/Flash attention.
Type maxElementType = getMaxType()->getElementType();
@@ -1324,20 +1368,29 @@
SmallVector<AffineMap> AttentionOp::getIndexingMapsForOperands() {
auto maps = getIndexingMapsArray();
- return SmallVector<AffineMap>(maps.begin(),
- maps.begin() + getNumDpsInputs() - 1);
+ maps.resize(getNumDpsInputs());
+ return maps;
}
SmallVector<AffineMap> AttentionOp::getIndexingMapsForResults() {
auto maps = getIndexingMapsArray();
- return SmallVector<AffineMap>(maps.begin() + getNumDpsInputs() - 1,
- maps.end());
+ return SmallVector<AffineMap>(maps.begin() + getNumDpsInputs(), maps.end());
}
//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//
+void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ TypeRange results, Value query, Value key,
+ Value value, Value scale, Value output, Value max,
+ Value sum, ArrayAttr indexingMaps,
+ std::optional<Value> mask) {
+ Value maskIn = mask.value_or(Value());
+ build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output,
+ max, sum, indexingMaps);
+}
+
LogicalResult OnlineAttentionOp::verify() {
OnlineAttentionOp attnOp = *this;
@@ -1389,11 +1442,46 @@
return failure();
}
+ // Additional check case if mask exists
+ if (auto maskMap = getMaskMap()) {
+ if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap)))
+ return failure();
+ }
+
+ int expectedSymbols = getQueryMap().getNumInputs();
+ auto checkDomain =
+ [&attnOp, &expectedSymbols](StringRef operandName,
+ AffineMap indexingMap) -> LogicalResult {
+ if (expectedSymbols != indexingMap.getNumInputs()) {
+ return attnOp->emitError("Mismatched map domain for ")
+ << operandName << ". Expected: " << expectedSymbols
+ << " Got: " << indexingMap.getNumInputs();
+ }
+ return success();
+ };
+
+ if (failed(checkDomain("Query", getQueryMap())) ||
+ failed(checkDomain("Key", getKeyMap())) ||
+ failed(checkDomain("Value", getValueMap())) ||
+ failed(checkDomain("Scale", getScaleMap())) ||
+ failed(checkDomain("Output", getOutputMap())) ||
+ failed(checkDomain("Max", getMaxMap())) ||
+ failed(checkDomain("Sum", getSumMap()))) {
+ return failure();
+ }
+
+ // Additional check case if mask exists
+ if (auto maskMap = getMaskMap()) {
+ if (failed(checkDomain("Mask", *maskMap)))
+ return failure();
+ }
+
return success();
}
MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
- return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
+ return MutableOperandRange(*this, /*numInputs=*/getMask() ? 5 : 4,
+ /*numInits=*/3);
}
LogicalResult OnlineAttentionOp::reifyResultShapes(
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c292452..c642fd0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -461,7 +461,7 @@
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation",
- "generateResultTileValue"]>]> {
+ "generateResultTileValue"]>, AttrSizedOperandSegments]> {
let summary = "Attention operator";
let description = [{
Computes the scaled dot product attention function:
@@ -471,6 +471,10 @@
Here Q, K, V are given tensors and scale is a scalar value specifying
the scale to use.
+ If an additional mask argument M is included, the result of the first matmul is modified according to:
+
+ Q @ K.T += M
+
For self-attention, all inputs and the result have the same shape BxNxd
where B is the batch dimension, N is the sequence length and d is head
dimension. Typically N >>> d. Usually, this operator also performs
@@ -495,6 +499,7 @@
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
+ Optional<AnyShaped>:$mask,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps
);
@@ -505,11 +510,22 @@
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
- `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
+ `ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )? `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($results)^)?
}];
+ let builders = [
+ OpBuilder<(ins "TypeRange":$results,
+ "Value":$query,
+ "Value":$key,
+ "Value":$value,
+ "Value":$scale,
+ "ValueRange":$outputs,
+ "ArrayAttr":$indexing_maps,
+ CArg<"std::optional<Value>", "std::nullopt">:$mask)>
+ ];
+
let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
@@ -530,9 +546,18 @@
AffineMap getValueMap() {
return cast<AffineMap>(getIndexingMapsArray()[2]);
}
- AffineMap getOutputMap() {
+ AffineMap getScaleMap() {
return cast<AffineMap>(getIndexingMapsArray()[3]);
}
+ std::optional<AffineMap> getMaskMap() {
+ if (getMask()) {
+ return cast<AffineMap>(getIndexingMapsArray()[4]);
+ }
+ return std::nullopt;
+ }
+ AffineMap getOutputMap() {
+ return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]);
+ }
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
@@ -545,6 +570,11 @@
ShapedType getValueType() {
return cast<ShapedType>(getValue().getType());
}
+ std::optional<ShapedType> getMaskType() {
+ std::optional<Value> mask = getMask();
+ if (!mask) return std::nullopt;
+ return cast<ShapedType>(mask->getType());
+ }
FloatType getScaleType() {
return cast<FloatType>(cast<ShapedType>(getScale().getType()));
}
@@ -559,12 +589,12 @@
std::optional<AffineMap> getMaxMap() {
if (getNumResults() < 2)
return std::nullopt;
- return cast<AffineMap>(getIndexingMapsArray()[4]);
+ return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 1]);
}
std::optional<AffineMap> getSumMap() {
if (getNumResults() < 3)
return std::nullopt;
- return cast<AffineMap>(getIndexingMapsArray()[5]);
+ return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 2]);
}
Value getOutput() {
return getDpsInitOperand(0)->get();
@@ -599,6 +629,12 @@
int64_t getValueRank() {
return getValueType().getRank();
}
+ std::optional<int64_t> getMaskRank() {
+ std::optional<ShapedType> maskType = getMaskType();
+ if (!maskType)
+ return std::nullopt;
+ return maskType->getRank();
+ }
int64_t getOutputRank() {
return getOutputType().getRank();
}
@@ -650,8 +686,12 @@
online_attention(Q, K, V, scale, running_max, running_sum)
= online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
+ If an additional mask argument M is included, the result of the first matmul is modified according to:
+
+ Q @ K.T += M
+
The advantage of this online_normalizer is that it can be tiled along
- it's reduction dimension, making the online_attention operator:
+ its reduction dimension, making the online_attention operator:
- Tilable along softmax reduction dimension
- Associative along softmax reduction dimension
- Commutative along softmax associative dimension
@@ -666,6 +706,7 @@
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
+ Optional<AnyShaped>:$mask,
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
@@ -677,11 +718,25 @@
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
- `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
+ `ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
(`->` type($results)^)?
}];
+ let builders = [
+ OpBuilder<(ins "TypeRange":$results,
+ "Value":$query,
+ "Value":$key,
+ "Value":$value,
+ "Value":$scale,
+ "Value":$output,
+ "Value":$max,
+ "Value":$sum,
+ "ArrayAttr":$indexing_maps,
+ CArg<"std::optional<Value>", "std::nullopt">:$mask)>
+ ];
+
+
let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
@@ -698,14 +753,23 @@
AffineMap getValueMap() {
return getIndexingMapsArray()[2];
}
- AffineMap getOutputMap() {
+ AffineMap getScaleMap() {
return getIndexingMapsArray()[3];
}
+ std::optional<AffineMap> getMaskMap() {
+ if (getMask()) {
+ return cast<AffineMap>(getIndexingMapsArray()[4]);
+ }
+ return std::nullopt;
+ }
+ AffineMap getOutputMap() {
+ return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]);
+ }
AffineMap getMaxMap() {
- return getIndexingMapsArray()[4];
+ return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 1]);
}
AffineMap getSumMap() {
- return getIndexingMapsArray()[5];
+ return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 2]);
}
int64_t getIterationDomainRank() {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 2f8d8ef..b3d5152 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -712,6 +712,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
ins(%query, %key, %value, %scale : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, f32) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32>
return %1 : tensor<6x12x20x8xf32>
@@ -728,6 +729,9 @@
%1:3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> ()>,
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
affine_map<(d0, d1, d2, d3) -> (d0)>]}
ins(%query, %key, %value, %scale : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>, f32) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
return %1#0, %1#1, %1#2 : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>
@@ -738,10 +742,11 @@
func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: f32) -> tensor<192x1024x64xf32> {
%0 = tensor.empty() : tensor<192x1024x64xf32>
%scale = arith.constant 1.0 : f32
- // expected-error @+5 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}}
+ // expected-error @+6 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}}
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
return %1 : tensor<192x1024x64xf32>
@@ -749,12 +754,28 @@
// -----
-func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
+func.func @attention_missing_affine_map(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
%0 = tensor.empty() : tensor<192x1024x64xf32>
%scale = arith.constant 1.0 : f32
- // expected-error @below {{'iree_linalg_ext.attention' op failed to verify op's indexing maps}}
+ // expected-error @below {{'iree_linalg_ext.attention' op expected an indexing map for each operand}}
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+ ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+ return %1 : tensor<192x1024x64xf32>
+}
+
+// -----
+
+func.func @attention_affine_map_domain_mismatch(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
+ %0 = tensor.empty() : tensor<192x1024x64xf32>
+ %scale = arith.constant 1.0 : f32
+ // expected-error @below {{Mismatched map domain for Scale. Expected: 5 Got: 4}}
+ %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
return %1 : tensor<192x1024x64xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 77ba7aa..0f117cc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1087,6 +1087,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
return %1 : tensor<192x1024x64xf32>
@@ -1095,6 +1096,7 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1103,7 +1105,7 @@
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
@@ -1118,6 +1120,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
return %1 : tensor<192x1024x64xf32>
@@ -1125,6 +1128,7 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK: func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1133,7 +1137,7 @@
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
@@ -1150,6 +1154,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
return %1 : tensor<192x1024x64xf32>
@@ -1157,6 +1162,7 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK: func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1165,7 +1171,7 @@
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
@@ -1179,6 +1185,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
@@ -1186,6 +1193,7 @@
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
+// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK: func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1193,7 +1201,7 @@
// CHECK-SAME: {
// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]}
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
// CHECK-SAME: tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
// CHECK-SAME: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index e310589..2f2e046 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -100,10 +100,10 @@
// used by attention's exp2 who's value is always > 0.
Value mx = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mxDbl));
- Value clamped = b.create<arith::MinimumFOp>(loc, mx, args[0]);
+ Value clamp = b.create<arith::MinimumFOp>(loc, mx, args[0]);
// Convert scale to the same datatype as input.
- Value trunc = convertScalarToDtype(b, loc, clamped, dstTy,
+ Value trunc = convertScalarToDtype(b, loc, clamp, dstTy,
/*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, trunc);
});
@@ -175,6 +175,53 @@
return genericOp.getResult(0);
}
+static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
+ AffineMap maskMap, Value qk, Value mask) {
+
+ SmallVector<AffineMap> compressedMaps =
+ compressUnusedDims(SmallVector<AffineMap>{qkMap, maskMap});
+ qkMap = compressedMaps[0];
+ maskMap = compressedMaps[1];
+
+ SmallVector<utils::IteratorType> iteratorTypes(qkMap.getNumDims(),
+ utils::IteratorType::parallel);
+
+ Value zero = builder.create<arith::ConstantOp>(
+ loc, builder.getFloatAttr(getElementTypeOrSelf(qk.getType()), 0.0));
+ Value negInf = builder.create<arith::ConstantOp>(
+ loc, builder.getFloatAttr(getElementTypeOrSelf(qk.getType()),
+ -std::numeric_limits<double>::infinity()));
+ auto genericOp = builder.create<linalg::GenericOp>(
+ loc, qk.getType(), SmallVector<Value>{mask}, qk,
+ SmallVector<AffineMap>{maskMap, qkMap}, iteratorTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value qkVal = args[1];
+ Value maskVal = args[0];
+
+ // TODO: Replace bool mask condition once treated as i1 (instead of i8)
+ if (maskVal.getType().isInteger()) {
+ maskVal =
+ b.create<arith::TruncIOp>(loc, builder.getI1Type(), maskVal);
+ maskVal = b.create<arith::SelectOp>(loc, maskVal, zero, negInf);
+ } else {
+ maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(),
+ /*isUnsignedCast=*/false);
+ // Scaling to compensate for base-2 softmax
+ Value log2e = b.create<arith::ConstantOp>(
+ loc, b.getFloatAttr(qkVal.getType(), M_LOG2E));
+ maskVal = b.create<arith::MulFOp>(loc, maskVal, log2e);
+ }
+ // Finally, set the returned value to the qk element plus the mask
+ // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead
+ // of a SelectFOp to stay consistent with the additive definition of
+ // attention masking)
+ Value add = b.create<arith::AddFOp>(loc, qkVal, maskVal);
+ b.create<linalg::YieldOp>(loc, add);
+ });
+
+ return genericOp.getResult(0);
+}
+
// Compute output = exp2(output - input)
static Value computeSubAndExp2(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
@@ -240,6 +287,7 @@
Value query = getQuery();
Value key = getKey();
Value value = getValue();
+ std::optional<Value> mask = getMask();
Value oldAcc = getOutput();
Value oldMax = getMax();
Value oldSum = getSum();
@@ -265,6 +313,9 @@
auto qETy = getElementTypeOrSelf(query.getType());
auto vETy = getElementTypeOrSelf(value.getType());
+ AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(),
+ /*symbolCount=*/0, getContext());
+
// In the original algorithm, the scaling is done after the softmax:
// softmax(Q @ K.T * scale) @ V
//
@@ -275,8 +326,6 @@
// significantly affect numerics.
if (qETy.getIntOrFloatBitWidth() > 8) {
AffineMap qMap = getQueryMap();
- AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
- /*symbolCount=*/0, getContext());
query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
query, scale);
}
@@ -325,6 +374,11 @@
offset);
}
+ // S += mask
+ if (mask != nullptr) {
+ s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value());
+ }
+
// TODO: This decomposition should be in a seperate op called
// "online softmax".
// ---- Online Softmax ----
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index ef5c21c..a01e08a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -8,6 +8,7 @@
// The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and
// modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`.
+#include <optional>
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -106,7 +107,7 @@
SmallVector<AffineExpr> newExprs;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned pos = cast<AffineDimExpr>(expr).getPosition();
- SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
+ auto expandedExprs = llvm::to_vector_of<AffineExpr, 6>(
llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
return builder.getAffineDimExpr(static_cast<unsigned>(v));
}));
@@ -187,8 +188,7 @@
: collapsingReshapeOp.getReassociationMaps(),
expandedType.getShape(), collapsedType.getShape(), rewriter)))
return std::nullopt;
-
- SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
+ auto expandedOpIndexingMaps = llvm::to_vector_of<AffineMap, 6>(
llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) {
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
}));
@@ -254,12 +254,18 @@
}
}
+ Value maskOperand;
+ if (expandedOpOperands.size() > 4) {
+ maskOperand = expandedOpOperands[4];
+ }
+
// Create a new `AttentionOp` that has the computed operands/indexing maps.
TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp = rewriter.create<AttentionOp>(
attentionOp.getLoc(), resultTypes, expandedOpOperands[0],
expandedOpOperands[1], expandedOpOperands[2], expandedOpOperands[3],
- outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps));
+ outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps),
+ maskOperand);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index d68a2eb..8089cd0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -161,6 +161,7 @@
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
AffineMap vMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
+ AffineMap sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {}, ctx);
AffineMap rMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
AffineMap maxMap =
@@ -174,7 +175,7 @@
vMap = AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), vDims, ctx);
}
- SmallVector<AffineMap> attentionMaps = {qMap, kMap, vMap,
+ SmallVector<AffineMap> attentionMaps = {qMap, kMap, vMap, sMap,
rMap, maxMap, sumMap};
// Add batches to standard attention indexing maps.
int64_t numBatches = tiledInputRank - 2;
@@ -417,10 +418,14 @@
SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
indexingMaps.push_back(maxMap);
indexingMaps.push_back(sumMap);
+
+ Value mask = attnOp.getMask() ? attnOp.getMask() : Value();
+
OnlineAttentionOp onlineAttn = rewriter.create<OnlineAttentionOp>(
loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
- accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps));
+ mask, accFill, maxFill, sumFill,
+ rewriter.getAffineMapArrayAttr(indexingMaps));
onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary());
ops.push_back(onlineAttn);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index afa3578..016f9d0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -1887,6 +1887,16 @@
// Scale
tiledOperands.emplace_back(scale);
+ // Mask
+ Value attnMask = getMask();
+ if (attnMask) {
+ SmallVector<Range> maskSlice =
+ getPermutedSlice(*getMaskMap(), offsets, sizes);
+ Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice);
+ tiledOperands.emplace_back(maskSliceOp->getResult(0));
+ slices.push_back(maskSliceOp);
+ }
+
// Output
{
Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice);
@@ -1909,7 +1919,7 @@
slices.push_back(maxSliceOp);
}
- std::optional<Value> sum = getMax();
+ std::optional<Value> sum = getSum();
if (sum) {
SmallVector<Range> sumSlice =
getPermutedSlice(*getSumMap(), offsets, sizes);
@@ -1923,12 +1933,13 @@
SmallVector<Type> resultTypes;
if (hasPureTensorSemantics()) {
- resultTypes.push_back(tiledOperands[4].getType());
+ int64_t baseIdx = attnMask ? 5 : 4;
+ resultTypes.push_back(tiledOperands[baseIdx].getType());
if (max) {
- resultTypes.push_back(tiledOperands[5].getType());
+ resultTypes.push_back(tiledOperands[baseIdx + 1].getType());
}
if (sum) {
- resultTypes.push_back(tiledOperands[6].getType());
+ resultTypes.push_back(tiledOperands[baseIdx + 2].getType());
}
}
@@ -2024,6 +2035,11 @@
SmallVector<Range> keySlice = getPermutedSlice(getKeyMap(), offsets, sizes);
SmallVector<Range> valueSlice =
getPermutedSlice(getValueMap(), offsets, sizes);
+ std::optional<SmallVector<Range>> maskSlice;
+ if (auto maskMap = getMaskMap()) {
+ maskSlice = getPermutedSlice(*maskMap, offsets, sizes);
+ }
+
SmallVector<Range> outputSlice =
getPermutedSlice(getOutputMap(), offsets, sizes);
SmallVector<Range> maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes);
@@ -2065,6 +2081,16 @@
tiledOperands.emplace_back(scale);
+ // Mask
+ Value attnMask = getMask();
+ if (attnMask) {
+ SmallVector<Range> maskSlice =
+ getPermutedSlice(*getMaskMap(), offsets, sizes);
+ Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice);
+ tiledOperands.emplace_back(maskSliceOp->getResult(0));
+ slices.push_back(maskSliceOp);
+ }
+
/// Output
{
Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice);
@@ -2096,9 +2122,9 @@
}
SmallVector<Type> resultTypes;
- resultTypes.push_back(tiledOperands[4].getType());
- resultTypes.push_back(tiledOperands[5].getType());
- resultTypes.push_back(tiledOperands[6].getType());
+ resultTypes.push_back(tiledOperands[tiledOperands.size() - 3].getType());
+ resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType());
+ resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
index 7eb4c0a..202ab11 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir
@@ -3,14 +3,15 @@
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
-#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>)
-> tensor<2x10x4096x128xf16> {
%scale = arith.constant 0.125 : f16
%acc = tensor.empty() : tensor<2x10x4096x128xf16>
%out = iree_linalg_ext.attention
- {indexing_maps = [#map, #map1, #map2, #map3]}
+ {indexing_maps = [#map, #map1, #map2, #map3, #map4]}
ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16)
outs(%acc : tensor<2x10x4096x128xf16>) -> tensor<2x10x4096x128xf16>
func.return %out : tensor<2x10x4096x128xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
index 0344eb7..19d6a6b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir
@@ -6,6 +6,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
return %1 : tensor<1x1024x64xf32>
@@ -108,6 +109,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
@@ -213,6 +215,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
return %1 : tensor<1x1024x64xf16>
@@ -333,6 +336,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
return %1 : tensor<1x1024x64xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
index df46f9f..e0aa548 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir
@@ -3,6 +3,7 @@
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
@@ -16,7 +17,7 @@
%scale = arith.constant 1.0 : f16
%out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -82,6 +83,7 @@
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
@@ -95,7 +97,7 @@
%scale = arith.constant 1.0 : f32
%out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -165,3 +167,67 @@
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
+ %key: tensor<192x1024x64xf8E4M3FNUZ>,
+ %value: tensor<192x1024x64xf8E4M3FNUZ>,
+ %mask: tensor<192x1024x1024xf8E4M3FNUZ>,
+ %output: tensor<192x1024x64xf32>,
+ %max: tensor<192x1024xf32>,
+ %sum: tensor<192x1024xf32>)
+ -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
+ %scale = arith.constant 1.0 : f16
+
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>)
+ outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
+}
+// CHECK-LABEL: @attention_f8_masked
+// S = Q @ K
+// CHECK: linalg.generic
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// S = S * scale
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// S = S + mask
+// CHECK: arith.addf
+// newMax = max(oldMax, rowMax(S))
+// CHECK: linalg.generic
+// CHECK: arith.maximumf
+// CHECK: linalg.yield
+// P = exp2(S - newMax)
+// CHECK: linalg.generic
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// norm = exp2(oldMax - newMax)
+// CHECK: linalg.generic
+// CHECK: arith.subf
+// CHECK: math.exp2
+// CHECK: linalg.yield
+// normSum = norm * oldSum
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// newSum = normSum + rowMax(P)
+// CHECK: linalg.generic
+// CHECK: arith.addf
+// CHECK: linalg.yield
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
index be9c9da..51d3c51 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir
@@ -8,6 +8,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32>
return %1 : tensor<1x1024x64xf32>
@@ -61,6 +62,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
@@ -117,6 +119,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
return %1 : tensor<1x1024x64xf16>
@@ -151,6 +154,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16>
return %1 : tensor<1x1024x64xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 147d592..9bfa8b4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -628,6 +628,7 @@
transform.yield
}
}
+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
// CHECK: func.func @topk_tile_tensor
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
@@ -1537,6 +1538,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
return %1 : tensor<192x1024x64xf32>
@@ -1553,6 +1555,7 @@
// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1580,7 +1583,7 @@
// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], %[[ARG5]], 0] [%[[D2]],
// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
// CHECK: %[[D5:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]}
+// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_O]]]}
// CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
// CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32)
// CHECK-SAME: outs(%[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
@@ -1595,11 +1598,152 @@
// -----
+func.func @attention_float_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32> {
+ %0 = tensor.empty() : tensor<192x1024x64xf32>
+ %scale = arith.constant 1.0 : f32
+ %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+ ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+ return %1 : tensor<192x1024x64xf32>
+}
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)>
+// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[MAP_M:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+// CHECK: func.func @attention_float_mask(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32>
+// CHECK-SAME: {
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C192:.+]] = arith.constant 192 : index
+// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK: %[[D1:.+]] = scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ARG5:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<192x1024x64xf32>) {
+// CHECK: %[[D2:.+]] = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]]
+// CHECK-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG5]]) -> (tensor<192x1024x64xf32>) {
+// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG4]])
+// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME: %[[D4]], 1024] [1, 1, 1] : tensor<192x1024x1024xf32> to tensor<?x?x1024xf32>
+// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK: %[[D5:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_M]], #[[MAP_O]]]}
+// CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
+// CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]], %[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32, tensor<?x?x1024xf32>)
+// CHECK-SAME: outs(%[[EXTRACTED_SLICE_3]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG7]][%[[ARG4]], %[[ARG6]], 0]
+// CHECK-SAME: [%[[D3]], %[[D4]], 64] [1, 1, 1] : tensor<?x?x64xf32> into tensor<192x1024x64xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<192x1024x64xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D2]] : tensor<192x1024x64xf32>
+// CHECK: }
+// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
+// CHECK: }
+
+// -----
+
+func.func @attention_bool_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xi1>) -> tensor<192x1024x64xf32> {
+ %0 = tensor.empty() : tensor<192x1024x64xf32>
+ %scale = arith.constant 1.0 : f32
+ %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
+ ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xi1>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+ return %1 : tensor<192x1024x64xf32>
+}
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)>
+// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[MAP_M:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+// CHECK: func.func @attention_bool_mask(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
+// CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<192x1024x1024xi1>) -> tensor<192x1024x64xf32>
+// CHECK-SAME: {
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C192:.+]] = arith.constant 192 : index
+// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK: %[[D1:.+]] = scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ARG5:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<192x1024x64xf32>) {
+// CHECK: %[[D2:.+]] = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]]
+// CHECK-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG5]]) -> (tensor<192x1024x64xf32>) {
+// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG4]])
+// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1,
+// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor<?x1024x64xf32>
+// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME: %[[D4]], 1024] [1, 1, 1] : tensor<192x1024x1024xi1> to tensor<?x?x1024xi1>
+// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]],
+// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<?x?x64xf32>
+// CHECK: %[[D5:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_M]], #[[MAP_O]]]}
+// CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]],
+// CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]], %[[EXTRACTED_SLICE_2]] : tensor<?x?x64xf32>, tensor<?x1024x64xf32>, tensor<?x1024x64xf32>, f32, tensor<?x?x1024xi1>)
+// CHECK-SAME: outs(%[[EXTRACTED_SLICE_3]] : tensor<?x?x64xf32>) -> tensor<?x?x64xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG7]][%[[ARG4]], %[[ARG6]], 0]
+// CHECK-SAME: [%[[D3]], %[[D4]], 64] [1, 1, 1] : tensor<?x?x64xf32> into tensor<192x1024x64xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<192x1024x64xf32>
+// CHECK: }
+// CHECK: scf.yield %[[D2]] : tensor<192x1024x64xf32>
+// CHECK: }
+// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
+// CHECK: }
+
+// -----
+
func.func @attention_memref(%query: memref<192x1024x64xf32>, %key: memref<192x1024x64xf32>, %value: memref<192x1024x64xf32>, %output: memref<192x1024x64xf32>) {
%scale = arith.constant 1.0 : f32
iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>, f32) outs(%output : memref<192x1024x64xf32>)
return
@@ -1616,6 +1760,7 @@
// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK: func.func @attention_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
@@ -1640,7 +1785,7 @@
// CHECK: %[[SUBVIEW_2:.+]] = memref.subview %[[ARG3]][%[[ARG4]], %[[ARG5]], 0] [%[[D0]], %[[D1]], 64] [1, 1,
// CHECK-SAME: 1] : memref<192x1024x64xf32> to memref<?x?x64xf32, strided<[65536, 64, 1], offset: ?>>
// CHECK: iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]}
+// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_O]]]}
// CHECK-SAME: ins(%[[SUBVIEW]], %[[SUBVIEW_0]], %[[SUBVIEW_1]], %[[C1_F32]] : memref<?x?x64xf32,
// CHECK-SAME: strided<[65536, 64, 1], offset: ?>>, memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>,
// CHECK-SAME: memref<?x1024x64xf32, strided<[65536, 64, 1], offset: ?>>, f32) outs(%[[SUBVIEW_2]] :
@@ -1662,6 +1807,7 @@
indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
ins(%query, %key, %value, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16)
outs(%0 : tensor<2x10x4096x64xf16>) -> tensor<2x10x4096x64xf16>
@@ -1706,6 +1852,7 @@
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
@@ -1723,7 +1870,7 @@
%sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
%out:3 = iree_linalg_ext.online_attention
- { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] }
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32)
outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
@@ -1737,8 +1884,9 @@
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
// CHECK-LABEL: @online_attention
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
@@ -1751,7 +1899,7 @@
// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
// CHECK-DAG: iree_linalg_ext.online_attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]}
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP5]]]}
// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32)
// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
// CHECK: scf.forall.in_parallel
@@ -1763,3 +1911,150 @@
transform.yield
}
}
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_float_mask(%query: tensor<192x1024x64xf32>,
+ %key: tensor<192x1024x64xf32>,
+ %value: tensor<192x1024x64xf32>,
+ %mask: tensor<192x1024x1024xf32>)
+ -> tensor<192x1024x64xf32> {
+ %scale = arith.constant 1.0 : f32
+
+ %output_empty = tensor.empty() : tensor<192x1024x64xf32>
+ %row_red_empty = tensor.empty() : tensor<192x1024xf32>
+
+ %sum_ident = arith.constant 0.000000e+00 : f32
+ %max_ident = arith.constant -3.40282347E+38 : f32
+
+ %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+ %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+ %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+
+ // Adjust the operation to correctly handle the mask
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xf32>)
+ outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0 : tensor<192x1024x64xf32>
+}
+
+// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-LABEL: @online_attention_float_mask
+// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
+// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
+// CHECK-DAG: %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
+// CHECK-DAG: %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
+// CHECK-DAG: %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
+// CHECK-DAG: %[[MASK:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 1024] [1, 1, 1] : tensor<192x1024x1024xf32> to tensor<4x128x1024xf32>
+// CHECK-DAG: %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
+// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: iree_linalg_ext.online_attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP6]], #[[$MAP6]]]}
+// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}}, %[[MASK]] : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32, tensor<4x128x1024xf32>)
+// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
+// CHECK: scf.forall.in_parallel
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
+#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
+#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
+#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
+#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)>
+#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
+#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>
+
+func.func @online_attention_bool_mask(%query: tensor<192x1024x64xf32>,
+ %key: tensor<192x1024x64xf32>,
+ %value: tensor<192x1024x64xf32>,
+ %mask: tensor<192x1024x1024xi1>)
+ -> tensor<192x1024x64xf32> {
+ %scale = arith.constant 1.0 : f32
+
+ %output_empty = tensor.empty() : tensor<192x1024x64xf32>
+ %row_red_empty = tensor.empty() : tensor<192x1024xf32>
+
+ %sum_ident = arith.constant 0.000000e+00 : f32
+ %max_ident = arith.constant -3.40282347E+38 : f32
+
+ %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+ %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+ %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32>
+
+ // Adjust the operation to correctly handle the mask
+ %out:3 = iree_linalg_ext.online_attention
+ { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] }
+ ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xi1>)
+ outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
+ -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
+
+ return %out#0 : tensor<192x1024x64xf32>
+}
+
+
+// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-LABEL: @online_attention_bool_mask
+// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2)
+// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]])
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]])
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]])
+// CHECK-DAG: %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32>
+// CHECK-DAG: %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32>
+// CHECK-DAG: %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32>
+// CHECK-DAG: %[[MASK:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 1024] [1, 1, 1] : tensor<192x1024x1024xi1> to tensor<4x128x1024xi1>
+// CHECK-DAG: %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32>
+// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32>
+// CHECK-DAG: iree_linalg_ext.online_attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP6]], #[[$MAP6]]]}
+// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}}, %[[MASK]] : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32, tensor<4x128x1024xi1>)
+// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>)
+// CHECK: scf.forall.in_parallel
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
index 428f240..90797a0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/SetOperations.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
@@ -83,10 +84,6 @@
FailureOr<AttentionOpDetail>
AttentionOpDetail::get(ArrayRef<AffineMap> indexingMaps) {
- if (indexingMaps.size() != 4 && indexingMaps.size() != 6) {
- return failure();
- }
-
AttentionOpDetail opInfo;
opInfo.inferFromIndexingMaps(indexingMaps);
opInfo.maps = SmallVector<AffineMap>(indexingMaps);
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 6b02036..e0e0cef 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -3,11 +3,12 @@
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> {
%0 = tensor.empty() : tensor<20x4096x64xf16>
- %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+ %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
util.return %expanded : tensor<2x10x4096x64xf16>
}
@@ -36,11 +37,51 @@
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @attention_static_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x4096x64xf16> {
+ %0 = tensor.empty() : tensor<20x4096x64xf16>
+ %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+ %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
+ util.return %expanded : tensor<2x10x4096x64xf16>
+}
+
+//CHECK-LABEL: func public @attention_static_masked(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
+// CHECK-SAME: %[[ARG3:.+]]: f16
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>)
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16>
+// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
+// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
+// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
+// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 1024]
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: util.return %[[ATTENTION]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x2048x2x2x32xf16> {
%0 = tensor.empty() : tensor<20x4096x64xf16>
- %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+ %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
%expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16>
util.return %expanded : tensor<2x10x2048x2x2x32xf16>
}
@@ -66,11 +107,50 @@
// -----
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @attention_expand_all_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x2048x2x2x32xf16> {
+ %0 = tensor.empty() : tensor<20x4096x64xf16>
+ %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4: tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
+ %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16>
+ util.return %expanded : tensor<2x10x2048x2x2x32xf16>
+}
+
+//CHECK-LABEL: func public @attention_expand_all_masked(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
+// CHECK-SAME: %[[ARG3:.+]]: f16
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>)
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16>
+// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16>
+// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16>
+// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16>
+// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 1024] : tensor<20x4096x1024xf16> into tensor<2x10x2048x2x1024xf16>
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)>
+// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: util.return %[[ATTENTION]]
+
+// -----
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16) -> tensor<2x?x?x?xf16> {
%c0 = arith.constant 0 : index
@@ -83,7 +163,7 @@
%d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
%d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
%0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
- %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+ %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
%split = arith.divsi %d0, %c2 : index
%expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
: tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
@@ -130,7 +210,78 @@
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16, %arg4: tensor<?x?x?xf16>) -> tensor<2x?x?x?xf16> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf16>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf16>
+ %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf16>
+ %d3 = tensor.dim %arg1, %c1 : tensor<?x?x?xf16>
+ %d4 = tensor.dim %arg2, %c2 : tensor<?x?x?xf16>
+ %0 = tensor.empty(%d0, %d1, %d4) : tensor<?x?x?xf16>
+ %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16, tensor<?x?x?xf16>) outs(%0 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+ %split = arith.divsi %d0, %c2 : index
+ %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
+ : tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
+ util.return %expanded : tensor<2x?x?x?xf16>
+}
+
+//CHECK-LABEL: func public @attention_dynamic_masked(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>
+// CHECK-SAME: %[[ARG3:.+]]: f16
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
+// CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]]
+// CHECK-DAG: %[[VAL:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
+// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
+// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]]
+// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
+// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+// CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]]
+// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
+// CHECK-DAG: %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]]
+// CHECK-DAG: %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]]
+// CHECK-DAG: %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]]
+// CHECK-DAG: %[[SPLIT3:.+]] = arith.divui %[[D10]], %[[C2]]
+// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]]
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: util.return %[[ATTENTION]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
util.func public @sink_through_attention(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
%13 = tensor.empty() : tensor<4x32x64x128xf16>
@@ -138,7 +289,7 @@
%collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
%collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
%17 = tensor.empty() : tensor<128x64x128xf16>
- %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+ %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
util.return %18 : tensor<128x64x128xf16>
}
@@ -162,13 +313,52 @@
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @sink_through_attention_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16, %3 : tensor<4x32x64x64xf16>) -> (tensor<128x64x128xf16>) {
+ %13 = tensor.empty() : tensor<4x32x64x128xf16>
+ %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+ %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+ %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+ %collapsed_15 = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<4x32x64x64xf16> into tensor<128x64x64xf16>
+ %17 = tensor.empty() : tensor<128x64x128xf16>
+ %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst, %collapsed_15 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+ util.return %18 : tensor<128x64x128xf16>
+}
+
+// CHECK-LABEL: util.func public @sink_through_attention_masked
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG3:.+]]: f16
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]] :
+// CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+// CHECK: util.return %[[RET]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
util.func public @sink_single_collapse(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) {
%13 = tensor.empty() : tensor<4x32x64x128xf16>
%collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
%17 = tensor.empty() : tensor<128x64x128xf16>
- %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+ %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
util.return %18 : tensor<128x64x128xf16>
}
@@ -188,3 +378,41 @@
// CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]] :
// CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
// CHECK: util.return %[[RET]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
+
+util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16, %3 : tensor<128x64x64xf16>) -> (tensor<128x64x128xf16>) {
+ %13 = tensor.empty() : tensor<4x32x64x128xf16>
+ %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+ %17 = tensor.empty() : tensor<128x64x128xf16>
+ %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %1, %2, %cst, %3 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16>
+ util.return %18 : tensor<128x64x128xf16>
+}
+
+// CHECK-LABEL: util.func public @sink_single_collapse_masked
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG3:.+]]: f16
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]:
+// CHECK-DAG: %[[EXPANDED1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128]
+// CHECK-DAG: %[[EXPANDED2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128]
+// CHECK-DAG: %[[EXPANDED3:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 64]
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]], %[[EXPANDED3]] :
+// CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16>
+// CHECK: util.return %[[RET]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
index 7f0ed69..edfbfcf 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir
@@ -388,7 +388,7 @@
} -> tensor<4x1x4x16x32x128xf16>
%3 = tensor.empty() : tensor<4x1x32x1x128xf16>
%4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) {
- %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
+ %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
flow.return %5 : tensor<4x1x32x1x128xf16>
}
%collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16>
@@ -402,3 +402,98 @@
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
// CHECK: flow.return %[[ATTENTION]]
+
+// -----
+
+// Clone 'gather-like' operations
+util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<16384x16x32x128xf16>, %arg2: f16, %arg3: i64, %arg4: tensor<4x4x16x32x128xf16>, %arg5: tensor<4x1x32x1x128xf16>) -> tensor<4x32x1x128xf16> {
+ %0 = tensor.empty() : tensor<4x1x4x16x32x128xf16>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+ ^bb0(%in: i64, %out: f16):
+ %5 = arith.index_cast %in : i64 to index
+ %6 = linalg.index 3 : index
+ %7 = linalg.index 4 : index
+ %8 = linalg.index 5 : index
+ %extracted = tensor.extract %arg1[%5, %6, %7, %8] : tensor<16384x16x32x128xf16>
+ linalg.yield %extracted : f16
+ } -> tensor<4x1x4x16x32x128xf16>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+ ^bb0(%in: i64, %out: f16):
+ %5 = arith.addi %in, %arg3 : i64
+ %6 = arith.index_cast %5 : i64 to index
+ %7 = linalg.index 3 : index
+ %8 = linalg.index 4 : index
+ %9 = linalg.index 5 : index
+ %extracted = tensor.extract %arg1[%6, %7, %8, %9] : tensor<16384x16x32x128xf16>
+ linalg.yield %extracted : f16
+ } -> tensor<4x1x4x16x32x128xf16>
+ %3 = tensor.empty() : tensor<4x1x32x1x128xf16>
+ %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) {
+ %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
+ flow.return %5 : tensor<4x1x32x1x128xf16>
+ }
+ %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16>
+ util.return %collapsed : tensor<4x32x1x128xf16>
+}
+
+// CHECK-LABEL: util.func public @clone_gather_lik
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[GATHER0:.+]] = linalg.generic
+// CHECK: %[[GATHER1:.+]] = linalg.generic
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
+// CHECK: flow.return %[[ATTENTION]]
+
+// -----
+
+util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<16384x16x32x128xf16>, %arg2: f16, %arg3: i64, %arg4: tensor<4x4x16x32x128xf16>, %arg5: tensor<4x1x32x1x128xf16>, %arg6: tensor<4x1x32x128xf16>) -> tensor<4x32x1x128xf16> {
+ %0 = tensor.empty() : tensor<4x1x4x16x32x128xf16>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+ ^bb0(%in: i64, %out: f16):
+ %5 = arith.index_cast %in : i64 to index
+ %6 = linalg.index 3 : index
+ %7 = linalg.index 4 : index
+ %8 = linalg.index 5 : index
+ %extracted = tensor.extract %arg1[%5, %6, %7, %8] : tensor<16384x16x32x128xf16>
+ linalg.yield %extracted : f16
+ } -> tensor<4x1x4x16x32x128xf16>
+
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) {
+ ^bb0(%in: i64, %out: f16):
+ %5 = arith.addi %in, %arg3 : i64
+ %6 = arith.index_cast %5 : i64 to index
+ %7 = linalg.index 3 : index
+ %8 = linalg.index 4 : index
+ %9 = linalg.index 5 : index
+ %extracted = tensor.extract %arg1[%6, %7, %8, %9] : tensor<16384x16x32x128xf16>
+ linalg.yield %extracted : f16
+ } -> tensor<4x1x4x16x32x128xf16>
+
+ %3 = tensor.empty() : tensor<4x1x32x1x128xf16>
+
+ %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) {
+ %5 = iree_linalg_ext.attention {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
+ ]
+ } ins(%arg5, %1, %2, %arg2, %arg6 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16, tensor<4x1x32x128xf16>) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16>
+
+ flow.return %5 : tensor<4x1x32x1x128xf16>
+ }
+
+ %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16>
+ util.return %collapsed : tensor<4x32x1x128xf16>
+}
+
+// CHECK-LABEL: util.func public @clone_gather_lik
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[GATHER0:.+]] = linalg.generic
+// CHECK: %[[GATHER1:.+]] = linalg.generic
+// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
+// CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]]
+// CHECK: flow.return %[[ATTENTION]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
index 993b690..1d10dfc 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir
@@ -97,7 +97,7 @@
linalg.yield %5 : f16
} -> tensor<?x?x?xf16>
- %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%arg4 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+ %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16) outs(%arg4 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
^bb0(%in: f16, %out: f16):
@@ -123,3 +123,49 @@
// CHECK: %[[GEN2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ATTN]]
// CHECK: flow.return %[[GEN2]]
+
+// -----
+
+util.func public @attention_dispatch_masked(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16, %arg4: tensor<?x?x?xf16>, %arg5: tensor<?x?x?xf16>, %arg6: tensor<?x?x?xf16>, %arg7: tensor<?x?x?xf16>) -> tensor<?x?x?xf16> {
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5 = arith.mulf %in, %in : f16
+ linalg.yield %5 : f16
+ } -> tensor<?x?x?xf16>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5 = arith.mulf %in, %in : f16
+ linalg.yield %5 : f16
+ } -> tensor<?x?x?xf16>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5 = arith.mulf %in, %in : f16
+ linalg.yield %5 : f16
+ } -> tensor<?x?x?xf16>
+
+ %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3, %arg4: tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16, tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
+
+ %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ %5 = arith.mulf %in, %in : f16
+ linalg.yield %5 : f16
+ } -> tensor<?x?x?xf16>
+ util.return %4 : tensor<?x?x?xf16>
+}
+
+// CHECK-LABEL: util.func public @attention_dispatch_masked
+// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
+// CHECK-NEXT: %[[GEN0:.+]] = linalg.generic
+// CHECK: flow.return %[[GEN0]]
+// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
+// CHECK-NEXT: %[[GEN1:.+]] = linalg.generic
+// CHECK: flow.return %[[GEN1]]
+// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region
+// CHECK-NEXT: %[[GEN2:.+]] = linalg.generic
+// CHECK: flow.return %[[GEN2]]
+// CHECK: %[[RESULT:.+]] = flow.dispatch.region
+// CHECK: %[[ATTN:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: ins(%[[DISPATCH0]], %[[DISPATCH1]], %[[DISPATCH2]]
+// CHECK: %[[GEN2:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ATTN]]
+// CHECK: flow.return %[[GEN2]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir
index 079acf3..d0dc854 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir
@@ -15,7 +15,7 @@
linalg.yield %in : f16
} -> tensor<4x32x64x128xf16>
%4 = tensor.empty() : tensor<4x32x64x128xf16>
- %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16) outs(%4 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16>
+ %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16) outs(%4 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16>
%6 = tensor.empty() : tensor<4x64x32x128xf16>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<4x32x64x128xf16>) outs(%6 : tensor<4x64x32x128xf16>) {
^bb0(%in: f16, %out: f16):
@@ -35,6 +35,55 @@
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d3)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
+
+// -----
+
+util.func public @transposed_attention_masked(%arg0: tensor<4x64x32x128xf16>, %arg1: tensor<4x64x32x128xf16>, %arg2: tensor<4x64x32x128xf16>, %arg3: f16, %arg4: tensor<4x64x32x64xf16>) -> tensor<4x64x4096xf16> {
+ %0 = tensor.empty() : tensor<4x32x64x128xf16>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ linalg.yield %in : f16
+ } -> tensor<4x32x64x128xf16>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ linalg.yield %in : f16
+ } -> tensor<4x32x64x128xf16>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ linalg.yield %in : f16
+ } -> tensor<4x32x64x128xf16>
+ %empty = tensor.empty() : tensor<4x32x64x64xf16>
+ %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<4x64x32x64xf16>) outs(%empty : tensor<4x32x64x64xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ linalg.yield %in : f16
+ } -> tensor<4x32x64x64xf16>
+ %5 = tensor.empty() : tensor<4x32x64x128xf16>
+ %6 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3, %4 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16, tensor<4x32x64x64xf16>) outs(%5 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16>
+ %7 = tensor.empty() : tensor<4x64x32x128xf16>
+ %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<4x32x64x128xf16>) outs(%7 : tensor<4x64x32x128xf16>) {
+ ^bb0(%in: f16, %out: f16):
+ linalg.yield %in : f16
+ } -> tensor<4x64x32x128xf16>
+ %collapsed = tensor.collapse_shape %8 [[0], [1], [2, 3]] : tensor<4x64x32x128xf16> into tensor<4x64x4096xf16>
+ util.return %collapsed : tensor<4x64x4096xf16>
+}
+
+// CHECK-LABEL: util.func public @transposed_attention_masked
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
+// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
+// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor
+// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: f16
+// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d3)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d5)>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)>
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
index cbb6ebd..447ca7d 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir
@@ -15,6 +15,7 @@
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
@@ -39,26 +40,30 @@
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
-// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]]
-// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16>
-// CHECK-DAG: %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]]
-// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
-// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}}
-// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}}
-// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}}
+// CHECK-DAG: %[[D:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG2]], %[[C2]]
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D]], %[[D0]], %[[D1]]) : tensor<?x?x?xf16>
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
-// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
-// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> ()>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
-// CHECK: util.return %[[ATTENTION]]
+// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D]], %[[C2]]
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT]], %[[D0]], %[[D1]]]
+// CHECK-DAG: %[[OUTS:.+]] = tensor.empty(%[[D0]], %[[D_SPLIT]], %[[D1]]) : tensor<2x?x?x?xf16>
+// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]
+// CHECK-SAME: ins(%[[EXPANDED]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: linalg.yield
+// CHECK: util.return %[[TRANSPOSE]]
// -----
@@ -70,6 +75,7 @@
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
@@ -91,16 +97,23 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
// CHECK-SAME: %[[ARG3:.+]]: f16)
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
-// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16]
-// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16]
-// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<20x4096x64xf16>
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
-// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
-// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>]
-// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] :
+// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> ()>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] :
// CHECK-SAME: outs(%[[EMPTY]] :
-// CHECK: util.return %[[ATTENTION]]
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 64]
+// CHECK-DAG: %[[OUTS:.+]] = tensor.empty() : tensor<2x4096x10x64xf16>
+// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps =
+// CHECK-SAME: [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>]
+// CHECK-SAME: ins(%[[EXPANDED]] :
+// CHECK-SAME: outs(%[[OUTS]] :
+// CHECK: linalg.yield
+// CHECK: util.return %[[TRANSPOSE]]
diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py
index d258dc1..feaad0e 100644
--- a/tests/e2e/attention/generate_e2e_attention_tests.py
+++ b/tests/e2e/attention/generate_e2e_attention_tests.py
@@ -217,6 +217,7 @@
f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n"
+ f" affine_map<(batch, m, n, k1, k2) -> ()>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}"
f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n"
f" outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n"
diff --git a/tests/e2e/linalg_ext_ops/attention.mlir b/tests/e2e/linalg_ext_ops/attention.mlir
index c418809..c2ca832 100644
--- a/tests/e2e/linalg_ext_ops/attention.mlir
+++ b/tests/e2e/linalg_ext_ops/attention.mlir
@@ -14,6 +14,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x3x4xf32>,
tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%init : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
@@ -44,6 +45,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<1x4x4xf32>,
tensor<1x4x4xf32>, tensor<1x4x4xf32>, f32) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
@@ -90,6 +92,7 @@
%1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%query, %key, %value, %scale : tensor<3x3x4xf32>,
tensor<3x3x4xf32>, tensor<3x3x4xf32>, f32) outs(%init : tensor<3x3x4xf32>) -> tensor<3x3x4xf32>