[GPU][DT] Fix matmul narrow dim selection (#21764)
The old logic:
```
if (ShapedType::isDynamic(n) || m < n) { ... }
if (ShapedType::isDynamic(m) || n < m) { ... }
```
could incorrectly select the narrow dimension when `m` is dynamic
(represented by `INT64_MIN`). This case should be handled by the second
`if`, but it is accidentally captured by the first `if`, since `m < n`
evaluates as true for a dynamic `m`.
This PR also fixes the iterationSizes issue that caused compilation
failures in llama with data tiling.
---------
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_gfx942.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_gfx942.mlir
index 18fa640..0a4946f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_gfx942.mlir
@@ -238,6 +238,54 @@
#encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32],
user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iteration_sizes = [?, 513, ?]>
+func.func @set_encoding_ACC_dynamic_M_MFMA_F32_16x16x4_F32(%arg0 : tensor<?x513xf32>) -> tensor<?x513xf32, #encoding> {
+ %0 = iree_encoding.set_encoding %arg0 : tensor<?x513xf32> -> tensor<?x513xf32, #encoding>
+ return %0 : tensor<?x513xf32, #encoding>
+}
+
+// CHECK-LABEL: func.func @set_encoding_ACC_dynamic_M_MFMA_F32_16x16x4_F32
+// CHECK: %[[PACK:.*]] = linalg.pack %{{.+}} padding_value(%{{.+}} : f32)
+// CHECK-SAME: outer_dims_perm = [0, 1]
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [128, 128]
+// CHECK-SAME: : tensor<?x513xf32> -> tensor<?x5x128x128xf32>
+// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
+// CHECK-SAME : tensor<?x5x128x128xf32> into tensor<?x5x4x4x2x4x16x8xf32>
+// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x5x4x4x2x4x16x8xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?x5x4x2x8x4x16x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 2, 4, 7, 3, 6, 5]
+// CHECK: return %[[TRANSPOSE]]
+
+// -----
+
+#encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32],
+ user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iteration_sizes = [255, ?, ?]>
+func.func @set_encoding_ACC_dynamic_N_MFMA_F32_16x16x4_F32(%arg0 : tensor<255x?xf32>) -> tensor<255x?xf32, #encoding> {
+ %0 = iree_encoding.set_encoding %arg0 : tensor<255x?xf32> -> tensor<255x?xf32, #encoding>
+ return %0 : tensor<255x?xf32, #encoding>
+}
+
+// CHECK-LABEL: func.func @set_encoding_ACC_dynamic_N_MFMA_F32_16x16x4_F32
+// CHECK: %[[PACK:.*]] = linalg.pack %{{.+}} padding_value(%{{.+}} : f32)
+// CHECK-SAME: outer_dims_perm = [0, 1]
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [128, 128]
+// CHECK-SAME: : tensor<255x?xf32> -> tensor<2x?x128x128xf32>
+// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
+// CHECK-SAME : tensor<2x?x128x128xf32> into tensor<2x?x4x8x4x4x16x2xf32>
+// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x?x4x8x4x4x16x2xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<2x?x4x8x2x4x16x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
+// CHECK: return %[[TRANSPOSE]]
+
+// -----
+
+#encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32],
+ user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iteration_sizes = [255, 513, ?]>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.cpp
index b6b9ab0..db9c016 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.cpp
@@ -62,17 +62,30 @@
// and vecmat, so set to 1 if empty.
const int64_t m = cDims.m.empty() ? 1 : iterationSizes[cDims.m[0]];
const int64_t n = cDims.n.empty() ? 1 : iterationSizes[cDims.n[0]];
+
+ // If both dimensions are dynamic, return empty.
if (ShapedType::isDynamic(m) && ShapedType::isDynamic(n)) {
return {};
}
- if (ShapedType::isDynamic(n) || m < n) {
- return {MatmulNarrowDim::Dim::M,
- static_cast<int64_t>(llvm::PowerOf2Ceil(m))};
- }
- if (ShapedType::isDynamic(m) || n < m) {
+ // If only one dimension is dynamic, pick the other as the narrow dimension.
+ if (ShapedType::isDynamic(m)) {
return {MatmulNarrowDim::Dim::N,
static_cast<int64_t>(llvm::PowerOf2Ceil(n))};
}
+ if (ShapedType::isDynamic(n)) {
+ return {MatmulNarrowDim::Dim::M,
+ static_cast<int64_t>(llvm::PowerOf2Ceil(m))};
+ }
+ // If Both dimensions are static, pick the smaller one.
+ if (n < m) {
+ return {MatmulNarrowDim::Dim::N,
+ static_cast<int64_t>(llvm::PowerOf2Ceil(n))};
+ }
+ if (m < n) {
+ return {MatmulNarrowDim::Dim::M,
+ static_cast<int64_t>(llvm::PowerOf2Ceil(m))};
+ }
+ // If dimensions are static and equal, return empty.
return {};
}