GPU data tiling: Refine tile dimensions, more preparation for thread distribution. (#18556)
* The unrolling factors in `DataTiledMMAAttr` get split between plain
unrolling and unroll-to-subgroups.
* The dimensions in `TileSwizzle` get an enum telling if they are
cross-thread / cross-instruction.
* `getSwizzle` gets moved to GPUTileSwizzleUtils as it is going to be
used in codegen outside of MaterializeEncoding.
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index aaa4ad3..0cc1f28 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -143,6 +143,7 @@
"TileDispatchUsingForall.cpp",
"TileDispatchUsingInterface.cpp",
"TileSizeSelection.cpp",
+ "TileSwizzle.cpp",
"TypePropagationPass.cpp",
"UserConfig.cpp",
"VectorizeMemrefCopy.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index d828e08..3a94dcd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -135,6 +135,7 @@
"TileDispatchUsingForall.cpp"
"TileDispatchUsingInterface.cpp"
"TileSizeSelection.cpp"
+ "TileSwizzle.cpp"
"TypePropagationPass.cpp"
"UserConfig.cpp"
"VectorizeMemrefCopy.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index 9b8468b..96d7594 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -210,10 +210,12 @@
}
SmallVector<int64_t>
-getExpandedTileShape(SmallVector<SmallVector<int64_t>> expandShape) {
+getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) {
SmallVector<int64_t> result;
- for (auto expandShapeDim : expandShape) {
- result.append(expandShapeDim);
+ for (auto e : expandShape) {
+ for (auto d : e) {
+ result.push_back(d.size);
+ }
}
return result;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
index b7d75c9..2905028 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
@@ -143,7 +143,7 @@
/// Concatenates the vectors.
SmallVector<int64_t>
-getExpandedTileShape(SmallVector<SmallVector<int64_t>> expandShape);
+getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
index 4ffa632..1174e81 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
@@ -37,84 +37,6 @@
#define GEN_PASS_DEF_GPUMATERIALIZEDEVICEENCODINGPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
-/// Returns the index of the dimension whose flattened size (flattening inner
-/// dimensions into it) matches the given `targetSize`. This is used to compute
-/// interleaving indices.
-///
-/// Example:
-/// Input shape = [16, 8, 4, 4]
-/// Input targetSize = 16
-/// -> Return 2, because the tail of the shape starting at index 2 is [4, 4],
-/// whose product equals targetSize.
-static int64_t getDimIdxForTargetSize(ArrayRef<int64_t> shape,
- int64_t targetSize) {
- int interleaveAt = 0;
- int size = 1;
- for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) {
- assert(size <= targetSize);
- assert((targetSize % size) == 0);
- if (size == targetSize) {
- break;
- }
- size *= shape[interleaveAt];
- }
- return interleaveAt;
-}
-
-/// Generates the swizzle for the full data-tiled-mma tile, including all the
-/// relevant unrolling factors.
-static TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
- IREE::GPU::MMAFragment fragment) {
- auto [AType, BType, CType] = mma.getABCElementTypes();
- int ABits = AType.getIntOrFloatBitWidth();
- int BBits = BType.getIntOrFloatBitWidth();
- // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded.
- const int targetPreferredLoadBitWidth = 128;
- auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
- switch (fragment) {
- case IREE::GPU::MMAFragment::Lhs:
- // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
- // Unroll on K with interleaving, then on M.
- if (mma.getUnrollK() > 1) {
- unroll(swizzle, 1, mma.getUnrollK());
- int interleavingIdx = getDimIdxForTargetSize(
- swizzle.expandShape[1],
- targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits));
- interleave(swizzle, 1, interleavingIdx);
- }
- if (mma.getUnrollM() > 1) {
- unroll(swizzle, 0, mma.getUnrollM());
- }
- break;
- case IREE::GPU::MMAFragment::Rhs:
- // B-matrix (RHS). Since the pack ops already took care of transposing B,
- // source dimensions are N (index 0) and K (index 1).
- // Unroll on K with interleaving, then on N.
- if (mma.getUnrollK() > 1) {
- unroll(swizzle, 1, mma.getUnrollK());
- int interleavingIdx = getDimIdxForTargetSize(
- swizzle.expandShape[1],
- targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits));
- interleave(swizzle, 1, interleavingIdx);
- }
- if (mma.getUnrollN() > 1) {
- unroll(swizzle, 0, mma.getUnrollN());
- }
- break;
- case IREE::GPU::MMAFragment::Acc:
- // C-matrix (accumulator). Source dimensions are M (index 0) and N (index
- // 1). Unroll on N, then on M.
- if (mma.getUnrollN() > 1) {
- unroll(swizzle, 1, mma.getUnrollN());
- }
- if (mma.getUnrollM() > 1) {
- unroll(swizzle, 0, mma.getUnrollM());
- }
- break;
- }
- return swizzle;
-}
-
static bool hasIntrinsic(IREE::GPU::TargetAttr target,
IREE::GPU::MMAIntrinsic intrinsic) {
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
@@ -133,13 +55,16 @@
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];
- auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollN,
+ auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollMToThreads,
+ int unrollN, int unrollNToThreads,
int unrollK) -> std::optional<DataTiledMMAAttr> {
if (!hasIntrinsic(target, intrinsic)) {
return std::nullopt;
}
auto candidate = DataTiledMMAAttr::get(
- ctx, MMAIntrinsicAttr::get(ctx, intrinsic), unrollM, unrollN, unrollK);
+ ctx, MMAIntrinsicAttr::get(ctx, intrinsic), /*unroll_m=*/unrollM,
+ /*unroll_m_to_subgroups=*/unrollMToThreads, /*unroll_n=*/unrollN,
+ /*unroll_n_to_subgroups=*/unrollNToThreads, /*unroll_k=*/unrollK);
auto [candidateLhs, candidateRhs, candidateOut] =
candidate.getABCElementTypes();
if (candidateLhs != lhs || candidateRhs != rhs || candidateOut != out) {
@@ -147,13 +72,13 @@
}
return candidate;
};
- if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 8, 4)) {
+ if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 1, 2, 4, 4)) {
return m;
}
- if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 8, 2)) {
+ if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 1, 2, 4, 2)) {
return m;
}
- if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 8, 2)) {
+ if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 1, 2, 4, 2)) {
return m;
}
// Fallback - no architecture-optimized tile size for this case.
@@ -220,7 +145,7 @@
SmallVector<ReassociationIndices>
getReassociationIndices(int outerDims,
- SmallVector<SmallVector<int64_t>> expandShape) {
+ const TileSwizzle::ExpandShapeType &expandShape) {
SmallVector<ReassociationIndices> result;
int expandedIdx = 0;
for (int i = 0; i < outerDims; ++i) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp
index b225e69..94335c4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp
@@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
namespace mlir::iree_compiler {
@@ -13,7 +12,7 @@
// dimensions to expanded dimensions, returns the index of the first expanded
// dimension corresponding to the given source dimension index.
static int64_t
-getExpandedDimFirstIdx(const SmallVector<SmallVector<int64_t>> &expandShape,
+getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape,
int64_t srcIndex) {
int dstIndexFirst = 0;
for (int i = 0; i < srcIndex; ++i) {
@@ -22,14 +21,17 @@
return dstIndexFirst;
}
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor) {
+void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
+ TileSwizzle::Dim::Kind kind) {
assert(unrollFactor > 1);
int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
-
+ TileSwizzle::Dim unrollDim;
+ unrollDim.size = unrollFactor;
+ unrollDim.kind = kind;
// The new unrolling dimension is inserted at the start of the expandShape
// dimensions group corresponding to srcIndex.
swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(),
- unrollFactor);
+ unrollDim);
// Since we are not interleaving here, generating side-by-side copies of the
// original layout, the new unrolling dimension is the new outermost
// dimension. Existing entries get shifted to make room for it.
@@ -97,7 +99,10 @@
// shape expansion for now.
TileSwizzle swizzle;
for (auto t : layout.thread) {
- swizzle.expandShape.push_back({t});
+ TileSwizzle::Dim dim;
+ dim.size = t;
+ dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
+ swizzle.expandShape.push_back({dim});
}
// The layout strides decide the initial swizzle.permutation.
// Some WMMA intrinsics have tstrides=0 values, assert on that as that
@@ -112,9 +117,12 @@
// Deal with any element size greater than 1 by inserting it innermost.
// Notice that this is similar to the unroll() function, just creating an
// inner dimension instead of an outer dimension.
- for (int i = 0; i < layout.element.size(); ++i) {
- if (layout.element[i] != 1) {
- swizzle.expandShape[i].push_back(layout.element[i]);
+ for (auto [i, e] : llvm::enumerate(layout.element)) {
+ if (e != 1) {
+ TileSwizzle::Dim dim;
+ dim.size = e;
+ dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
+ swizzle.expandShape[i].push_back(dim);
int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1;
for (auto &p : swizzle.permutation) {
p += (p >= newIndex);
@@ -125,13 +133,105 @@
// Deal with any outer size greater than 1 as just a call to unroll.
// Iterate over dims in reverse order because we are creating a new outermost
// dimension each time.
- for (int i = layout.outer.size() - 1; i >= 0; --i) {
- if (layout.outer[i] != 1) {
- unroll(swizzle, i, layout.outer[i]);
+ for (auto [i, o] : llvm::enumerate(layout.outer)) {
+ if (o != 1) {
+ // `layout.outer` means additional Internal dimensions, just like
+ // `layout.element`, just swizzled outermost.
+ unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
}
}
return swizzle;
}
+// Returns the index of the dimension whose flattened size (flattening inner
+// dimensions into it) matches the given `targetSize`. This is used to compute
+// interleaving indices.
+//
+// Example:
+// Input shape = [16, 8, 4, 4]
+// Input targetSize = 16
+// -> Return 2, because the tail of the shape starting at index 2 is [4, 4],
+// whose product equals targetSize.
+static int64_t
+getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape,
+ int64_t targetSize) {
+ int interleaveAt = 0;
+ int size = 1;
+ for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) {
+ assert(size <= targetSize);
+ assert((targetSize % size) == 0);
+ if (size == targetSize) {
+ break;
+ }
+ size *= shape[interleaveAt].size;
+ }
+ return interleaveAt;
+}
+
+TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
+ IREE::GPU::MMAFragment fragment) {
+ auto [AType, BType, CType] = mma.getABCElementTypes();
+ int ABits = AType.getIntOrFloatBitWidth();
+ int BBits = BType.getIntOrFloatBitWidth();
+ // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded.
+ const int targetPreferredLoadBitWidth = 128;
+ auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
+ using Kind = TileSwizzle::Dim::Kind;
+ switch (fragment) {
+ case IREE::GPU::MMAFragment::Lhs:
+ // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
+ // Unroll on K with interleaving, then on M.
+ if (mma.getUnrollK() > 1) {
+ unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+ int interleavingIdx = getDimIdxForTargetSize(
+ swizzle.expandShape[1],
+ targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits));
+ interleave(swizzle, 1, interleavingIdx);
+ }
+ if (mma.getUnrollM() > 1) {
+ unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+ }
+ if (mma.getUnrollMToSubgroups() > 1) {
+ unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+ }
+ break;
+ case IREE::GPU::MMAFragment::Rhs:
+ // B-matrix (RHS). Since the pack ops already took care of transposing B,
+ // source dimensions are N (index 0) and K (index 1).
+ // Unroll on K with interleaving, then on N.
+ if (mma.getUnrollK() > 1) {
+ unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
+ int interleavingIdx = getDimIdxForTargetSize(
+ swizzle.expandShape[1],
+ targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits));
+ interleave(swizzle, 1, interleavingIdx);
+ }
+ if (mma.getUnrollN() > 1) {
+ unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic);
+ }
+ if (mma.getUnrollNToSubgroups() > 1) {
+ unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+ }
+ break;
+ case IREE::GPU::MMAFragment::Acc:
+ // C-matrix (accumulator). Source dimensions are M (index 0) and N (index
+ // 1). Unroll on N, then on M.
+ if (mma.getUnrollN() > 1) {
+ unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic);
+ }
+ if (mma.getUnrollNToSubgroups() > 1) {
+ unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread);
+ }
+ if (mma.getUnrollM() > 1) {
+ unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
+ }
+ if (mma.getUnrollMToSubgroups() > 1) {
+ unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
+ }
+ break;
+ }
+ return swizzle;
+}
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h
index fc5af79..fc79bf0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_GPU_GPUTILESWIZZLEUTILS_H_
#include "iree/compiler/Codegen/Common/TileSwizzle.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
namespace mlir::iree_compiler {
@@ -17,17 +18,26 @@
TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
IREE::GPU::MMAFragment fragment);
+// Returns the swizzle for the full data-tiled-mma tile, including all the
+// relevant unrolling factors.
+TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
+ IREE::GPU::MMAFragment fragment);
+
// Unrolls the dimension given by `srcIndex` by the given `unrollFactor`.
// This is not interleaving layouts. The layout will consist of multiple copies
// of the input tile, side by side.
//
+// The enum parameter `kind` initializes the corresponding member on the newly
+// created TileSwizzle::Dim.
+//
// Example:
// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
// Input srcIndex = 1
// Input unrollFactor = 4
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
//
-void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor);
+void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
+ TileSwizzle::Dim::Kind kind);
// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
// move permutation[0], the outer-most dimension (which the unroll() function
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
index ac28e84..8e5927b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
@@ -128,11 +128,11 @@
// CHECK-SAME: inner_tiles = [128, 16]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<5x16x128x16xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x8x16x4x4xf32>
+// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x4x2x16x4x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x8x16x4x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<5x16x8x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x4x2x16x4x4xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<5x16x4x2x4x16x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -161,11 +161,11 @@
// CHECK-SAME: inner_tiles = [128, 128]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<2x5x128x128xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x8x16xf32>
+// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x5x8x8x4x16x4xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 6, 4]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xf32>)
+// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -189,11 +189,11 @@
// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x8x4x16x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x8x16xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5]
+// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>)
+// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME: : tensor<2x5x8x4x4x8x16xf32> into tensor<2x5x128x128xf32>
+// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
@@ -232,11 +232,11 @@
}
// CHECK-LABEL: func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%{{.+}} : tensor<?x?x8x8x4x16x4xf32>)
-// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x4x8x16xf32>)
-// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5]
+// CHECK-SAME: ins(%{{.+}} : tensor<?x?x8x4x2x4x16x4xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?x?x8x4x4x4x2x16xf32>)
+// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME: : tensor<?x?x8x4x4x8x16xf32> into tensor<?x?x128x128xf32>
+// CHECK-SAME: : tensor<?x?x8x4x4x4x2x16xf32> into tensor<?x?x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
@@ -295,12 +295,12 @@
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
-// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
-// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x8x4x16x4xf32>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4xf32>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
-// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 8, unroll_k = 4>
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
@@ -365,11 +365,11 @@
// CHECK-SAME: inner_tiles = [128, 64]
// CHECK-SAME: : tensor<255x513xi8> -> tensor<5x4x128x64xi8>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x8x16x2x4x8xi8>
+// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x4x2x16x2x4x8xi8>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x8x16x2x4x8xi8>)
-// CHECK-SAME: outs({{.*}} : tensor<5x4x8x4x16x2x8xi8>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4, 6]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x4x2x16x2x4x8xi8>)
+// CHECK-SAME: outs({{.*}} : tensor<5x4x4x2x4x16x2x8xi8>)
+// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5, 7]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -398,11 +398,11 @@
// CHECK-SAME: inner_tiles = [128, 128]
// CHECK-SAME: : tensor<255x513xi32> -> tensor<2x5x128x128xi32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
-// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x8x16xi32>
+// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xi32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x5x8x8x4x16x4xi32>)
-// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 6, 4]
+// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>)
+// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xi32>)
+// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]
// -----
@@ -426,11 +426,11 @@
// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
-// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x8x4x16x4xi32>)
-// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x8x16xi32>)
-// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5]
+// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xi32>)
+// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>)
+// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
-// CHECK-SAME: : tensor<2x5x8x4x4x8x16xi32> into tensor<2x5x128x128xi32>
+// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
@@ -490,10 +490,10 @@
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
-// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
-// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x8x4x16x4xi32>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x2x8xi8>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xi32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
-// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 8, unroll_k = 2>
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp
new file mode 100644
index 0000000..7ae46e6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp
@@ -0,0 +1,46 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TileSwizzle.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir::iree_compiler {
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ TileSwizzle::Dim::Kind kind) {
+ switch (kind) {
+ case TileSwizzle::Dim::Kind::Internal:
+ return os << "Internal";
+ case TileSwizzle::Dim::Kind::CrossThread:
+ return os << "CrossThread";
+ case TileSwizzle::Dim::Kind::CrossIntrinsic:
+ return os << "CrossIntrinsic";
+ }
+}
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, TileSwizzle::Dim dim) {
+ return os << dim.size << "(" << dim.kind << ")";
+}
+
+static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const TileSwizzle::ExpandShapeDimVectorType &expandShapeDimVector) {
+ os << "[";
+ llvm::interleaveComma(expandShapeDimVector, os);
+ return os << "]";
+}
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const TileSwizzle &swizzle) {
+ os << "{expandShape = [";
+ llvm::interleaveComma(swizzle.expandShape, os);
+ os << "], swizzle = [";
+ llvm::interleaveComma(swizzle.permutation, os);
+ os << "]}";
+ return os;
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
index b908ae4..738bb6a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h
@@ -9,6 +9,7 @@
#include <cstdint>
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir::iree_compiler {
@@ -16,13 +17,50 @@
// pair of ops performing a change of layout within the tiles. This is used
// on GPU, where the tiles themselves can have an arbitrary layout.
struct TileSwizzle {
+ struct Dim {
+ // Describes what varies across this dimension.
+ enum class Kind : int8_t {
+ // This dimension is internal to one intrinsic on one thread. This
+ // is only seen for intrinsic operands that are themselves vectors.
+ // For example, with AMD MFMA, for the MFMA_F32_16x16x4_F32 intrinsic,
+ // the C-matrix operand is a vector of 4 floats already at the level of
+ // one intrinsic on one thread. That dimension of size 4 is 'Internal'.
+ Internal,
+ // This dimension is internal to one intrinsic, but is across threads.
+ // For example, with AMD MFMA, for the MFMA_F32_16x16x4_F32 intrinsic,
+ // the A-matrix tile has shape 16x4, and these two dimensions of size 16
+ // and 4 are 'CrossThread': neither is visible at the single-thread level
+ // (in the intrinsic itself, the A-matrix operand is a single scalar) but
+ // as we move along these dimensions, we are moving over the 64 threads
+ // of the subgroup.
+ //
+ // Another example of cross-thread dimensions is in kernels that are
+ // "unrolled" across subgroups. Such dimensions are cross-subgroup, so in
+ // particular they are cross-thread.
+ CrossThread,
+ // This dimensions is across intrinsics, as in, actual instructions in the
+ // generated code. In other words, it is an actual unrolling factor,
+ // resulting in this many more instructions being generated and executed
+ // on each thread/subgroup.
+ CrossIntrinsic
+ };
+
+ Kind kind = Kind::Internal;
+
+ // The size of the dimension.
+ int16_t size = 0;
+ };
+
+ using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
+ using ExpandShapeType = llvm::SmallVector<ExpandShapeDimVectorType>;
+
// This vector-of-vectors contains all the information needed to generate
// a `tensor.expand_shape` creating additional internal dimensions into the
// tile. For example, expandShape = [[16], [4, 2]] means that the original
// tile shape [16, 8] gets expanded such that the first dimension 16 is left
// unchanged, and the second dimension 8 gets split into two internal dims
// of size 4 and 2.
- llvm::SmallVector<llvm::SmallVector<int64_t>> expandShape;
+ ExpandShapeType expandShape;
// This permutation vector applies to the expanded dimensions and is used
// to generate a `linalg.transpose` changing the layout of the tile. For
// example, permutation[0] dictates which of the expanded dimensions becomes
@@ -30,6 +68,14 @@
llvm::SmallVector<int64_t> permutation;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ TileSwizzle::Dim::Kind kind);
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, TileSwizzle::Dim dim);
+
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const TileSwizzle &swizzle);
+
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_TILESWIZZLE_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 07cd27d..b8b8806 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -898,7 +898,8 @@
std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
- return {opaqueLayout.mSize * getUnrollM(), opaqueLayout.nSize * getUnrollN(),
+ return {opaqueLayout.mSize * getUnrollM() * getUnrollMToSubgroups(),
+ opaqueLayout.nSize * getUnrollN() * getUnrollNToSubgroups(),
opaqueLayout.kSize * getUnrollK()};
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index be04d19..d3dc53a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -255,9 +255,11 @@
let parameters = (ins
"::mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr":$intrinsic,
- "int64_t":$unroll_m,
- "int64_t":$unroll_n,
- "int64_t":$unroll_k
+ DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, on the same thread.">:$unroll_m,
+ DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, distributed across this many more threads.">:$unroll_m_to_subgroups,
+ DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, on the same thread.">:$unroll_n,
+ DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, distributed across this many more threads.">:$unroll_n_to_subgroups,
+ DefaultValuedParameter<"int64_t", "1", "Unrolling along the K dimension, on the same thread, with interleaved layout.">:$unroll_k
);
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir
index 046a3c8..0ebe947 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir
@@ -29,21 +29,21 @@
module {
func.func @test_data_tiled_mfma_f32_16x16x4_f32() attributes {
- mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>} {
+ mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_subgroups = 2, unroll_k = 1>} {
return
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32
-// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_subgroups = 2>
module {
func.func @test_data_tiled_mfma_f32_16x16x16_f16() attributes {
- mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n = 1, unroll_k = 1>} {
+ mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n_to_subgroups = 2, unroll_k = 2>} {
return
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x16_f16
-// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_n_to_subgroups = 2, unroll_k = 2>
module {
func.func @test_data_tiled_mfma_i32_16x16x32_i8() attributes {
@@ -52,7 +52,7 @@
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_i32_16x16x32_i8
-// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8>
module {
func.func @test_any_lowering_config() attributes {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
index b174f8f..0aa922e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir
@@ -227,7 +227,7 @@
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
- kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+ kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
} : tensor<?x?x4x16x1x1xf32>, tensor<?x?x4x16x1x1xf32> into tensor<?x?x4x16x4x1xf32>
return %0 : tensor<?x?x4x16x4x1xf32>
}
@@ -240,7 +240,7 @@
// CHECK: iree_gpu.multi_mma %arg0, %arg1, %arg2
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
-// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
// CHECK-SAME: : tensor<?x?x4x16x1x1xf32>, tensor<?x?x4x16x1x1xf32> into tensor<?x?x4x16x4x1xf32>
// -----
@@ -272,6 +272,34 @@
// -----
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+func.func @data_tiled_2x2x4_tensor_multi_mma(%lhs: tensor<?x?x2x4x16x1x4xf32>, %rhs: tensor<?x?x2x4x16x1x4xf32>, %acc: tensor<?x?x2x2x4x16x4x1xf32>) -> tensor<?x?x2x2x4x16x4x1xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m_to_subgroups = 2, unroll_n_to_subgroups = 2, unroll_k = 4>
+ } : tensor<?x?x2x4x16x1x4xf32>, tensor<?x?x2x4x16x1x4xf32> into tensor<?x?x2x2x4x16x4x1xf32>
+ return %0 : tensor<?x?x2x2x4x16x4x1xf32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma
+// CHECK: iree_gpu.multi_mma %arg0, %arg1, %arg2
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m_to_subgroups = 2, unroll_n_to_subgroups = 2, unroll_k = 4>
+// CHECK-SAME: : tensor<?x?x2x4x16x1x4xf32>, tensor<?x?x2x4x16x1x4xf32> into tensor<?x?x2x2x4x16x4x1xf32>
+
+
+// -----
+
func.func @tensor_barrier(%input: tensor<?xf16>) -> tensor<?xf16> {
%out = iree_gpu.value_barrier %input : tensor<?xf16>
return %out : tensor<?xf16>