[Torch][LinalgExt] Support GQA in torch.hop_flex_attention lowering (#24313)
This PR adds grouped-query attention support to the
torch.hop_flex_attention lowering in IREE’s Torch input pipeline.
When query, key, and value have different head counts, the lowering now
expands key and value heads to match the query head count before
emitting `iree_linalg_ext.online_attention`.
This is similar to how the torch-mlir => TMTensor lowering handles this
case.
---------
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: GPT-5 Codex <noreply@openai.com>
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index b74e08c..c1ba627 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -131,6 +131,9 @@
"@torch-mlir//:TorchMLIRTorchDialect": [
f"{torch_mlir_cmake_prefix}::torch-mlir::TorchDialectIR",
],
+ "@torch-mlir//:TorchMLIRTorchDialectUtils": [
+ f"{torch_mlir_cmake_prefix}::torch-mlir::TorchDialectUtils",
+ ],
"@torch-mlir//:TorchMLIRTorchOnnxToTorch": [
f"{torch_mlir_cmake_prefix}::torch-mlir::TorchOnnxToTorchPasses",
],
diff --git a/compiler/plugins/input/Torch/InputConversion/BUILD.bazel b/compiler/plugins/input/Torch/InputConversion/BUILD.bazel
index a471125..3485639 100644
--- a/compiler/plugins/input/Torch/InputConversion/BUILD.bazel
+++ b/compiler/plugins/input/Torch/InputConversion/BUILD.bazel
@@ -91,6 +91,7 @@
"@torch-mlir//:TorchMLIRTorchConversionPasses",
"@torch-mlir//:TorchMLIRTorchConversionToMLProgram",
"@torch-mlir//:TorchMLIRTorchDialect",
+ "@torch-mlir//:TorchMLIRTorchDialectUtils",
"@torch-mlir//:TorchMLIRTorchOnnxToTorch",
"@torch-mlir//:TorchMLIRTorchPasses",
"@torch-mlir//:TorchMLIRTorchToArith",
diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
index fc42cf9..301f62c 100644
--- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
@@ -77,6 +77,7 @@
iree::compiler::plugins::input::Torch::torch-mlir::TorchConversionDialectIR
iree::compiler::plugins::input::Torch::torch-mlir::TorchDialectIR
iree::compiler::plugins::input::Torch::torch-mlir::TorchDialectPasses
+ iree::compiler::plugins::input::Torch::torch-mlir::TorchDialectUtils
iree::compiler::plugins::input::Torch::torch-mlir::TorchOnnxToTorchPasses
PUBLIC
)
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
index 1734fac..fe8209d 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
@@ -21,6 +21,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
+#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
@@ -183,6 +184,83 @@
rewriter, loc, tensorType.toBuiltinTensor(), torchTensor);
}
+static Value repeatTensorElementsForDim(PatternRewriter &rewriter,
+ Operation *op, Type resType, Value self,
+ int64_t repeats, int64_t dim) {
+ Location loc = op->getLoc();
+ auto selfType = cast<torch::Torch::ValueTensorType>(self.getType());
+
+ int64_t inputRank = selfType.getSizes().size();
+ dim = torch::Torch::toPositiveDim(dim, inputRank);
+
+ Value dimValue = torch::Torch::ConstantIntOp::create(rewriter, loc, dim);
+ Value dimValuePlusOne =
+ torch::Torch::ConstantIntOp::create(rewriter, loc, dim + 1);
+
+ self = torch::Torch::unsqueezeTensor(rewriter, op, self, dimValuePlusOne)
+ .value();
+
+ SmallVector<int64_t> expandShape(selfType.getSizes());
+ expandShape.insert(expandShape.begin() + dim + 1, repeats);
+ SmallVector<int64_t> expandShapeForBroadcast(expandShape.size(), -1);
+ expandShapeForBroadcast[dim + 1] = repeats;
+ Value expandShapeList =
+ torch::Torch::toIntListConstruct(rewriter, loc, expandShapeForBroadcast);
+
+ Type expandType =
+ selfType.getWithSizesAndDtype(expandShape, selfType.getOptionalDtype());
+ Value expanded = torch::Torch::AtenBroadcastToOp::create(
+ rewriter, loc, expandType, self, expandShapeList);
+
+ return torch::Torch::PrimsCollapseOp::create(rewriter, loc, resType, expanded,
+ dimValue, dimValuePlusOne)
+ .getResult();
+}
+
+static LogicalResult
+preProcessGroupQueryAttentionInputs(torch::Torch::HigherOrderFlexAttentionOp op,
+ PatternRewriter &rewriter, Value query,
+ Value &key, Value &value) {
+ auto queryType = cast<torch::Torch::ValueTensorType>(query.getType());
+ auto keyType = cast<torch::Torch::ValueTensorType>(key.getType());
+ auto valueType = cast<torch::Torch::ValueTensorType>(value.getType());
+
+ int64_t rank = queryType.getSizes().size();
+ int64_t qNumHeads = queryType.getSizes()[rank - 3];
+ int64_t kNumHeads = keyType.getSizes()[rank - 3];
+ int64_t vNumHeads = valueType.getSizes()[rank - 3];
+
+ if (llvm::any_of(ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads},
+ [](int64_t d) { return d == torch::Torch::kUnknownSize; })) {
+ return rewriter.notifyMatchFailure(
+ op, "expected statically known attention head counts");
+ }
+
+ if (qNumHeads == kNumHeads && qNumHeads == vNumHeads) {
+ return success();
+ }
+
+ if (qNumHeads % kNumHeads != 0 || qNumHeads % vNumHeads != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "expected query heads to be a multiple of key and value heads");
+ }
+
+ auto repeatToQueryHeadCount = [&](Value input,
+ torch::Torch::ValueTensorType inputType,
+ int64_t inputNumHeads) -> Value {
+ SmallVector<int64_t> resultShape(inputType.getSizes());
+ resultShape[rank - 3] = qNumHeads;
+ Type resultType = inputType.getWithSizesAndDtype(
+ resultShape, inputType.getOptionalDtype());
+ return repeatTensorElementsForDim(rewriter, op, resultType, input,
+ qNumHeads / inputNumHeads, rank - 3);
+ };
+
+ key = repeatToQueryHeadCount(key, keyType, kNumHeads);
+ value = repeatToQueryHeadCount(value, valueType, vNumHeads);
+ return success();
+}
+
/// Inline a single-block torch function's body at the current insertion point.
/// Falls back to func.call for multi-block or external functions.
static SmallVector<Value> inlineTorchFunction(PatternRewriter &rewriter,
@@ -318,6 +396,7 @@
Value value = op.getValue();
Value scaleVal = op.getScale();
+ bool enableGqa = op.getEnableGqa().value_or(false);
auto scoreModSymbol = op.getScoreModFnAttr();
auto maskModSymbol = op.getMaskModFnAttr();
@@ -333,8 +412,14 @@
return rewriter.notifyMatchFailure(
op, "expected return_max_scores to be a constant bool");
}
+ if (enableGqa) {
+ if (failed(preProcessGroupQueryAttentionInputs(op, rewriter, query, key,
+ value))) {
+ return failure();
+ }
+ }
- // Extract shapes from Q, K, V.
+ // Extract shapes from preprocessed Q, K, V.
auto queryType = cast<torch::Torch::ValueTensorType>(query.getType());
auto valueType = cast<torch::Torch::ValueTensorType>(value.getType());
diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir
index bcaab19..f586088 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir
@@ -229,6 +229,52 @@
// -----
+// Test flex_attention where enable_gqa is absent and head counts already match.
+// CHECK-LABEL: func.func @flex_attn_gqa_absent_matching_heads
+func.func @flex_attn_gqa_absent_matching_heads(%arg0: !torch.vtensor<[2,8,8,16],f32>, %arg1: !torch.vtensor<[2,8,8,16],f32>, %arg2: !torch.vtensor<[2,8,8,16],f32>) -> !torch.vtensor<[2,8,8,16],f32> {
+ %none = torch.constant.none
+ %false = torch.constant.bool false
+ %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %none, %false, %false : !torch.vtensor<[2,8,8,16],f32>, !torch.vtensor<[2,8,8,16],f32>, !torch.vtensor<[2,8,8,16],f32>, !torch.none, !torch.bool, !torch.bool -> !torch.vtensor<[2,8,8,16],f32>, !torch.none, !torch.none
+ return %output : !torch.vtensor<[2,8,8,16],f32>
+}
+// CHECK-NOT: torch.aten.broadcast_to
+// CHECK: iree_linalg_ext.online_attention
+// CHECK-SAME: ins({{.*}} : tensor<2x8x8x16xf32>, tensor<2x8x8x16xf32>, tensor<2x8x8x16xf32>, f32)
+
+// -----
+
+// Test flex_attention with explicitly enabled GQA and independent key/value
+// head counts.
+// CHECK-LABEL: func.func @flex_attn_gqa_enabled_independent_kv_heads
+func.func @flex_attn_gqa_enabled_independent_kv_heads(%arg0: !torch.vtensor<[2,8,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,2,8,16],f32>) -> !torch.vtensor<[2,8,8,16],f32> {
+ %none = torch.constant.none
+ %false = torch.constant.bool false
+ %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %none, %false, %false {enable_gqa = true} : !torch.vtensor<[2,8,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,2,8,16],f32>, !torch.none, !torch.bool, !torch.bool -> !torch.vtensor<[2,8,8,16],f32>, !torch.none, !torch.none
+ return %output : !torch.vtensor<[2,8,8,16],f32>
+}
+// CHECK: torch.aten.broadcast_to {{.*}} -> !torch.vtensor<[2,4,2,8,16],f32>
+// CHECK: torch.prims.collapse {{.*}} -> !torch.vtensor<[2,8,8,16],f32>
+// CHECK: torch.aten.broadcast_to {{.*}} -> !torch.vtensor<[2,2,4,8,16],f32>
+// CHECK: torch.prims.collapse {{.*}} -> !torch.vtensor<[2,8,8,16],f32>
+// CHECK: iree_linalg_ext.online_attention
+// CHECK-SAME: ins({{.*}} : tensor<2x8x8x16xf32>, tensor<2x8x8x16xf32>, tensor<2x8x8x16xf32>, f32)
+
+// -----
+
+// Test flex_attention with explicitly disabled GQA and matching head counts.
+// CHECK-LABEL: func.func @flex_attn_gqa_disabled_matching_heads
+func.func @flex_attn_gqa_disabled_matching_heads(%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> !torch.vtensor<[2,4,8,16],f32> {
+ %none = torch.constant.none
+ %false = torch.constant.bool false
+ %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %none, %false, %false {enable_gqa = false} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.none, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.none, !torch.none
+ return %output : !torch.vtensor<[2,4,8,16],f32>
+}
+// CHECK-NOT: torch.aten.broadcast_to
+// CHECK: iree_linalg_ext.online_attention
+// CHECK-SAME: ins({{.*}} : tensor<2x4x8x16xf32>, tensor<2x4x8x16xf32>, tensor<2x4x8x16xf32>, f32)
+
+// -----
+
// CHECK-LABEL: func.func @argmax_2d_dim1
func.func @argmax_2d_dim1(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3],si64> {
%int1 = torch.constant.int 1