Split transform.tile_and_decompose_attention (#15516)
78e9dbcc2b6b51f329238b4b25648527c64e60a7 split up the
TileAndDecomposeAttention pass implementation into separate tiling and
decomposition. This patch does the same splitting at transform dialect
level. There are two reasons for this:
1. Easier to keep track of all the results from the transform operation.
2. After the previous splitting up patch, we were doing some hacks on
moving operations filled around to get the same sequence of transform
results as before.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
index ae62eb4..03e1cd6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_transform_spec.mlir
@@ -31,10 +31,11 @@
// Tile and decompose attention
// ==========================================
%attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %acc_fill, %max_fill, %sum_fill, %inner_loop,
- %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %truncate, %scale_acc, %second_matmul, %last_truncate
- = transform.tile_and_decompose_attention %attention4 :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %acc_fill, %max_fill, %sum_fill, %inner_loop, %last_truncate, %blocked_attention = transform.tile_attention %attention4 :
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %truncate, %scale_acc, %second_matmul
+ = transform.decompose_tiled_attention %blocked_attention :
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// Promote key and value operands
// ==========================================
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index e485a06..340684f 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -172,9 +172,17 @@
std::unique_ptr<Pass> createDecomposeSoftmaxPass();
// Transform dialect version of tile and decompose attention wrapper.
-SmallVector<Operation *>
-tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
- RewriterBase &rewriter, bool onlyTile = false);
+void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
+ SmallVectorImpl<Operation *> &ops,
+ RewriterBase &rewriter, bool onlyTile = false);
+
+IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
+ SmallVectorImpl<Operation *> &ops,
+ RewriterBase &rewriter);
+
+void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
+ SmallVectorImpl<Operation *> &ops,
+ RewriterBase &rewriter);
// Creates a pass to convert the attention op into a sequence of
// linalg ops.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index 2c62d6e..ad171c4 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -105,14 +105,50 @@
}];
}
-def TileAndDecomposeAttentionOp : Op<Transform_Dialect, "tile_and_decompose_attention",
+def TileAttentionOp : Op<Transform_Dialect, "tile_attention",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Target iree_linalg_ext.attention ops and tile and decompose them.
+ Target iree_linalg_ext.attention ops and tile them.
+ This transform consumes the target handle and produces a result handle.
+ }];
+
+ let arguments = (
+ ins TransformHandleTypeInterface:$target
+ );
+ let results = (outs Variadic<TransformHandleTypeInterface>:$result);
+
+ let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let assemblyFormat = [{
+ $target attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def DecomposeTiledAttentionOp : Op<Transform_Dialect, "decompose_tiled_attention",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Target iree_linalg_ext.attention ops and decompose them.
This transform consumes the target handle and produces a result handle.
}];
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 4054f93..430320a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -323,9 +323,9 @@
/// Tile iree_linalg_ext.attention.
/// TODO: Adopt getTiledImplementation with this.
-static SmallVector<Operation *>
-tileAttention(IREE::LinalgExt::AttentionOp attnOp, RewriterBase &rewriter) {
- SmallVector<Operation *> ops;
+IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
+ SmallVectorImpl<Operation *> &ops,
+ RewriterBase &rewriter) {
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(attnOp);
@@ -427,14 +427,15 @@
rewriter.replaceOp(attnOp, loopNest.results[0]);
ops.push_back(tiledAttentionOp);
- return ops;
+
+ return tiledAttentionOp;
}
/// Decompose tiled iree_linalg_ext.attention op.
/// TODO: Adopt decomposeOperation with this.
-static void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
- SmallVector<Operation *> &ops,
- RewriterBase &rewriter) {
+void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
+ SmallVectorImpl<Operation *> &ops,
+ RewriterBase &rewriter) {
Location loc = tiledAttnOp.getLoc();
Value keySlice = tiledAttnOp.getKey();
Value valueSlice = tiledAttnOp.getValue();
@@ -463,24 +464,14 @@
/// Utility function which tiles and then decomposes attention op via
/// FlashAttention algorithm.
-SmallVector<Operation *>
-tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
- RewriterBase &rewriter, bool onlyTile) {
- SmallVector<Operation *> ops = tileAttention(attnOp, rewriter);
+void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
+ SmallVectorImpl<Operation *> &ops,
+ RewriterBase &rewriter, bool onlyTile) {
+ IREE::LinalgExt::AttentionOp tiledAttentionOp =
+ tileAttention(attnOp, ops, rewriter);
if (onlyTile)
- return ops;
- auto tiledAttnOp = cast<IREE::LinalgExt::AttentionOp>(ops[ops.size() - 1]);
- ops.pop_back();
- Operation *truncateToF16 = NULL;
- Type elementType = tiledAttnOp.getQueryType().getElementType();
- if (elementType.isF16()) {
- truncateToF16 = ops[ops.size() - 1];
- ops.pop_back();
- }
- decomposeTiledAttention(tiledAttnOp, ops, rewriter);
- if (truncateToF16)
- ops.push_back(truncateToF16);
- return ops;
+ return;
+ decomposeTiledAttention(tiledAttentionOp, ops, rewriter);
}
namespace {
@@ -515,7 +506,8 @@
LogicalResult reifyAttentionTransform(func::FuncOp funcOp, bool onlyTile) {
IRRewriter rewriter(funcOp.getContext());
funcOp.walk([&](IREE::LinalgExt::AttentionOp attnOp) {
- tileAndDecomposeAttention(attnOp, rewriter, onlyTile);
+ SmallVector<Operation *> ops;
+ tileAndDecomposeAttention(attnOp, ops, rewriter, onlyTile);
return WalkResult::advance();
});
return success();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index 95ef281..c72c1a7 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -161,12 +161,23 @@
// TileAndDecomposeAttention
//===---------------------------------------------------------------------===//
-DiagnosedSilenceableFailure LinalgExt::TileAndDecomposeAttentionOp::applyToOne(
+DiagnosedSilenceableFailure LinalgExt::TileAttentionOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- SmallVector<Operation *> ops =
- LinalgExt::tileAndDecomposeAttention(attentionOp, rewriter);
+ SmallVector<Operation *> ops;
+ LinalgExt::tileAttention(attentionOp, ops, rewriter);
+ for (auto op : ops)
+ results.push_back(op);
+ return DiagnosedSilenceableFailure::success();
+}
+
+DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne(
+ transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> ops;
+ LinalgExt::decomposeTiledAttention(attentionOp, ops, rewriter);
for (auto op : ops)
results.push_back(op);
return DiagnosedSilenceableFailure::success();
diff --git a/tests/transform_dialect/cpu/attention_codegen_spec.mlir b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
index f31f50e..3601d92 100644
--- a/tests/transform_dialect/cpu/attention_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
@@ -19,10 +19,12 @@
// Tile and decompose attention
// ==========================================
- %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- %acc_fill, %max_fill, %sum_fill, %inner_loop, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum,
- %reciprocal_sum, %softmax, %scale_acc, %second_matmul = transform.tile_and_decompose_attention %attention2 :
- (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,!transform.any_op, !transform.any_op, !transform.any_op)
+ %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ %acc_fill, %max_fill, %sum_fill, %inner_loop, %blocked_attention = transform.tile_attention %attention4 :
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %fill_op, %first_matmul, %reduce_max, %partial_softmax, %update, %reduce_sum, %reciprocal_sum, %softmax, %scale_acc, %second_matmul
+ = transform.decompose_tiled_attention %blocked_attention :
+ (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// Vectorize function
// ==========================================