Simplify GPUTileSwizzleUtils and avoid creating unit dims. (#19105)

In `getIntrinsicSwizzle`, we had a slightly roundabout way of
constructing the swizzle from the `SingleSubgroupLayout`. We started
from the `thread` dims, which we used unconditionally even if they had
the value 1, leading to unit dims; and then we inserted the `element`
dims *on the inside*, which required custom manipulation of the
`swizzle` field. Now we just start from the `element` dims and work our
way outwards from there, which means we can reuse the same helper that
used to be named `unroll` and that we rename here to `expand` in
preparation for https://github.com/iree-org/iree/pull/19102, and which
we also move to be a `static` helper since it's no longer used outside
of this file.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
index f6e9445..2a6b9c6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
@@ -50,8 +50,8 @@
 // CHECK-DAG:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
 // CHECK-DAG:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
 // CHECK-DAG:   %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
-// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
+// CHECK-DAG:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
+// CHECK-DAG:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
 // CHECK-DAG:   %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x8x2x16xf32>
 // CHECK:       %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
 // CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
index 738bb6a..82eff59 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
@@ -49,6 +49,10 @@
 
     // The size of the dimension.
     int16_t size = 0;
+
+    // Support constructing from any size type.
+    template <typename T>
+    Dim(Kind kind, T size) : kind(kind), size(size) {}
   };
 
   using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
index ae9d5d9..1171b1e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
@@ -3,85 +3,77 @@
 // Licensed under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
 #include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
-
 namespace mlir::iree_compiler {
 
-// Given an `expandShape` vector-of-vectors describing the mapping from source
-// dimensions to expanded dimensions, returns the index of the first expanded
-// dimension corresponding to the given source dimension index.
-static int64_t
-getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape,
-                       int64_t srcIndex) {
-  int dstIndexFirst = 0;
-  for (int i = 0; i < srcIndex; ++i) {
-    dstIndexFirst += expandShape[i].size();
+using Kind = TileSwizzle::Dim::Kind;
+
+// Returns the index of the first destination dimension corresponding to the
+// given source dimension `srcIdx`.
+static int64_t expandedDimIdx(const TileSwizzle::ExpandShapeType &expandShape,
+                              int srcIdx) {
+  int dstIdx = 0;
+  for (int i = 0; i < srcIdx; ++i) {
+    dstIdx += expandShape[i].size();
   }
-  return dstIndexFirst;
+  return dstIdx;
 }
 
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
-            TileSwizzle::Dim::Kind kind) {
-  assert(unrollFactor > 1);
-  int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
-  TileSwizzle::Dim unrollDim;
-  unrollDim.size = unrollFactor;
-  unrollDim.kind = kind;
+// Pushes `dim` to the front of `swizzle.expandShape[srcIdx]`, and updates
+// `swizzle.permutation` to make the new dimension outer-most among the dims in
+// `swizzle.expandShape[srcIdx]`.
+//
+// This can be used to unroll a kernel with kind = CrossIntrinsic,
+// or to expand a kernel to multiple subgroups with kind = CrossThread.
+//
+// Example:
+//    Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
+//    Input srcIdx = 1
+//    Input dim.size = 4
+// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
+//
+static void expand(TileSwizzle &swizzle, int srcIdx, TileSwizzle::Dim dim) {
+  int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx);
   // The new unrolling dimension is inserted at the start of the expandShape
-  // dimensions group corresponding to srcIndex.
-  swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(),
-                                       unrollDim);
+  // dimensions group corresponding to srcIdx.
+  swizzle.expandShape[srcIdx].insert(swizzle.expandShape[srcIdx].begin(), dim);
   // Since we are not interleaving here, generating side-by-side copies of the
   // original layout, the new unrolling dimension is the new outermost
   // dimension. Existing entries get shifted to make room for it.
   for (auto &p : swizzle.permutation) {
-    p += (p >= dstIndexFirst);
+    p += (p >= dstIdx);
   }
-  swizzle.permutation.insert(swizzle.permutation.begin(), dstIndexFirst);
+  swizzle.permutation.insert(swizzle.permutation.begin(), dstIdx);
 }
 
