[DT][GPU] Permute cross-thread dims of TileSwizzle to outermost (#19734)
This PR adds a new constraint on the generated TileSwizzles during GPU
encoding materialization. The new constraint is that all `CrossThread`
dimensions in the ACC layout must come from the outermost dimensions
within their reassociation groups of the swizzle's expand_shape. For
example, consider the following TileSwizzle for an ACC layout generated
without this constraint:
```
expandShape = [[8, {4}, 4], [{4}, 2, {16}]],
permutation = [3, 0, 4, 1, 5, 2]
```
`CrossThread` dimensions are denoted with braces {}, and this
TileSwizzle does not follow the new constrain because the first `{4}`
dim of reassociation group 0 and the `{16}` dim of reassociation group 1
are not outermost. After adding the new constraint, this swizzle would
look like the following:
```
expandShape = [[{4}, 8, 4], [{4}, {16}, 2]],
permutation = [3, 1, 5, 0, 4, 2]
```
Now, all `CrossThread` dimensions are outermost within their
reassociation groups. The permutation has been adjusted accordingly so
that the result shape of the swizzled tile remains the same in both
cases.
The LHS and RHS swizzle layouts also have to be adjusted to match the
new ACC layout, but the CrossThread dimensions are not necessarily the
same between corresponding M and N tiles of the ACC swizzle and LHS/RHS
swizzles. Because of this, the LHS and RHS (currently only LHS needs
this) swizzle shapes must be expanded to match the dimensionality of the
ACC layout. This is the reason why some of the LHS layouts have
additional expansion after this PR.
The reason for adding this constraint is so that the swizzle operations
of unset_encoding operations are able to be fused into the thread loop
of their data tiled multi_mma operation. This constraint makes the
fusion possible by forcing the slices that are held by a thread at the
end of the multi_mma computation to be contiguous in the linear layout
tensor within each reassociation group. This matters because we need to
fuse the collapse_shape op of the unset_encoding into the thread loop,
which is only possible when the written slice is contiguous in the
result for each reassociation group.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir
index a942751..7c091bc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx1100.mlir
@@ -50,7 +50,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x8x2x16xf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x8x2x16xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir
index cf54195..363b2f9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx908.mlir
@@ -50,7 +50,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x4x4xi8>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x4x4x4x4xi8>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4x4xi8>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x2x4x16x4xi32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir
index 86388f1..4448512 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx90a.mlir
@@ -50,7 +50,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x4x2xbf16>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x4x4x4x2xbf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4x2xbf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
@@ -109,7 +109,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x2xf64>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x4x4x4x2xf64>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x4x16x2xf64>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x4x4x16xf64>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir
index bd70cad..405b58f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/gpu_materialize_encoding_gfx942.mlir
@@ -23,7 +23,7 @@
return
}
// CHECK-LABEL: func.func @empty_fill_encoding_unroll8x8x4_MFMA_F32_16x16x4_F32
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x33x8x4x16x4xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x33x8x4x4x4x4xf32>
// CHECK: %{{.+}} = linalg.fill ins({{.+}}) outs(%[[EMPTY]]
// -----
@@ -52,11 +52,11 @@
// CHECK-SAME: inner_tiles = [128, 16]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<2x33x128x16xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<2x33x128x16xf32> into tensor<2x33x8x16x4x4xf32>
+// CHECK-SAME : tensor<2x33x128x16xf32> into tensor<2x33x4x8x4x4x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x33x8x16x4x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x33x8x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x33x4x8x4x4x4xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<2x33x8x4x4x4x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 6, 2, 4, 5]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -85,11 +85,11 @@
// CHECK-SAME: inner_tiles = [16, 16]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<16x33x16x16xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<16x33x16x16xf32> into tensor<16x33x16x4x4xf32>
+// CHECK-SAME : tensor<16x33x16x16xf32> into tensor<16x33x4x4x4x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<16x33x16x4x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<16x33x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 4, 2, 3]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<16x33x4x4x4x4xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<16x33x4x4x4x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 5, 2, 3, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -128,11 +128,11 @@
// CHECK-SAME: inner_tiles = [128, 16]
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x128x16xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<?x?x128x16xf32> into tensor<?x?x8x16x4x4xf32>
+// CHECK-SAME : tensor<?x?x128x16xf32> into tensor<?x?x4x8x4x4x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x?x8x16x4x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x?x4x8x4x4x4xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x4x4x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 6, 2, 4, 5]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -161,11 +161,11 @@
// CHECK-SAME: inner_tiles = [128, 16]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<5x16x128x16xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x4x2x16x4x4xf32>
+// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x4x16x2x4x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x4x2x16x4x4xf32>)
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x4x16x2x4x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<5x16x4x2x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5]
+// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -227,11 +227,11 @@
// CHECK-SAME: inner_tiles = [128, 128]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<2x5x128x128xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32>
+// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x4x8x4x4x16x2xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>)
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x4x8x4x4x16x2xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x2x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 5, 2, 6, 3, 7, 4]
+// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -256,10 +256,10 @@
// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x4x8x2x4x16x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>)
-// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
+// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x4x4x16x2xf32>)
+// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32>
+// CHECK-SAME: : tensor<2x5x4x8x4x4x16x2xf32> into tensor<2x5x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
@@ -299,10 +299,10 @@
// CHECK-LABEL: func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<?x?x4x8x2x4x16x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x4x4x2x16xf32>)
-// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
+// CHECK-SAME: outs({{.*}} : tensor<?x?x4x8x4x4x16x2xf32>)
+// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME: : tensor<?x?x8x4x4x4x2x16xf32> into tensor<?x?x128x128xf32>
+// CHECK-SAME: : tensor<?x?x4x8x4x4x16x2xf32> into tensor<?x?x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
@@ -360,7 +360,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x4x4x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
@@ -420,7 +420,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x4xf32>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x4x4x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
@@ -459,11 +459,11 @@
// CHECK-SAME: inner_tiles = [128, 64]
// CHECK-SAME: : tensor<255x513xi8> -> tensor<2x9x128x64xi8>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<2x9x128x64xi8> into tensor<2x9x8x16x2x4x8xi8>
+// CHECK-SAME : tensor<2x9x128x64xi8> into tensor<2x9x4x8x4x2x4x8xi8>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x9x8x16x2x4x8xi8>)
-// CHECK-SAME: outs({{.*}} : tensor<2x9x8x4x16x2x8xi8>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4, 6]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x9x4x8x4x2x4x8xi8>)
+// CHECK-SAME: outs({{.*}} : tensor<2x9x8x4x4x4x2x8xi8>)
+// CHECK-SAME: permutation = [0, 1, 3, 6, 2, 4, 5, 7]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -492,11 +492,11 @@
// CHECK-SAME: inner_tiles = [128, 64]
// CHECK-SAME: : tensor<255x513xi8> -> tensor<5x4x128x64xi8>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x4x2x16x2x4x8xi8>
+// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x4x16x2x2x4x8xi8>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x4x2x16x2x4x8xi8>)
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x4x16x2x2x4x8xi8>)
// CHECK-SAME: outs({{.*}} : tensor<5x4x4x2x4x16x2x8xi8>)
-// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5, 7]
+// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5, 7]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -525,11 +525,11 @@
// CHECK-SAME: inner_tiles = [128, 128]
// CHECK-SAME: : tensor<255x513xi32> -> tensor<2x5x128x128xi32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32>
+// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x4x8x4x4x16x2xi32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>)
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x4x8x4x4x16x2xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x2x4x16x4xi32>)
-// CHECK-SAME: permutation = [0, 1, 5, 2, 6, 3, 7, 4]
+// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -554,10 +554,10 @@
// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x4x8x2x4x16x4xi32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>)
-// CHECK-SAME: permutation = [0, 1, 3, 5, 7, 2, 4, 6]
+// CHECK-SAME: outs({{.*}} : tensor<2x5x4x8x4x4x16x2xi32>)
+// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32>
+// CHECK-SAME: : tensor<2x5x4x8x4x4x16x2xi32> into tensor<2x5x128x128xi32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
@@ -616,7 +616,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x4x4x2x8xi8>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x2x8xi8>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x8x2x4x16x4xi32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
@@ -1122,7 +1122,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x8xf8E4M3FNUZ>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x4x4x2x8xf8E4M3FNUZ>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x8xf8E4M3FNUZ>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
@@ -1182,7 +1182,7 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x4xbf16>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x4x4x2x4xbf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x4xbf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x4x8x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
index 2829922..b51c385 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
@@ -6,6 +6,11 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+
+#define DEBUG_TYPE "gpu-tile-swizzle-utils"
namespace mlir::iree_compiler::IREE::GPU {
@@ -78,8 +83,8 @@
swizzle.permutation = outPermutation;
}
-TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
- IREE::GPU::MMAFragment fragment) {
+static TileSwizzle getIntrinsicSwizzleBeforeMovingCrossThreadOutermost(
+ IREE::GPU::MMAIntrinsic intrinsic, IREE::GPU::MMAFragment fragment) {
auto layout = IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
// MMASingleSubgroupLayout has non-transposed RHS.
@@ -140,9 +145,181 @@
return 0;
}
-TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
- IREE::GPU::MMAFragment fragment) {
- auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
+/// Moves all `Kind::CrossThread` dims of the Acc layout to the outermost
+/// within their expand shape reassociation groups. This only moves the cross
+/// thread dims of the Acc layout because we want to fuse the unset_encoding
+/// ops with the data tiled matmul. In order to do this, the sliced dimensions
+/// (CrossThread) for each thread need to be outermost in the final write out.
+///
+/// This transformation is for the Acc layout, but the Lhs and Rhs layouts need
+/// to be transformed too, because the layouts need to match the Acc for their
+/// respective M and N tile dimensions.
+///
+/// Example (CrossThread dims are denoted with surrounding {} braces):
+/// Input:
+/// Lhs layout:
+/// expandShape = [[8, {16}], [2, {4}, 4]],
+/// permutation = [0, 3, 1, 2, 4]
+/// Rhs layout:
+/// expandShape = [[{4}, 2, {16}], [2, {4}, 4]],
+/// permutation = [0, 1, 4, 2, 3, 5]
+/// Acc layout:
+/// expandShape = [[8, {4}, 4], [{4}, 2, {16}]],
+/// permutation = [3, 0, 4, 1, 5, 2]
+/// Output:
+/// Lhs layout:
+/// expandShape = [[{4}, 8, {4}], [2, {4}, 4]],
+/// permutation = [1, 4, 0, 2, 3, 5]
+/// Rhs layout:
+/// expandShape = [[{4}, {16}, 2], [2, {4}, 4]],
+/// permutation = [0, 2, 4, 1, 3, 5]
+/// Acc layout:
+/// expandShape = [[{4}, 8, 4], [{4}, {16}, 2]],
+/// permutation = [3, 1, 5, 0, 4, 2]
+static TileSwizzle moveCrossThreadOutermost(TileSwizzle swizzle,
+ TileSwizzle accSwizzle,
+ MMAFragment fragment) {
+ assert(accSwizzle.expandShape.size() == 2);
+ assert(swizzle.expandShape.size() == 2);
+ TileSwizzle::ExpandShapeType expandShape = swizzle.expandShape;
+ TileSwizzle::ExpandShapeType accExpandShape = accSwizzle.expandShape;
+ // We will construct the permutation on the flattened `expandShape` dims that
+ // will move the cross thread dims to outermost, and store the resulting
+ // permutation in `crossThreadToOuterPerm`. This will be used later to adjust
+ // the swizzle.permutation properly.
+ SmallVector<int64_t> crossThreadToOuterPerm;
+ int groupStartIdx = 0;
+ for (int accGroupIdx = 0; accGroupIdx < accExpandShape.size();
+ ++accGroupIdx) {
+ // Index for corresponding `swizzle` group is 0 for Lhs and Rhs fragments,
+ // since the 1 index of Rhs and Lhs are K dimensions.
+ int groupIdx = fragment == MMAFragment::Acc ? accGroupIdx : 0;
+ // Skip N group for Lhs.
+ if (fragment == MMAFragment::Lhs && accGroupIdx == 1) {
+ continue;
+ }
+ // Skip M group for Rhs.
+ if (fragment == MMAFragment::Rhs && accGroupIdx == 0) {
+ continue;
+ }
+ TileSwizzle::ExpandShapeDimVectorType group = expandShape[groupIdx];
+ TileSwizzle::ExpandShapeDimVectorType accGroup =
+ accExpandShape[accGroupIdx];
+ // The expanded shape of the `accSwizzle` group may not necessarily match
+ // the expanded shape of the `swizzle` group, so we may need to expand the
+ // dimensions of `group` further so that the is a corresponding dimension
+ // in `group` for every CrossThread dimension in `accGroup`.
+ //
+ // For example:
+ // Lhs swizzle.expandShape = [[8, {16}], [2, {4}, 4]],
+ // Acc swizzle.expandShape = [[8, {4}, 4], [{4}, 2, {16}]],
+ //
+ // For group 0 the Lhs swizzle.expandShape has shape [8, {16}], while the
+ // Acc swizzle.expandShape has shape [8, {4}, 4]. Since, the CrossThread dim
+ // of the Acc swizzle group does not appear in the Lhs swizzle group, we
+ // need to expand the Lhs swizzle group so we can permute the groups in the
+ // same way. In this case the Lhs group's {16} dim will be expanded, and the
+ // new Lhs swizzle group will be [8, {4}, {4}].
+ if (accGroup.size() != group.size()) {
+ SmallVector<int64_t> accGroupShape =
+ llvm::map_to_vector(accGroup, [](TileSwizzle::Dim d) {
+ return static_cast<int64_t>(d.size);
+ });
+ SmallVector<int64_t> groupShape =
+ llvm::map_to_vector(group, [](TileSwizzle::Dim d) {
+ return static_cast<int64_t>(d.size);
+ });
+ std::optional<SmallVector<ReassociationIndices>> groupReassociation =
+ getReassociationIndicesForCollapse(accGroupShape, groupShape);
+ // For the current MFMA layouts, there should always be a reassociation
+ // found, since the ACC layout is always an expanded form of the combined
+ // LHS and RHS layouts.
+ assert(groupReassociation.has_value() &&
+ "expected to find reassociation");
+ TileSwizzle::ExpandShapeDimVectorType expandedGroup;
+ for (auto [i, reassociation] : llvm::enumerate(*groupReassociation)) {
+ for (int64_t reInd : reassociation) {
+ expandedGroup.push_back(
+ TileSwizzle::Dim(group[i].kind, accGroup[reInd].size));
+ }
+ int expandedPermIdx;
+ for (auto [permIdx, permDim] : llvm::enumerate(swizzle.permutation)) {
+ if (permDim > i) {
+ permDim += reassociation.size() - 1;
+ }
+ if (permDim == i) {
+ expandedPermIdx = permIdx;
+ }
+ }
+ for (int j = 0, e = reassociation.size() - 1; j < e; ++j) {
+ swizzle.permutation.insert(
+ swizzle.permutation.begin() + expandedPermIdx + j + 1, i + j + 1);
+ }
+ }
+ swizzle.expandShape[groupIdx] = expandedGroup;
+ }
+
+ // At this point, the `group` and `accGroup` will have the same shape, so
+ // we can compute a permutation for `accGroup` that would move the Acc
+ // CrossThread dims outermost, and then use that exact permutation for the
+ // `group`. Compute the localized permutation within the acc reassociation
+ // group, and apply to the expandShape dims within the `group`.
+ SmallVector<int64_t> crossThreadInds;
+ SmallVector<int64_t> otherInds;
+ for (int64_t idx = 0; idx < accGroup.size(); ++idx) {
+ TileSwizzle::Dim dim = accGroup[idx];
+ if (dim.kind == TileSwizzle::Dim::Kind::CrossThread) {
+ crossThreadInds.push_back(idx);
+ } else {
+ otherInds.push_back(idx);
+ }
+ }
+ SmallVector<int64_t> groupPerm(crossThreadInds);
+ groupPerm.append(otherInds);
+ applyPermutationToVector(swizzle.expandShape[groupIdx], groupPerm);
+
+ // Append the group permutation to the global `crossThreadToOuterPerm`.
+ // `groupPerm` contains the local permutation within the expand shape
+ // reassociation group, so we need to convert to the global permutation
+ // indices when adding to the global crossThreadToOuterPerm.
+ for (int64_t idx : groupPerm) {
+ crossThreadToOuterPerm.push_back(idx + groupStartIdx);
+ }
+ groupStartIdx += expandShape[groupIdx].size();
+ }
+
+ // The matching groups bewteen `accSwizzle` and `swizzle` have now been
+ // permuted. For Lhs and Rhs fragments, we need to fill in the rest of the
+ // permutation from the skipped groups that don't appear in the `accSwizzle`.
+ if (fragment != MMAFragment::Acc) {
+ for (int64_t i = swizzle.expandShape.front().size();
+ i < swizzle.permutation.size(); ++i) {
+ crossThreadToOuterPerm.push_back(i);
+ }
+ }
+
+ // At this point, the expandShape dims have been permuted within their groups,
+ // but we still need to adjust the swizzle.permutation to preserve the result
+ // shape of the swizzle. We have the following permutations:
+ // - perm(originalSrc -> crossThreadOuterSrc)
+ // - perm(originalSrc -> result)
+ // And we want `perm(crossThreadOuterSrc -> result)`, so we need to take
+ // `inverse(perm(originalSrc -> crossThreadOuterSrc))`, and then apply
+ // `perm(originalSrc -> result)`.
+ SmallVector<int64_t> perm = invertPermutationVector(crossThreadToOuterPerm);
+ applyPermutationToVector(perm, swizzle.permutation);
+ swizzle.permutation = perm;
+ return swizzle;
+}
+
+/// Return the full swizzle without any reordering of CrossThread dims. The
+/// result of this function should be passed to moveCrossThreadOutermost to
+/// get the final swizzle.
+static TileSwizzle
+getSwizzleBeforeMovingCrossThreadOutermost(IREE::GPU::DataTiledMMAAttr mma,
+ IREE::GPU::MMAFragment fragment) {
+ auto swizzle = getIntrinsicSwizzleBeforeMovingCrossThreadOutermost(
+ mma.getIntrinsic().getValue(), fragment);
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs:
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
@@ -197,4 +374,50 @@
return swizzle;
}
+TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
+ IREE::GPU::MMAFragment fragment) {
+ TileSwizzle swizzle =
+ getSwizzleBeforeMovingCrossThreadOutermost(mma, fragment);
+ // We want to move the CrossThread dims to be outermost in the source layout
+ // for the result. We need the transformations for the Lhs and Rhs to match
+ // with the Acc transformation, so we need to know what the acc swizzle is
+ // when moving CrossThread dims, even when the fragment is Lhs or Rhs.
+ TileSwizzle accSwizzle = swizzle;
+ if (fragment != IREE::GPU::MMAFragment::Acc) {
+ accSwizzle = getSwizzleBeforeMovingCrossThreadOutermost(
+ mma, IREE::GPU::MMAFragment::Acc);
+ }
+ LLVM_DEBUG(llvm::dbgs() << fragment
+ << " swizzle before moving CrossThread dims: "
+ << swizzle << "\n");
+ TileSwizzle crossThreadOuterSwizzle =
+ moveCrossThreadOutermost(swizzle, accSwizzle, fragment);
+ LLVM_DEBUG(llvm::dbgs() << fragment
+ << " swizzle after moving CrossThread dims: "
+ << crossThreadOuterSwizzle << "\n\n");
+ return crossThreadOuterSwizzle;
+}
+
+TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
+ IREE::GPU::MMAFragment fragment) {
+ auto swizzle =
+ getIntrinsicSwizzleBeforeMovingCrossThreadOutermost(intrinsic, fragment);
+ TileSwizzle accSwizzle = swizzle;
+ if (fragment != IREE::GPU::MMAFragment::Acc) {
+ accSwizzle = getIntrinsicSwizzleBeforeMovingCrossThreadOutermost(
+ intrinsic, IREE::GPU::MMAFragment::Acc);
+ }
+ LLVM_DEBUG(
+ llvm::dbgs() << fragment
+ << " intrinsic swizzle before moving CrossThread dims: "
+ << swizzle << "\n");
+ TileSwizzle crossThreadOuterSwizzle =
+ moveCrossThreadOutermost(swizzle, accSwizzle, fragment);
+ LLVM_DEBUG(
+ llvm::dbgs() << fragment
+ << " intrinsic swizzle after moving CrossThread dims: "
+ << crossThreadOuterSwizzle << "\n\n");
+ return crossThreadOuterSwizzle;
+}
+
} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
index c1b9fcc..a217ef1 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
@@ -13,13 +13,23 @@
namespace mlir::iree_compiler::IREE::GPU {
-// Returns the TileSwizzle bringing a tile from row-major layout into the tiled
-// layout consumed by the given `intrinsic` and `fragment`.
+/// Returns the TileSwizzle bringing a tile from row-major layout into the tiled
+/// layout consumed by the given `intrinsic` and `fragment`.
+///
+/// **Important Note** This function builds an intrinsic swizzle, and then calls
+/// `moveCrossThreadOutermost` (see static funcion in GPUTileSwizzleUtils.cpp),
+/// which does the necessary expansion to make dimensionality consistent with
+/// the swizzles generated by `getSwizzle`. The order of the swizzle.expandShape
+/// Dims generated by `getSwizzle` and `getIntrinsicSwizzle` may be different,
+/// but the corresponding permutations are adjusted such that the order of
+/// dimensions are the same after the permutation is applied. When using this
+/// function, do not expect the ordering of dimensions before applying the
+/// swizzle.permutationto be consistent with swizzles from `getSwizzle`.
Codegen::TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
IREE::GPU::MMAFragment fragment);
-// Returns the swizzle for the full data-tiled-mma tile, including all the
-// relevant unrolling and expansion factors.
+/// Returns the swizzle for the full data-tiled-mma tile, including all the
+/// relevant unrolling and expansion factors.
Codegen::TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 74f3853..5c0f6a3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -635,6 +635,7 @@
// distribution-only thread dimensions, we need to get back to the intrinsic.
TileSwizzle intrinsicSwizzle =
getIntrinsicSwizzle(getIntrinsic().getValue(), fragment);
+
SmallVector<int64_t> intrinsicLayoutThreadSizes =
sliceSwizzledShape(intrinsicSwizzle, [](TileSwizzle::Dim d) {
return d.kind == TileSwizzle::Dim::Kind::CrossThread;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
index f4ba3ae..aa818d6 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir
@@ -377,13 +377,13 @@
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
-func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<1x1x4x16xf32>, %rhs: tensor<1x1x4x16xf32>, %acc: tensor<1x1x4x16x4xf32>) -> tensor<1x1x4x16x4xf32>
+func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<1x1x4x4x4xf32>, %rhs: tensor<1x1x4x16xf32>, %acc: tensor<1x1x4x16x4xf32>) -> tensor<1x1x4x16x4xf32>
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>} {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
- } : tensor<1x1x4x16xf32>, tensor<1x1x4x16xf32> into tensor<1x1x4x16x4xf32>
+ } : tensor<1x1x4x4x4xf32>, tensor<1x1x4x16xf32> into tensor<1x1x4x16x4xf32>
return %0 : tensor<1x1x4x16x4xf32>
}
@@ -392,14 +392,15 @@
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x4x16x4xf32>)
+// CHECK-DAG: %[[LHS_IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (4, 4, 4)
+// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[LHS_IN_IDS]]#1, %[[LHS_IN_IDS]]#2, %[[LHS_IN_IDS]]#3] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1]
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
-// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1]
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
-// CHECK-SAME: : tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32> into tensor<1x1x1x1x4xf32>
+// CHECK-SAME: : tensor<1x1x1x1x1xf32>, tensor<1x1x1x1xf32> into tensor<1x1x1x1x4xf32>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1]
// CHECK: mapping = [#gpu.thread<linear_dim_0>]
@@ -411,13 +412,13 @@
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
-func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled(%lhs: tensor<1x1x2x4x16x4xf32>, %rhs: tensor<1x1x2x4x16x4xf32>, %acc: tensor<1x1x2x2x4x16x4xf32>) -> tensor<1x1x2x2x4x16x4xf32>
+func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled(%lhs: tensor<1x1x2x4x4x4x4xf32>, %rhs: tensor<1x1x2x4x16x4xf32>, %acc: tensor<1x1x2x2x4x16x4xf32>) -> tensor<1x1x2x2x4x16x4xf32>
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>} {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, intrinsics_m = 2, intrinsics_n = 2, intrinsics_k = 4>
- } : tensor<1x1x2x4x16x4xf32>, tensor<1x1x2x4x16x4xf32> into tensor<1x1x2x2x4x16x4xf32>
+ } : tensor<1x1x2x4x4x4x4xf32>, tensor<1x1x2x4x16x4xf32> into tensor<1x1x2x2x4x16x4xf32>
return %0 : tensor<1x1x2x2x4x16x4xf32>
}
@@ -426,16 +427,17 @@
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>)
-// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
+// CHECK-DAG: %[[LHS_IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (4, 4, 4)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
-// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: [0, 0, 0, %[[LHS_IN_IDS]]#1, %[[LHS_IN_IDS]]#2, %[[LHS_IN_IDS]]#3, 0] [1, 1, 2, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
+// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, intrinsics_m = 2, intrinsics_n = 2, intrinsics_k = 4>
-// CHECK-SAME: : tensor<1x1x2x1x1x4xf32>, tensor<1x1x2x1x1x4xf32> into tensor<1x1x2x2x1x1x4xf32>
+// CHECK-SAME: : tensor<1x1x2x1x1x1x4xf32>, tensor<1x1x2x1x1x4xf32> into tensor<1x1x2x2x1x1x4xf32>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]]
// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: mapping = [#gpu.thread<linear_dim_0>]
@@ -447,13 +449,13 @@
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
-func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor<1x1x2x4x16x4xf32>, %rhs: tensor<1x1x2x4x16x4xf32>, %acc: tensor<1x1x2x2x4x16x4xf32>) -> tensor<1x1x2x2x4x16x4xf32>
+func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor<1x1x2x4x4x4x4xf32>, %rhs: tensor<1x1x2x4x16x4xf32>, %acc: tensor<1x1x2x2x4x16x4xf32>) -> tensor<1x1x2x2x4x16x4xf32>
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64>} {
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, subgroups_m = 2, subgroups_n = 2, intrinsics_k = 4>
- } : tensor<1x1x2x4x16x4xf32>, tensor<1x1x2x4x16x4xf32> into tensor<1x1x2x2x4x16x4xf32>
+ } : tensor<1x1x2x4x4x4x4xf32>, tensor<1x1x2x4x16x4xf32> into tensor<1x1x2x2x4x16x4xf32>
return %0 : tensor<1x1x2x2x4x16x4xf32>
}
@@ -462,9 +464,10 @@
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (256) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>)
-// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (2, 4, 16)
+// CHECK-DAG: %[[LHS_IN_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (2, 4, 4, 4)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
-// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: [0, 0, %[[LHS_IN_IDS]]#1, %[[LHS_IN_IDS]]#2, %[[LHS_IN_IDS]]#3, %[[LHS_IN_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
+// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (2, 4, 16)
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (2, 2, 4, 16)
@@ -472,7 +475,7 @@
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, subgroups_m = 2, subgroups_n = 2, intrinsics_k = 4>}
-// CHECK-SAME: : tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32>
+// CHECK-SAME: : tensor<1x1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: mapping = [#gpu.thread<linear_dim_0>]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 75dc8b2..a89cb39 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -711,10 +711,10 @@
func.func @multi_mma_data_tiled_unrolled_MFMA_F32_16x16x4_F32()
attributes {translation_info = #translation_info} {
%c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x1x8x4x16x4xf32>>
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x1x8x4x4x4x4xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x1x4x2x4x16x4xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>>
- %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0, 0], sizes = [4, 1, 8, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x1x8x4x16x4xf32>> -> tensor<4x1x8x4x16x4xf32>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0, 0, 0], sizes = [4, 1, 8, 4, 4, 4, 4], strides = [1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x1x8x4x4x4x4xf32>> -> tensor<4x1x8x4x4x4x4xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0, 0, 0, 0], sizes = [4, 1, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x1x4x2x4x16x4xf32>> -> tensor<4x1x4x2x4x16x4xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [4, 4, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>> -> tensor<4x4x8x4x2x4x16x4xf32>
%6 = iree_gpu.multi_mma %3, %4, %5 {
@@ -733,7 +733,7 @@
intrinsics_n = 2,
subgroups_n = 4,
intrinsics_k = 4>}
- : tensor<4x1x8x4x16x4xf32>, tensor<4x1x4x2x4x16x4xf32> into tensor<4x4x8x4x2x4x16x4xf32>
+ : tensor<4x1x8x4x4x4x4xf32>, tensor<4x1x4x2x4x16x4xf32> into tensor<4x4x8x4x2x4x16x4xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [4, 4, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<4x4x8x4x2x4x16x4xf32> -> !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>>
return
}
@@ -745,7 +745,7 @@
// CHECK-DAG: %[[BINDING_A:.+]] = hal.interface.binding.subspan {{.*}} binding(0)
// CHECK-DAG: %[[BINDING_B:.+]] = hal.interface.binding.subspan {{.*}} binding(1)
// CHECK-DAG: %[[BINDING_C:.+]] = hal.interface.binding.subspan {{.*}} binding(2)
-// CHECK-DAG: %[[A_ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x16x4xf32, #gpu.address_space<workgroup>>
+// CHECK-DAG: %[[A_ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x4x4xf32, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[B_ALLOC:.+]] = memref.alloc() : memref<1x1x4x2x4x16x4xf32, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier
// CHECK-DAG: %[[A_GLOBAL_LOAD:.+]] = vector.transfer_read %[[BINDING_A]]{{.*}} vector<4xf32>
@@ -753,21 +753,21 @@
// CHECK-DAG: vector.transfer_write %[[A_GLOBAL_LOAD]], %[[A_ALLOC]]
// CHECK-DAG: vector.transfer_write %[[B_GLOBAL_LOAD]], %[[B_ALLOC]]
// CHECK: gpu.barrier
-// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read %[[A_ALLOC]]{{.*}} vector<8x1x1x4xf32>
+// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read %[[A_ALLOC]]{{.*}} vector<8x1x1x1x4xf32>
// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read %[[B_ALLOC]]{{.*}} vector<2x1x1x4xf32>
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_00_0:.+]] = vector.extract %[[C_READ]][0, 0, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_01_0:.+]] = vector.extract %[[C_READ]][0, 1, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_70_0:.+]] = vector.extract %[[C_READ]][7, 0, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
// CHECK-DAG: %[[C_71_0:.+]] = vector.extract %[[C_READ]][7, 1, 0, 0] : vector<4xf32> from vector<8x2x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT00:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT01:.+]] = vector.extract %[[A_READ]][0, 0, 0, 1] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT02:.+]] = vector.extract %[[A_READ]][0, 0, 0, 2] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT03:.+]] = vector.extract %[[A_READ]][0, 0, 0, 3] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT70:.+]] = vector.extract %[[A_READ]][7, 0, 0, 0] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT71:.+]] = vector.extract %[[A_READ]][7, 0, 0, 1] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT72:.+]] = vector.extract %[[A_READ]][7, 0, 0, 2] : f32 from vector<8x1x1x4xf32>
-// CHECK-DAG: %[[A_EXTRACT73:.+]] = vector.extract %[[A_READ]][7, 0, 0, 3] : f32 from vector<8x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT00:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0, 0] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT01:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0, 1] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT02:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0, 2] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT03:.+]] = vector.extract %[[A_READ]][0, 0, 0, 0, 3] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT70:.+]] = vector.extract %[[A_READ]][7, 0, 0, 0, 0] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT71:.+]] = vector.extract %[[A_READ]][7, 0, 0, 0, 1] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT72:.+]] = vector.extract %[[A_READ]][7, 0, 0, 0, 2] : f32 from vector<8x1x1x1x4xf32>
+// CHECK-DAG: %[[A_EXTRACT73:.+]] = vector.extract %[[A_READ]][7, 0, 0, 0, 3] : f32 from vector<8x1x1x1x4xf32>
// CHECK-DAG: %[[B_EXTRACT00:.+]] = vector.extract %[[B_READ]][0, 0, 0, 0] : f32 from vector<2x1x1x4xf32>
// CHECK-DAG: %[[B_EXTRACT01:.+]] = vector.extract %[[B_READ]][0, 0, 0, 1] : f32 from vector<2x1x1x4xf32>
// CHECK-DAG: %[[B_EXTRACT02:.+]] = vector.extract %[[B_READ]][0, 0, 0, 2] : f32 from vector<2x1x1x4xf32>