[Codegen] Add transfer_write distribution pattern for nested layouts (#16402)
Note that today this does not handle cases where we need predication on
the writes based on the thread id (i.e. write only if lane_id == 0).
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index a516148..35e819b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
@@ -64,6 +65,105 @@
.getResult();
}
+/// Given a set of base transfer |indices|, |offsets| for the batch/outer
+/// dimensions, and distributed warp and thread indices, computes the indices
+/// of the distributed transfer operation based on the |vectorLayout|.
+static SmallVector<Value> getTransferIndicesFromNestedLayout(
+ OpBuilder &b, ValueRange indices, ArrayRef<int64_t> offsets,
+ NestedLayoutAttr vectorLayout, AffineMap permutationMap,
+ ArrayRef<Value> warpIndices, ArrayRef<Value> threadIndices) {
+ auto isBroadcast = [](AffineExpr expr) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
+ return constExpr.getValue() == 0;
+ return false;
+ };
+ int64_t rank = vectorLayout.getBatchOrder().size();
+ // Permute the batch and outer vector offsets to match the order of
+ // the vector dimensions using the inverse of the batch/offset order.
+ SmallVector<int64_t> batchOffsets =
+ applyPermutation(ArrayRef<int64_t>(offsets.begin(), rank),
+ invertPermutationVector(vectorLayout.getBatchOrder()));
+ SmallVector<int64_t> outerVectorOffsets =
+ applyPermutation(ArrayRef<int64_t>(offsets.begin() + rank, rank),
+ invertPermutationVector(vectorLayout.getOuterOrder()));
+
+ SmallVector<Value> slicedIndices(indices.begin(), indices.end());
+ for (const auto &[i, dim] : llvm::enumerate(permutationMap.getResults())) {
+ // Broadcasted dimension offsets can be used as-is; the read index is
+ // invariant of the thread in such cases (and illegal for writes).
+ if (isBroadcast(dim)) {
+ continue;
+ }
+ unsigned pos = cast<AffineDimExpr>(dim).getPosition();
+ SmallVector<OpFoldResult> ids = {
+ warpIndices[i], b.getIndexAttr(batchOffsets[i]),
+ b.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
+ // The order in which a vector dimension is "tiled" is
+ // subgroups -> batches -> outer vectors -> threads -> elements
+ SmallVector<int64_t> sizes = {vectorLayout.getSubgroupsPerWorkgroup()[i],
+ vectorLayout.getBatchesPerSubgroup()[i],
+ vectorLayout.getOutersPerBatch()[i],
+ vectorLayout.getThreadsPerOuter()[i]};
+ slicedIndices[pos] = linearizeIndex(b, indices[pos], ids, sizes,
+ vectorLayout.getElementsPerThread()[i]);
+ }
+ return slicedIndices;
+}
+
+static SmallVector<int64_t> getLoopOrder(NestedLayoutAttr vectorLayout) {
+ int64_t rank = vectorLayout.getBatchOrder().size();
+ // Let the unroll order first unroll the batch dimensions, then the
+ // outer vector dimensions. We unroll in the order specified by the
+ // layout.
+ SmallVector<int64_t> loopOrder;
+ int64_t base = 0;
+ for (auto b : vectorLayout.getBatchOrder()) {
+ loopOrder.push_back(base + b);
+ }
+ base += rank;
+ // We must unroll along the outer dimensions as well to match the rank
+ // requirements of vector transfer ops (<= memref rank up to broadcasts).
+ for (auto o : vectorLayout.getOuterOrder()) {
+ loopOrder.push_back(base + o);
+ }
+ base += rank;
+ for (int i = 0, e = rank; i < e; ++i) {
+ loopOrder.push_back(base + i);
+ }
+ return loopOrder;
+}
+
+static SmallVector<int64_t>
+getElementVectorTileShape(NestedLayoutAttr vectorLayout) {
+ int64_t rank = vectorLayout.getBatchOrder().size();
+ SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
+ for (int i = 0, e = rank * 2; i < e; ++i) {
+ tileShape[i] = 1;
+ }
+ return tileShape;
+}
+
+/// Computes the warp and thread indices for the given vector layout from a
+/// single linearized thread ID.
+static void populateWarpAndThreadIndices(RewriterBase &rewriter, Value threadId,
+ NestedLayoutAttr vectorLayout,
+ SmallVector<Value> &warpIndices,
+ SmallVector<Value> &threadIndices) {
+ int64_t rank = vectorLayout.getBatchOrder().size();
+ // The delinearized thread IDs are returned from outer most to inner most,
+ // i.e. before applying the layout described dimensions ordering.
+ ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
+
+ // Subgroup and thread (lane) indices normalized to the order in which
+ // they are used by each dimension.
+ warpIndices =
+ llvm::to_vector(llvm::map_range(vectorLayout.getSubgroupOrder(),
+ [&](int64_t i) { return threadIds[i]; }));
+ threadIndices = llvm::to_vector(
+ llvm::map_range(vectorLayout.getThreadOrder(),
+ [&](int64_t i) { return threadIds[i + rank]; }));
+}
+
namespace {
/// Pattern to distribute `vector.transfer_read` ops with nested layouts.
@@ -93,35 +193,11 @@
return failure();
}
- // The delinearized thread IDs are returned from outer most to inner most,
- // i.e. before applying the layout described dimensions ordering.
- ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
-
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
- SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
+ SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
+ SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
int64_t rank = vectorLayout.getBatchOrder().size();
- // Let the unroll order first unroll the batch dimensions, then the
- // outer vector dimensions. We unroll in the order specified by the
- // layout.
- SmallVector<int64_t> loopOrder;
- int64_t base = 0;
- for (int64_t b : vectorLayout.getBatchOrder()) {
- loopOrder.push_back(base + b);
- tileShape[base + b] = 1;
- }
- base += rank;
- // We must unroll along the outer dimensions as well to match the rank
- // requirements of vector transfer ops (<= memref rank up to broadcasts).
- for (int64_t o : vectorLayout.getOuterOrder()) {
- loopOrder.push_back(base + o);
- tileShape[base + o] = 1;
- }
- base += rank;
- for (int i = 0, e = rank; i < e; ++i) {
- loopOrder.push_back(base + i);
- }
-
Type elementType = readOp.getSource().getType().getElementType();
auto vectorType = VectorType::get(distShape, elementType);
// The shape of the vector we read is pre-permutation. The permutation is
@@ -135,54 +211,18 @@
readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
VectorValue acc = cast<VectorValue>(zero);
- // Subgroup and thread (lane) indices normalized to the order in which
- // they are used by each dimension.
- SmallVector<Value> warpIndices = llvm::to_vector(
- llvm::map_range(vectorLayout.getSubgroupOrder(),
- [&](int64_t i) { return threadIds[i]; }));
- SmallVector<Value> threadIndices = llvm::to_vector(
- llvm::map_range(vectorLayout.getThreadOrder(),
- [&](int64_t i) { return threadIds[i + rank]; }));
-
- auto isBroadcast = [](AffineExpr expr) {
- if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
- return constExpr.getValue() == 0;
- return false;
- };
+ SmallVector<Value> warpIndices, threadIndices;
+ populateWarpAndThreadIndices(rewriter, threadId, vectorLayout, warpIndices,
+ threadIndices);
ValueRange indices = readOp.getIndices();
SmallVector<int64_t> strides(rank, 1);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
- // Permute the batch and outer vector offsets to match the order of
- // the vector dimensions using the inverse of the batch/offset order.
- SmallVector<int64_t> batchOffsets = applyPermutation(
- ArrayRef<int64_t>(offsets.begin(), rank),
- invertPermutationVector(vectorLayout.getBatchOrder()));
- SmallVector<int64_t> outerVectorOffsets = applyPermutation(
- ArrayRef<int64_t>(offsets.begin() + rank, rank),
- invertPermutationVector(vectorLayout.getOuterOrder()));
+ SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
+ rewriter, indices, offsets, vectorLayout, readOp.getPermutationMap(),
+ warpIndices, threadIndices);
- SmallVector<Value> slicedIndices(indices.begin(), indices.end());
- for (const auto &[i, dim] :
- llvm::enumerate(readOp.getPermutationMap().getResults())) {
- if (isBroadcast(dim))
- continue;
- unsigned pos = cast<AffineDimExpr>(dim).getPosition();
- SmallVector<OpFoldResult> ids = {
- warpIndices[i], rewriter.getIndexAttr(batchOffsets[i]),
- rewriter.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
- // The order in which a vector dimension is "tiled" is
- // subgroups -> batches -> outer vectors -> threads -> elements
- SmallVector<int64_t> sizes = {
- vectorLayout.getSubgroupsPerWorkgroup()[i],
- vectorLayout.getBatchesPerSubgroup()[i],
- vectorLayout.getOutersPerBatch()[i],
- vectorLayout.getThreadsPerOuter()[i]};
- slicedIndices[pos] =
- linearizeIndex(rewriter, indices[pos], ids, sizes,
- vectorLayout.getElementsPerThread()[i]);
- }
Value slicedRead = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
@@ -204,12 +244,83 @@
Value threadId;
};
+/// Pattern to distribute `vector.transfer_write` ops with nested layouts.
+struct DistributeTransferWriteNestedLayoutAttr final
+ : OpDistributionPattern<vector::TransferWriteOp> {
+ using OpDistributionPattern::OpDistributionPattern;
+
+ DistributeTransferWriteNestedLayoutAttr(MLIRContext *context, Value threadId)
+ : OpDistributionPattern(context), threadId(threadId) {}
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ DistributionSignature &signature,
+ PatternRewriter &rewriter) const override {
+ // TODO: Support masking.
+ if (writeOp.getMask()) {
+ return failure();
+ }
+ NestedLayoutAttr vectorLayout =
+ dyn_cast<NestedLayoutAttr>(signature[writeOp.getVector()]);
+ if (!vectorLayout) {
+ return failure();
+ }
+
+ if (!isa<MemRefType>(writeOp.getSource().getType())) {
+ return failure();
+ }
+
+ SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
+ SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
+ SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
+ int64_t rank = vectorLayout.getBatchOrder().size();
+
+ SmallVector<Value> warpIndices, threadIndices;
+ populateWarpAndThreadIndices(rewriter, threadId, vectorLayout, warpIndices,
+ threadIndices);
+
+ Value distributedVector =
+ getDistributed(rewriter, writeOp.getVector(), vectorLayout);
+
+ ValueRange indices = writeOp.getIndices();
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+ SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
+ rewriter, indices, offsets, vectorLayout, writeOp.getPermutationMap(),
+ warpIndices, threadIndices);
+
+ // Extract the "element vector" from the inner most dimensions. All outer
+ // dimensions are either unrolled or distributed such that this is a
+ // contiguous slice.
+ ArrayRef<int64_t> offsetArray(offsets);
+ Value slicedVector = rewriter.create<vector::ExtractOp>(
+ writeOp.getLoc(), distributedVector,
+ offsetArray.take_front(rank * 2));
+ // Transpose to the native dimension order.
+ if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
+ slicedVector = rewriter.create<vector::TransposeOp>(
+ slicedVector.getLoc(), slicedVector,
+ invertPermutationVector(vectorLayout.getElementOrder()));
+ }
+ rewriter.create<vector::TransferWriteOp>(
+ writeOp.getLoc(), slicedVector, writeOp.getSource(), slicedIndices,
+ writeOp.getPermutationMapAttr(), writeOp.getMask(),
+ writeOp.getInBoundsAttr());
+ }
+
+ rewriter.eraseOp(writeOp);
+ return success();
+ }
+
+ Value threadId;
+};
+
} // namespace
void populateGPUDistributeNestedLayoutAttrPatterns(
Value threadId, RewritePatternSet &patterns) {
- patterns.add<DistributeTransferReadNestedLayoutAttr>(patterns.getContext(),
- threadId);
+ patterns.add<DistributeTransferReadNestedLayoutAttr,
+ DistributeTransferWriteNestedLayoutAttr>(patterns.getContext(),
+ threadId);
}
}; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index d11526b..92f88ed 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -46,7 +46,7 @@
// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c8]
// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 1, 0, 0, 0, 0]
-// CHECK: iree_vector_ext.to_simd %10 : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
// -----
@@ -141,7 +141,7 @@
// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
// CHECK: %[[OFF0:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2, %[[I0]]]
// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[I1]]]
-// CHECK: %[[OFF1:.+]] = affine.apply #[[$MAP1]]()[%arg1]
+// CHECK: %[[OFF1:.+]] = affine.apply #[[$MAP1]]()[%[[I1]]]
// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[OFF1]]]
// CHECK: %[[OFF2:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#2, %[[I0]]]
// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF2]], %[[I1]]]
@@ -239,10 +239,10 @@
// CHECK: %[[LIN_ID0:.+]] = affine.apply #[[$MAP:.+]]()[%[[IDS]]#2, %[[I1]]]
// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
// CHECK: %[[I0_PLUS_8:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
-// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
// CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
-// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0]], %[[LIN_ID1]]], {{.*}} permutation_map = #[[$MAP1]]
-// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID1]]], %cst_0 {in_bounds = [true, true], permutation_map = #map1} : memref<32x32x32x32xf16>, vector<1x8xf16>
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID1]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID1]]], %cst_0 {in_bounds = [true, true], permutation_map = #map1} : memref<32x32x32x32xf16>, vector<1x8xf16>
// -----
@@ -325,3 +325,203 @@
// unique transfer read ops. The broadcasted dimension (2) CSEs the duplicate
// reads.
// CHECK-COUNT-60: vector.transfer_read
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [8, 1],
+ elements_per_thread = [1, 8],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @distribute_transfer_write_row_major
+func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0]
+ {in_bounds = [true, true],
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<64x64xf16>
+ func.return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[SLICE:.+]] = vector.extract %{{.*}}[0, 0, 0, 0] : vector<1x8xf16> from vector<2x2x1x1x1x8xf16>
+// CHECK: vector.transfer_write %[[SLICE]], %{{.*}}[%[[IDS]]#2, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<64x64xf16>
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[IDS]]#2, %c8]
+// CHECK: %[[LANEX_PLUS_VECDIMX:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c0]
+// CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c8]
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [1, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [4, 8],
+ elements_per_thread = [4, 1],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [1, 0]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @distribute_transfer_write_col_major
+func.func @distribute_transfer_write_col_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0]
+ {in_bounds = [true, true],
+ "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+ : vector<16x16xf16>, memref<64x64xf16>
+ func.return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %0 into (%c1, %c1, %c4, %c8) : index, index, index, index
+// CHECK: %[[LANEY:.+]] = affine.apply #map()[%1#2]
+// CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
+// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
+// CHECK: vector.transfer_write %{{.*}}[%[[LANEY]], %[[IDS]]#3]
+// CHECK: %[[LANEX:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
+// CHECK: vector.transfer_write {{.*}}[%[[LANEY]], %[[LANEX]]]
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [8, 1],
+ elements_per_thread = [1, 8],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 8)>
+
+func.func @distribute_transfer_write_row_major_with_nontrivial_index(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ func.return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: @distribute_transfer_write_row_major_with_nontrivial_index
+// CHECK-SAME: vector<16x16xf16>, %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[LIN_ID0:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]] {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %3] {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[LIN_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 2],
+ outers_per_batch = [1, 1],
+ threads_per_outer = [8, 1],
+ elements_per_thread = [1, 8],
+
+ subgroup_order = [0, 1],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+func.func @distribute_transfer_read_write(%a: index, %b: index,
+ %arg0: memref<32x32x32x32xf16>,
+ %arg1: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_result_0" = #layout_row_major}
+ : memref<32x32x32x32xf16>, vector<16x16xf16>
+ vector.transfer_write %root, %arg1[%c0, %c0, %a, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK: %[[B00:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX:[a-zA-Z0-9]+]], %[[OFFSET0:[a-zA-Z0-9]+]]]
+// CHECK: %[[B10:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX]], %[[OFFSET1:[a-zA-Z0-9]+]]]
+// CHECK: %[[B01:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH:[a-zA-Z0-9]+]], %[[OFFSET0]]]
+// CHECK: %[[B11:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH]], %[[OFFSET1]]]
+// CHECK: vector.transfer_write %[[B00]], %{{.*}}[%c0, %c0, %[[LANEX]], %[[OFFSET0]]]
+// CHECK: vector.transfer_write %[[B10]], %{{.*}}[%c0, %c0, %[[LANEX]], %[[OFFSET1]]]
+// CHECK: vector.transfer_write %[[B01]], %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH]], %[[OFFSET0]]]
+// CHECK: vector.transfer_write %[[B11]], %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH]], %[[OFFSET1]]]