[GPU] Allow swizzle promotion for transpose linalg.generic producers (preshuffling prs 1/3) (#24086)
This PR is the first in a series of PRs meant to enable preshuffling for
scaled gemms.
Currently, operands to be swizzled are skipped when producing
SwizzleHintOps if their producer is a linalg::LinalgOp. This means that
transpose linalg.generic ops feeding into a matmul with SwizzleOperand
promotion would never get XOR swizzle hints applied to their output
buffer.
This change adds an exception: transpose linalg.generic producers
(identified via linalg::isaTransposeOpInterface) no longer exit early.
For a GEMM dispatch where the transpose is folded into memory access,
this does not produce any extra allocations, bufferization correctly
recognizes that the `tensor.empty` associated to the TransposeOp and the
SwizzleHintOp are redundant and removes the copy.
Non-transpose linalg.generic producers continue to get the early return
with lowering_config annotation as before.
Made-with: Cursor
---------
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir
index 7b8fab4..faa208c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir
@@ -394,6 +394,104 @@
// -----
+#lowering_config_swizzle_transpose = #iree_gpu.lowering_config<{
+ promote_operands = [0, 1],
+ promotion_types = [
+ #iree_gpu.swizzle_operand<copy_config = #iree_gpu.derived_thread_config, swizzle = #iree_codegen.xor_shuffle<256, 32>>,
+ #iree_gpu.swizzle_operand<copy_config = #iree_gpu.derived_thread_config, swizzle = #iree_codegen.xor_shuffle<256, 32>>]}>
+
+#transpose_map = affine_map<(d0, d1) -> (d1, d0)>
+#identity_map = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @swizzle_operand_transpose_producer(
+ %a_transposed: tensor<64x32xf32>, %b: tensor<64x128xf32>) -> tensor<32x128xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty_a = tensor.empty() : tensor<32x64xf32>
+ %transpose_a = linalg.generic {
+ indexing_maps = [#transpose_map, #identity_map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%a_transposed : tensor<64x32xf32>) outs(%empty_a : tensor<32x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<32x64xf32>
+ %empty = tensor.empty() : tensor<32x128xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32>
+ %mm = linalg.matmul {lowering_config = #lowering_config_swizzle_transpose}
+ ins(%transpose_a, %b : tensor<32x64xf32>, tensor<64x128xf32>)
+ outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32>
+ return %mm : tensor<32x128xf32>
+}
+
+// Transpose linalg.generic producers are not given a lowering_config early
+// return — they fall through to the swizzle promotion path so XOR swizzle
+// hints are applied to their output buffer.
+// CHECK-LABEL: func.func @swizzle_operand_transpose_producer
+// CHECK-SAME: %[[A_T:[A-Za-z0-9]+]]: tensor<64x32xf32>
+// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x128xf32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.generic {{.*}} ins(%[[A_T]] : tensor<64x32xf32>)
+// CHECK: %[[EMPTY_A:.+]] = tensor.empty() : tensor<2048xf32>
+// CHECK: %[[SWIZZLE_A:.+]] = iree_codegen.swizzle_hint %[[EMPTY_A]][#iree_codegen.xor_shuffle<256, 32>] : tensor<2048xf32>
+// CHECK: %[[EXPAND_A:.+]] = tensor.expand_shape %[[SWIZZLE_A]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : tensor<2048xf32> into tensor<32x64xf32>
+// CHECK: %[[COPY_A:.+]] = linalg.copy
+// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
+// CHECK-SAME: ins(%[[TRANSPOSE]] : tensor<32x64xf32>) outs(%[[EXPAND_A]] : tensor<32x64xf32>)
+// CHECK: %[[EMPTY_B:.+]] = tensor.empty() : tensor<8192xf32>
+// CHECK: %[[SWIZZLE_B:.+]] = iree_codegen.swizzle_hint %[[EMPTY_B]][#iree_codegen.xor_shuffle<256, 32>] : tensor<8192xf32>
+// CHECK: %[[EXPAND_B:.+]] = tensor.expand_shape %[[SWIZZLE_B]] {{\[\[}}0, 1{{\]\]}} output_shape [64, 128] : tensor<8192xf32> into tensor<64x128xf32>
+// CHECK: %[[COPY_B:.+]] = linalg.copy
+// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
+// CHECK-SAME: ins(%[[B]] : tensor<64x128xf32>) outs(%[[EXPAND_B]] : tensor<64x128xf32>)
+// CHECK: linalg.matmul {{.*}} ins(%[[COPY_A]], %[[COPY_B]] : tensor<32x64xf32>, tensor<64x128xf32>)
+
+// -----
+
+#lowering_config_swizzle_non_transpose = #iree_gpu.lowering_config<{
+ promote_operands = [0, 1],
+ promotion_types = [
+ #iree_gpu.swizzle_operand<copy_config = #iree_gpu.derived_thread_config, swizzle = #iree_codegen.xor_shuffle<256, 32>>,
+ #iree_gpu.swizzle_operand<copy_config = #iree_gpu.derived_thread_config, swizzle = #iree_codegen.xor_shuffle<256, 32>>]}>
+
+#elementwise_map = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @swizzle_operand_non_transpose_generic_producer(
+ %a_raw: tensor<32x64xf32>, %b: tensor<64x128xf32>) -> tensor<32x128xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %empty_a = tensor.empty() : tensor<32x64xf32>
+ %negated_a = linalg.generic {
+ indexing_maps = [#elementwise_map, #elementwise_map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%a_raw : tensor<32x64xf32>) outs(%empty_a : tensor<32x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %neg = arith.negf %in : f32
+ linalg.yield %neg : f32
+ } -> tensor<32x64xf32>
+ %empty = tensor.empty() : tensor<32x128xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32>
+ %mm = linalg.matmul {lowering_config = #lowering_config_swizzle_non_transpose}
+ ins(%negated_a, %b : tensor<32x64xf32>, tensor<64x128xf32>)
+ outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32>
+ return %mm : tensor<32x128xf32>
+}
+
+// Non-transpose linalg.generic producers get a lowering_config stamped on them
+// directly and skip swizzle promotion — no swizzle_hint is created for them.
+// CHECK-LABEL: func.func @swizzle_operand_non_transpose_generic_producer
+// CHECK-SAME: %[[A_RAW:[A-Za-z0-9]+]]: tensor<32x64xf32>
+// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x128xf32>
+// CHECK: %[[NEGATED:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[A_RAW]] : tensor<32x64xf32>)
+// CHECK: lowering_config = #iree_gpu.derived_thread_config
+// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} tensor<2048xf32>
+// CHECK: %[[EMPTY_B:.+]] = tensor.empty() : tensor<8192xf32>
+// CHECK: %[[SWIZZLE_B:.+]] = iree_codegen.swizzle_hint %[[EMPTY_B]][#iree_codegen.xor_shuffle<256, 32>] : tensor<8192xf32>
+// CHECK: %[[EXPAND_B:.+]] = tensor.expand_shape %[[SWIZZLE_B]] {{\[\[}}0, 1{{\]\]}} output_shape [64, 128] : tensor<8192xf32> into tensor<64x128xf32>
+// CHECK: %[[COPY_B:.+]] = linalg.copy
+// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
+// CHECK-SAME: ins(%[[B]] : tensor<64x128xf32>) outs(%[[EXPAND_B]] : tensor<64x128xf32>)
+// CHECK: linalg.matmul {{.*}} ins(%[[NEGATED]], %[[COPY_B]] : tensor<32x64xf32>, tensor<64x128xf32>)
+
+// -----
+
// Im2colOp has no DMA conversion path in GPUConvertToCoalescedDMA, so
// promotionImpl must never stamp use_global_load_dma on it — it always falls
// back to derived_thread_config regardless of the requested promotion type.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp
index 17ec044..edc9d63 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp
@@ -76,8 +76,13 @@
}
if (isa<linalg::LinalgOp>(producer.getOperation())) {
- setLoweringConfig(producer, attr);
- return operand.get();
+ // Don't skip promotion for transpose producers — they need to go through
+ // the swizzle path so XOR swizzle hints are applied.
+ auto generic = dyn_cast<linalg::GenericOp>(producer.getOperation());
+ if (!generic || !linalg::isaTransposeOpInterface(generic)) {
+ setLoweringConfig(producer, attr);
+ return operand.get();
+ }
}
// Im2colOp has no DMA conversion path in GPUConvertToCoalescedDMA, so
// always use derived_thread_config regardless of the requested attr.