-void interleave(TileSwizzle &swizzle, int srcIndex,
-                int expandedDimIndexToInterleaveAt) {
-  // Compute which inner dimension to permute the current outer dimension into.
-  int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
-  int dstIndexToInterleaveAt = dstIndexFirst + expandedDimIndexToInterleaveAt;
-
+// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
+// move permutation[0], the outer-most dimension (which the unroll() function
+// created to be the unrolling dimension), to the inner dimension given by
+// `expandedIdx`.
+//
+// Example:
+//    Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
+//    Input srcIdx = 1
+//    Input expandedIdx = 1
+// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
+//
+static void interleave(TileSwizzle &swizzle, int srcIdx, int expandedIdx) {
+  int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx) + expandedIdx;
   SmallVector<int64_t> outPermutation(swizzle.permutation.size());
   // The leading dimension, permutation[0], gets moved inwards to the
-  // position that we just computed, dstIndexToInterleaveAt.
-  outPermutation[dstIndexToInterleaveAt] = swizzle.permutation[0];
+  // position that we just computed, dstIdx.
+  outPermutation[dstIdx] = swizzle.permutation[0];
   // Outer dimensions get shifted outwards to fill the gap.
-  for (int i = 0; i < dstIndexToInterleaveAt; ++i) {
+  for (int i = 0; i < dstIdx; ++i) {
     outPermutation[i] = swizzle.permutation[i + 1];
   }
-  // Inner dimensions don't change. That is to say that we only interleave
-  // at `targetInterleavedElements` granularity, we don't swizzle further
-  // internally to that.
-  for (int i = dstIndexToInterleaveAt + 1; i < outPermutation.size(); ++i) {
+  // Inner dimensions don't change.
+  for (int i = dstIdx + 1; i < outPermutation.size(); ++i) {
     outPermutation[i] = swizzle.permutation[i];
   }
   swizzle.permutation = outPermutation;
 }
 
-// Returns the permutation of indices that sorts `v` with the given comparator.
-template <template <typename U> class Comparator, typename T>
-static SmallVector<int64_t> getSortingPermutation(ArrayRef<T> v) {
-  using P = std::pair<int64_t, T>;
-  SmallVector<P> pairs;
-  pairs.reserve(v.size());
-  for (auto [i, x] : llvm::enumerate(v)) {
-    pairs.push_back({i, x});
-  }
-  std::sort(pairs.begin(), pairs.end(),
-            [](P p1, P p2) { return Comparator<T>{}(p1.second, p2.second); });
-  SmallVector<int64_t> indices;
-  for (auto p : pairs) {
-    indices.push_back(p.first);
-  }
-  return indices;
-}
-
 TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
                                 IREE::GPU::MMAFragment fragment) {
   auto layout = IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
@@ -95,57 +87,48 @@
     std::swap(layout.element[0], layout.element[1]);
   }
 
-  // Initially populate swizzle.expandShape with just the thread sizes, no
-  // shape expansion for now.
   TileSwizzle swizzle;
-  for (auto t : layout.thread) {
-    TileSwizzle::Dim dim;
-    dim.size = t;
-    dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
-    swizzle.expandShape.push_back({dim});
-  }
-  // The layout strides decide the initial swizzle.permutation.
-  // Some WMMA intrinsics have tstrides=0 value. That always indicates an outer
-  // dimension, so overwrite 0 with a large value to get the right order.
-  SmallVector<int64_t, 2> order = layout.tstrides;
-  for (auto &val : order) {
-    val = (val == 0) ? INT64_MAX : val;
-  }
-  swizzle.permutation = getSortingPermutation<std::greater, int64_t>(order);
-  // Deal with any element size greater than 1 by inserting it innermost.
-  // Notice that this is similar to the unroll() function, just creating an
-  // inner dimension instead of an outer dimension.
+  // There are two source dimensions, corresponding to the arrays in `layout`
+  // all having size 2. Let's just guard that assumption with one assert here.
+  assert(layout.thread.size() == 2);
+  swizzle.expandShape.resize(2);
+  // Expand the shape from inner-most to outer-most dimension, so that we can
+  // simply use the `expand` helper function, which creates new outer dims.
+  // `layout.element` dims are inner-most, so we add them first.
   for (auto [i, e] : llvm::enumerate(layout.element)) {
     if (e != 1) {
-      TileSwizzle::Dim dim;
-      dim.size = e;
-      dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
-      swizzle.expandShape[i].push_back(dim);
-      int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1;
-      for (auto &p : swizzle.permutation) {
-        p += (p >= newIndex);
-      }
-      swizzle.permutation.push_back(newIndex);
+      expand(swizzle, i, {Kind::Internal, e});
     }
   }
-  // Deal with any outer size greater than 1 as just a call to unroll.
-  // Iterate over dims in reverse order because we are creating a new outermost
-  // dimension each time.
+  // Next come `layout.thread` dims.
+  for (auto [i, t] : llvm::enumerate(layout.thread)) {
+    if (t != 1) {
+      expand(swizzle, i, {Kind::CrossThread, t});
+    }
+  }
+  // `layout.thread` dims are special in that they come with `layout.tstrides`
+  // which may call for a swap in `swizzle.permutation`. We only need to worry
+  // about that when both `layout.thread` sizes are greater than 1, so we didn't
+  // skip them above. Note that this condition also implies that we don't need
+  // to worry about `layout.tstrides == 0` which only happens with
+  // `layout.thread == 1`.
+  if (layout.thread[0] != 1 && layout.thread[1] != 1 &&
+      layout.tstrides[0] > layout.tstrides[1]) {
+    std::swap(swizzle.permutation[0], swizzle.permutation[1]);
+  }
+  // Finally come `layout.outer` dims, added last so they are outer-most.
   for (auto [i, o] : llvm::enumerate(layout.outer)) {
     if (o != 1) {
-      // `layout.outer` means additional Internal dimensions, just like
-      // `layout.element`, just swizzled outermost.
-      unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
+      expand(swizzle, i, {Kind::Internal, o});
     }
   }
-
   return swizzle;
 }
 
 static int getInnermostNonInternalDimIdx(
     const TileSwizzle::ExpandShapeDimVectorType &shape) {
   for (int idx = shape.size() - 1; idx >= 0; --idx) {
-    if (shape[idx].kind != TileSwizzle::Dim::Kind::Internal) {
+    if (shape[idx].kind != Kind::Internal) {
       return idx;
     }
   }
@@ -156,22 +139,21 @@
 TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
                        IREE::GPU::MMAFragment fragment) {
   auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
-  using Kind = TileSwizzle::Dim::Kind;
   switch (fragment) {
   case IREE::GPU::MMAFragment::Lhs:
     // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
     // Unroll on K with interleaving, then on M.
     if (mma.getUnrollK() > 1) {
-      unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+      expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
       int interleavingIdx =
           getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
       interleave(swizzle, 1, interleavingIdx);
     }
     if (mma.getUnrollM() > 1) {
-      unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+      expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
     }
     if (mma.getUnrollMToSubgroups() > 1) {
-      unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+      expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
     }
     break;
   case IREE::GPU::MMAFragment::Rhs:
@@ -179,32 +161,32 @@
     // source dimensions are N (index 0) and K (index 1).
     // Unroll on K with interleaving, then on N.
     if (mma.getUnrollK() > 1) {
-      unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+      expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
       int interleavingIdx =
           getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
       interleave(swizzle, 1, interleavingIdx);
     }
     if (mma.getUnrollN() > 1) {
-      unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic);
+      expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollN()});
     }
     if (mma.getUnrollNToSubgroups() > 1) {
-      unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+      expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
     }
     break;
   case IREE::GPU::MMAFragment::Acc:
     // C-matrix (accumulator). Source dimensions are M (index 0) and N (index
     // 1). Unroll on N, then on M.
     if (mma.getUnrollN() > 1) {
-      unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic);
+      expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
     }
     if (mma.getUnrollNToSubgroups() > 1) {
-      unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+      expand(swizzle, 1, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
     }
     if (mma.getUnrollM() > 1) {
-      unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+      expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
     }
     if (mma.getUnrollMToSubgroups() > 1) {
-      unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+      expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
     }
     break;
   }
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
index 7cf2490..413f801 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
@@ -19,40 +19,10 @@
                                 IREE::GPU::MMAFragment fragment);
 
 // Returns the swizzle for the full data-tiled-mma tile, including all the
