Simplify GPUTileSwizzleUtils and avoid creating unit dims. (#19105)
In `getIntrinsicSwizzle`, we had a slightly roundabout way of
constructing the swizzle from the `SingleSubgroupLayout`. We started
from the `thread` dims, which we used unconditionally even if they had
the value 1, leading to unit dims; and then we inserted the `element`
dims *on the inside*, which required custom manipulation of the
`swizzle` field. Now we just start from the `element` dims and work our
way outwards from there, which means we can reuse the same helper that
used to be named `unroll` and that we rename here to `expand` in
preparation for https://github.com/iree-org/iree/pull/19102, and which
we also move to be a `static` helper since it's no longer used outside
of this file.
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
index f6e9445..2a6b9c6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir
@@ -50,8 +50,8 @@
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
-// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
-// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x8x2x16xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
index 738bb6a..82eff59 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
@@ -49,6 +49,10 @@
// The size of the dimension.
int16_t size = 0;
+
+ // Support constructing from any size type.
+ template <typename T>
+ Dim(Kind kind, T size) : kind(kind), size(size) {}
};
using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
index ae9d5d9..1171b1e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp
@@ -3,85 +3,77 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
-
namespace mlir::iree_compiler {
-// Given an `expandShape` vector-of-vectors describing the mapping from source
-// dimensions to expanded dimensions, returns the index of the first expanded
-// dimension corresponding to the given source dimension index.
-static int64_t
-getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape,
- int64_t srcIndex) {
- int dstIndexFirst = 0;
- for (int i = 0; i < srcIndex; ++i) {
- dstIndexFirst += expandShape[i].size();
+using Kind = TileSwizzle::Dim::Kind;
+
+// Returns the index of the first destination dimension corresponding to the
+// given source dimension `srcIdx`.
+static int64_t expandedDimIdx(const TileSwizzle::ExpandShapeType &expandShape,
+ int srcIdx) {
+ int dstIdx = 0;
+ for (int i = 0; i < srcIdx; ++i) {
+ dstIdx += expandShape[i].size();
}
- return dstIndexFirst;
+ return dstIdx;
}
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
- TileSwizzle::Dim::Kind kind) {
- assert(unrollFactor > 1);
- int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
- TileSwizzle::Dim unrollDim;
- unrollDim.size = unrollFactor;
- unrollDim.kind = kind;
+// Pushes `dim` to the front of `swizzle.expandShape[srcIdx]`, and updates
+// `swizzle.permutation` to make the new dimension outer-most among the dims in
+// `swizzle.expandShape[srcIdx]`.
+//
+// This can be used to unroll a kernel with kind = CrossIntrinsic,
+// or to expand a kernel to multiple subgroups with kind = CrossThread.
+//
+// Example:
+// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
+// Input srcIdx = 1
+// Input dim.size = 4
+// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
+//
+static void expand(TileSwizzle &swizzle, int srcIdx, TileSwizzle::Dim dim) {
+ int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx);
// The new unrolling dimension is inserted at the start of the expandShape
- // dimensions group corresponding to srcIndex.
- swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(),
- unrollDim);
+ // dimensions group corresponding to srcIdx.
+ swizzle.expandShape[srcIdx].insert(swizzle.expandShape[srcIdx].begin(), dim);
// Since we are not interleaving here, generating side-by-side copies of the
// original layout, the new unrolling dimension is the new outermost
// dimension. Existing entries get shifted to make room for it.
for (auto &p : swizzle.permutation) {
- p += (p >= dstIndexFirst);
+ p += (p >= dstIdx);
}
- swizzle.permutation.insert(swizzle.permutation.begin(), dstIndexFirst);
+ swizzle.permutation.insert(swizzle.permutation.begin(), dstIdx);
}
-void interleave(TileSwizzle &swizzle, int srcIndex,
- int expandedDimIndexToInterleaveAt) {
- // Compute which inner dimension to permute the current outer dimension into.
- int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
- int dstIndexToInterleaveAt = dstIndexFirst + expandedDimIndexToInterleaveAt;
-
+// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
+// move permutation[0], the outer-most dimension (which the unroll() function
+// created to be the unrolling dimension), to the inner dimension given by
+// `expandedIdx`.
+//
+// Example:
+// Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
+// Input srcIdx = 1
+// Input expandedIdx = 1
+// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
+//
+static void interleave(TileSwizzle &swizzle, int srcIdx, int expandedIdx) {
+ int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx) + expandedIdx;
SmallVector<int64_t> outPermutation(swizzle.permutation.size());
// The leading dimension, permutation[0], gets moved inwards to the
- // position that we just computed, dstIndexToInterleaveAt.
- outPermutation[dstIndexToInterleaveAt] = swizzle.permutation[0];
+ // position that we just computed, dstIdx.
+ outPermutation[dstIdx] = swizzle.permutation[0];
// Outer dimensions get shifted outwards to fill the gap.
- for (int i = 0; i < dstIndexToInterleaveAt; ++i) {
+ for (int i = 0; i < dstIdx; ++i) {
outPermutation[i] = swizzle.permutation[i + 1];
}
- // Inner dimensions don't change. That is to say that we only interleave
- // at `targetInterleavedElements` granularity, we don't swizzle further
- // internally to that.
- for (int i = dstIndexToInterleaveAt + 1; i < outPermutation.size(); ++i) {
+ // Inner dimensions don't change.
+ for (int i = dstIdx + 1; i < outPermutation.size(); ++i) {
outPermutation[i] = swizzle.permutation[i];
}
swizzle.permutation = outPermutation;
}
-// Returns the permutation of indices that sorts `v` with the given comparator.
-template <template <typename U> class Comparator, typename T>
-static SmallVector<int64_t> getSortingPermutation(ArrayRef<T> v) {
- using P = std::pair<int64_t, T>;
- SmallVector<P> pairs;
- pairs.reserve(v.size());
- for (auto [i, x] : llvm::enumerate(v)) {
- pairs.push_back({i, x});
- }
- std::sort(pairs.begin(), pairs.end(),
- [](P p1, P p2) { return Comparator<T>{}(p1.second, p2.second); });
- SmallVector<int64_t> indices;
- for (auto p : pairs) {
- indices.push_back(p.first);
- }
- return indices;
-}
-
TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
IREE::GPU::MMAFragment fragment) {
auto layout = IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
@@ -95,57 +87,48 @@
std::swap(layout.element[0], layout.element[1]);
}
- // Initially populate swizzle.expandShape with just the thread sizes, no
- // shape expansion for now.
TileSwizzle swizzle;
- for (auto t : layout.thread) {
- TileSwizzle::Dim dim;
- dim.size = t;
- dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
- swizzle.expandShape.push_back({dim});
- }
- // The layout strides decide the initial swizzle.permutation.
- // Some WMMA intrinsics have tstrides=0 value. That always indicates an outer
- // dimension, so overwrite 0 with a large value to get the right order.
- SmallVector<int64_t, 2> order = layout.tstrides;
- for (auto &val : order) {
- val = (val == 0) ? INT64_MAX : val;
- }
- swizzle.permutation = getSortingPermutation<std::greater, int64_t>(order);
- // Deal with any element size greater than 1 by inserting it innermost.
- // Notice that this is similar to the unroll() function, just creating an
- // inner dimension instead of an outer dimension.
+ // There are two source dimensions, corresponding to the arrays in `layout`
+ // all having size 2. Let's just guard that assumption with one assert here.
+ assert(layout.thread.size() == 2);
+ swizzle.expandShape.resize(2);
+ // Expand the shape from inner-most to outer-most dimension, so that we can
+ // simply use the `expand` helper function, which creates new outer dims.
+ // `layout.element` dims are inner-most, so we add them first.
for (auto [i, e] : llvm::enumerate(layout.element)) {
if (e != 1) {
- TileSwizzle::Dim dim;
- dim.size = e;
- dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
- swizzle.expandShape[i].push_back(dim);
- int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1;
- for (auto &p : swizzle.permutation) {
- p += (p >= newIndex);
- }
- swizzle.permutation.push_back(newIndex);
+ expand(swizzle, i, {Kind::Internal, e});
}
}
- // Deal with any outer size greater than 1 as just a call to unroll.
- // Iterate over dims in reverse order because we are creating a new outermost
- // dimension each time.
+ // Next come `layout.thread` dims.
+ for (auto [i, t] : llvm::enumerate(layout.thread)) {
+ if (t != 1) {
+ expand(swizzle, i, {Kind::CrossThread, t});
+ }
+ }
+ // `layout.thread` dims are special in that they come with `layout.tstrides`
+ // which may call for a swap in `swizzle.permutation`. We only need to worry
+ // about that when both `layout.thread` sizes are greater than 1, so we didn't
+ // skip them above. Note that this condition also implies that we don't need
+ // to worry about `layout.tstrides == 0` which only happens with
+ // `layout.thread == 1`.
+ if (layout.thread[0] != 1 && layout.thread[1] != 1 &&
+ layout.tstrides[0] > layout.tstrides[1]) {
+ std::swap(swizzle.permutation[0], swizzle.permutation[1]);
+ }
+ // Finally come `layout.outer` dims, added last so they are outer-most.
for (auto [i, o] : llvm::enumerate(layout.outer)) {
if (o != 1) {
- // `layout.outer` means additional Internal dimensions, just like
- // `layout.element`, just swizzled outermost.
- unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
+ expand(swizzle, i, {Kind::Internal, o});
}
}
-
return swizzle;
}
static int getInnermostNonInternalDimIdx(
const TileSwizzle::ExpandShapeDimVectorType &shape) {
for (int idx = shape.size() - 1; idx >= 0; --idx) {
- if (shape[idx].kind != TileSwizzle::Dim::Kind::Internal) {
+ if (shape[idx].kind != Kind::Internal) {
return idx;
}
}
@@ -156,22 +139,21 @@
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment) {
auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
- using Kind = TileSwizzle::Dim::Kind;
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs:
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
// Unroll on K with interleaving, then on M.
if (mma.getUnrollK() > 1) {
- unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+ expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
int interleavingIdx =
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
interleave(swizzle, 1, interleavingIdx);
}
if (mma.getUnrollM() > 1) {
- unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+ expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getUnrollMToSubgroups() > 1) {
- unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+ expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
}
break;
case IREE::GPU::MMAFragment::Rhs:
@@ -179,32 +161,32 @@
// source dimensions are N (index 0) and K (index 1).
// Unroll on K with interleaving, then on N.
if (mma.getUnrollK() > 1) {
- unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+ expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
int interleavingIdx =
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
interleave(swizzle, 1, interleavingIdx);
}
if (mma.getUnrollN() > 1) {
- unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic);
+ expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getUnrollNToSubgroups() > 1) {
- unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+ expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
}
break;
case IREE::GPU::MMAFragment::Acc:
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
// 1). Unroll on N, then on M.
if (mma.getUnrollN() > 1) {
- unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic);
+ expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
}
if (mma.getUnrollNToSubgroups() > 1) {
- unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+ expand(swizzle, 1, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
}
if (mma.getUnrollM() > 1) {
- unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+ expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
}
if (mma.getUnrollMToSubgroups() > 1) {
- unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+ expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
}
break;
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
index 7cf2490..413f801 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h
@@ -19,40 +19,10 @@
IREE::GPU::MMAFragment fragment);
// Returns the swizzle for the full data-tiled-mma tile, including all the
-// relevant unrolling factors.
+// relevant unrolling and expansion factors.
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment);
-// Unrolls the dimension given by `srcIndex` by the given `unrollFactor`.
-// This is not interleaving layouts. The layout will consist of multiple copies
-// of the input tile, side by side.
-//
-// The enum parameter `kind` initializes the corresponding member on the newly
-// created TileSwizzle::Dim.
-//
-// Example:
-// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
-// Input srcIndex = 1
-// Input unrollFactor = 4
-// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
-//
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
- TileSwizzle::Dim::Kind kind);
-
-// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
-// move permutation[0], the outer-most dimension (which the unroll() function
-// created to be the unrolling dimension), to the inner dimension given by
-// `expandedDimIndexToInterleaveAt`.
-//
-// Example:
-// Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
-// Input srcIndex = 1
-// Input expandedDimIndexToInterleaveAt = 1
-// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
-//
-void interleave(TileSwizzle &swizzle, int srcIndex,
- int expandedDimIndexToInterleaveAt);
-
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
index da7a3b3..9a2e0ad 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir
@@ -76,9 +76,9 @@
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
%lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
- // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, VECTORX], [1, 16]>>}}
%rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
- // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
+ // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, VECTORX], [1, 16]>>}}
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORX, LANEY], [1, 8, 2]>, <[ BATCHY, LANEX], [1, 16]>>}}
return %output : vector<16x16xf32>