-// relevant unrolling factors.
+// relevant unrolling and expansion factors.
 TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
                        IREE::GPU::MMAFragment fragment);
 
-// Unrolls the dimension given by `srcIndex` by the given `unrollFactor`.
-// This is not interleaving layouts. The layout will consist of multiple copies
-// of the input tile, side by side.
-//
-// The enum parameter `kind` initializes the corresponding member on the newly
-// created TileSwizzle::Dim.
-//
-// Example:
-//    Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
-//    Input srcIndex = 1
-//    Input unrollFactor = 4
-// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
-//
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
-            TileSwizzle::Dim::Kind kind);
-
-// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
-// move permutation[0], the outer-most dimension (which the unroll() function
-// created to be the unrolling dimension), to the inner dimension given by
-// `expandedDimIndexToInterleaveAt`.
-//
-// Example:
-//    Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
-//    Input srcIndex = 1
-//    Input expandedDimIndexToInterleaveAt = 1
-// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
-//
-void interleave(TileSwizzle &swizzle, int srcIndex,
-                int expandedDimIndexToInterleaveAt);
-
 } // namespace mlir::iree_compiler
 
 #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
index da7a3b3..9a2e0ad 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
@@ -76,9 +76,9 @@
     %c0 = arith.constant 0 : index
     %cst_0 = arith.constant 0.0 : f16
     %lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
-    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  LANEY,  VECTORX], [1, 1, 16]>>}}
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  VECTORX], [1, 16]>>}}
     %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
-    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  LANEY,  VECTORX], [1, 1, 16]>>}}
+    // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  LANEX], [1, 16]>, <[ BATCHY,  VECTORX], [1, 16]>>}}
     %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
     // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX,  VECTORX,  LANEY], [1, 8, 2]>, <[ BATCHY,  LANEX], [1, 16]>>}}
     return %output : vector<16x16xf32>