[Codegen] Add transfer_write distribution pattern for nested layouts (#16402)

Note that today this does not handle cases where we need predication on
the writes based on the thread id (i.e. write only if lane_id == 0).
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index a516148..35e819b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -4,6 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
@@ -64,6 +65,105 @@
       .getResult();
 }
 
+/// Given a set of base transfer |indices|, |offsets| for the batch/outer
+/// dimensions, and distributed warp and thread indices, computes the indices
+/// of the distributed transfer operation based on the |vectorLayout|.
+static SmallVector<Value> getTransferIndicesFromNestedLayout(
+    OpBuilder &b, ValueRange indices, ArrayRef<int64_t> offsets,
+    NestedLayoutAttr vectorLayout, AffineMap permutationMap,
+    ArrayRef<Value> warpIndices, ArrayRef<Value> threadIndices) {
+  auto isBroadcast = [](AffineExpr expr) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
+      return constExpr.getValue() == 0;
+    return false;
+  };
+  int64_t rank = vectorLayout.getBatchOrder().size();
+  // Permute the batch and outer vector offsets to match the order of
+  // the vector dimensions using the inverse of the batch/offset order.
+  SmallVector<int64_t> batchOffsets =
+      applyPermutation(ArrayRef<int64_t>(offsets.begin(), rank),
+                       invertPermutationVector(vectorLayout.getBatchOrder()));
+  SmallVector<int64_t> outerVectorOffsets =
+      applyPermutation(ArrayRef<int64_t>(offsets.begin() + rank, rank),
+                       invertPermutationVector(vectorLayout.getOuterOrder()));
+
+  SmallVector<Value> slicedIndices(indices.begin(), indices.end());
+  for (const auto &[i, dim] : llvm::enumerate(permutationMap.getResults())) {
+    // Broadcasted dimension offsets can be used as-is; the read index is
+    // invariant of the thread in such cases (and illegal for writes).
+    if (isBroadcast(dim)) {
+      continue;
+    }
+    unsigned pos = cast<AffineDimExpr>(dim).getPosition();
+    SmallVector<OpFoldResult> ids = {
+        warpIndices[i], b.getIndexAttr(batchOffsets[i]),
+        b.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
+    // The order in which a vector dimension is "tiled" is
+    // subgroups -> batches -> outer vectors -> threads -> elements
+    SmallVector<int64_t> sizes = {vectorLayout.getSubgroupsPerWorkgroup()[i],
+                                  vectorLayout.getBatchesPerSubgroup()[i],
+                                  vectorLayout.getOutersPerBatch()[i],
+                                  vectorLayout.getThreadsPerOuter()[i]};
+    slicedIndices[pos] = linearizeIndex(b, indices[pos], ids, sizes,
+                                        vectorLayout.getElementsPerThread()[i]);
+  }
+  return slicedIndices;
+}
+
+static SmallVector<int64_t> getLoopOrder(NestedLayoutAttr vectorLayout) {
+  int64_t rank = vectorLayout.getBatchOrder().size();
+  // Let the unroll order first unroll the batch dimensions, then the
+  // outer vector dimensions. We unroll in the order specified by the
+  // layout.
+  SmallVector<int64_t> loopOrder;
+  int64_t base = 0;
+  for (auto b : vectorLayout.getBatchOrder()) {
+    loopOrder.push_back(base + b);
+  }
+  base += rank;
+  // We must unroll along the outer dimensions as well to match the rank
+  // requirements of vector transfer ops (<= memref rank up to broadcasts).
+  for (auto o : vectorLayout.getOuterOrder()) {
+    loopOrder.push_back(base + o);
+  }
+  base += rank;
+  for (int i = 0, e = rank; i < e; ++i) {
+    loopOrder.push_back(base + i);
+  }
+  return loopOrder;
+}
+
+static SmallVector<int64_t>
+getElementVectorTileShape(NestedLayoutAttr vectorLayout) {
+  int64_t rank = vectorLayout.getBatchOrder().size();
+  SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
+  for (int i = 0, e = rank * 2; i < e; ++i) {
+    tileShape[i] = 1;
+  }
+  return tileShape;
+}
+
+/// Computes the warp and thread indices for the given vector layout from a
+/// single linearized thread ID.
+static void populateWarpAndThreadIndices(RewriterBase &rewriter, Value threadId,
+                                         NestedLayoutAttr vectorLayout,
+                                         SmallVector<Value> &warpIndices,
+                                         SmallVector<Value> &threadIndices) {
+  int64_t rank = vectorLayout.getBatchOrder().size();
+  // The delinearized thread IDs are returned from outer most to inner most,
+  // i.e. before applying the layout described dimensions ordering.
+  ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
+
+  // Subgroup and thread (lane) indices normalized to the order in which
+  // they are used by each dimension.
+  warpIndices =
+      llvm::to_vector(llvm::map_range(vectorLayout.getSubgroupOrder(),
+                                      [&](int64_t i) { return threadIds[i]; }));
+  threadIndices = llvm::to_vector(
+      llvm::map_range(vectorLayout.getThreadOrder(),
+                      [&](int64_t i) { return threadIds[i + rank]; }));
+}
+
 namespace {
 
 /// Pattern to distribute `vector.transfer_read` ops with nested layouts.
@@ -93,35 +193,11 @@
       return failure();
     }
 
-    // The delinearized thread IDs are returned from outer most to inner most,
-    // i.e. before applying the layout described dimensions ordering.
-    ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
-
     SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
-    SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
+    SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
+    SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
     int64_t rank = vectorLayout.getBatchOrder().size();
 
-    // Let the unroll order first unroll the batch dimensions, then the
-    // outer vector dimensions. We unroll in the order specified by the
-    // layout.
-    SmallVector<int64_t> loopOrder;
-    int64_t base = 0;
-    for (int64_t b : vectorLayout.getBatchOrder()) {
-      loopOrder.push_back(base + b);
-      tileShape[base + b] = 1;
-    }
-    base += rank;
-    // We must unroll along the outer dimensions as well to match the rank
-    // requirements of vector transfer ops (<= memref rank up to broadcasts).
-    for (int64_t o : vectorLayout.getOuterOrder()) {
-      loopOrder.push_back(base + o);
-      tileShape[base + o] = 1;
-    }
-    base += rank;
-    for (int i = 0, e = rank; i < e; ++i) {
-      loopOrder.push_back(base + i);
-    }
-
     Type elementType = readOp.getSource().getType().getElementType();
     auto vectorType = VectorType::get(distShape, elementType);
     // The shape of the vector we read is pre-permutation. The permutation is
@@ -135,54 +211,18 @@
         readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
     VectorValue acc = cast<VectorValue>(zero);
 
-    // Subgroup and thread (lane) indices normalized to the order in which
-    // they are used by each dimension.
-    SmallVector<Value> warpIndices = llvm::to_vector(
-        llvm::map_range(vectorLayout.getSubgroupOrder(),
-                        [&](int64_t i) { return threadIds[i]; }));
-    SmallVector<Value> threadIndices = llvm::to_vector(
-        llvm::map_range(vectorLayout.getThreadOrder(),
-                        [&](int64_t i) { return threadIds[i + rank]; }));
-
-    auto isBroadcast = [](AffineExpr expr) {
-      if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
-        return constExpr.getValue() == 0;
-      return false;
-    };
+    SmallVector<Value> warpIndices, threadIndices;
+    populateWarpAndThreadIndices(rewriter, threadId, vectorLayout, warpIndices,
+                                 threadIndices);
 
     ValueRange indices = readOp.getIndices();
     SmallVector<int64_t> strides(rank, 1);
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
-      // Permute the batch and outer vector offsets to match the order of
-      // the vector dimensions using the inverse of the batch/offset order.
-      SmallVector<int64_t> batchOffsets = applyPermutation(
-          ArrayRef<int64_t>(offsets.begin(), rank),
-          invertPermutationVector(vectorLayout.getBatchOrder()));
-      SmallVector<int64_t> outerVectorOffsets = applyPermutation(
-          ArrayRef<int64_t>(offsets.begin() + rank, rank),
-          invertPermutationVector(vectorLayout.getOuterOrder()));
+      SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
+          rewriter, indices, offsets, vectorLayout, readOp.getPermutationMap(),
+          warpIndices, threadIndices);
 
-      SmallVector<Value> slicedIndices(indices.begin(), indices.end());
-      for (const auto &[i, dim] :
-           llvm::enumerate(readOp.getPermutationMap().getResults())) {
-        if (isBroadcast(dim))
-          continue;
-        unsigned pos = cast<AffineDimExpr>(dim).getPosition();
-        SmallVector<OpFoldResult> ids = {
-            warpIndices[i], rewriter.getIndexAttr(batchOffsets[i]),
-            rewriter.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
-        // The order in which a vector dimension is "tiled" is
-        // subgroups -> batches -> outer vectors -> threads -> elements
-        SmallVector<int64_t> sizes = {
-            vectorLayout.getSubgroupsPerWorkgroup()[i],
-            vectorLayout.getBatchesPerSubgroup()[i],
-            vectorLayout.getOutersPerBatch()[i],
-            vectorLayout.getThreadsPerOuter()[i]};
-        slicedIndices[pos] =
-            linearizeIndex(rewriter, indices[pos], ids, sizes,
-                           vectorLayout.getElementsPerThread()[i]);
-      }
       Value slicedRead = rewriter.create<vector::TransferReadOp>(
           readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
@@ -204,12 +244,83 @@
   Value threadId;
 };
 
+/// Pattern to distribute `vector.transfer_write` ops with nested layouts.
+struct DistributeTransferWriteNestedLayoutAttr final
+    : OpDistributionPattern<vector::TransferWriteOp> {
+  using OpDistributionPattern::OpDistributionPattern;
+
+  DistributeTransferWriteNestedLayoutAttr(MLIRContext *context, Value threadId)
+      : OpDistributionPattern(context), threadId(threadId) {}
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Support masking.
+    if (writeOp.getMask()) {
+      return failure();
+    }
+    NestedLayoutAttr vectorLayout =
+        dyn_cast<NestedLayoutAttr>(signature[writeOp.getVector()]);
+    if (!vectorLayout) {
+      return failure();
+    }
+
+    if (!isa<MemRefType>(writeOp.getSource().getType())) {
+      return failure();
+    }
+
+    SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
+    SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
+    SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
+    int64_t rank = vectorLayout.getBatchOrder().size();
+
+    SmallVector<Value> warpIndices, threadIndices;
+    populateWarpAndThreadIndices(rewriter, threadId, vectorLayout, warpIndices,
+                                 threadIndices);
+
+    Value distributedVector =
+        getDistributed(rewriter, writeOp.getVector(), vectorLayout);
+
+    ValueRange indices = writeOp.getIndices();
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+      SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
+          rewriter, indices, offsets, vectorLayout, writeOp.getPermutationMap(),
+          warpIndices, threadIndices);
+
+      // Extract the "element vector" from the inner most dimensions. All outer
+      // dimensions are either unrolled or distributed such that this is a
+      // contiguous slice.
+      ArrayRef<int64_t> offsetArray(offsets);
+      Value slicedVector = rewriter.create<vector::ExtractOp>(
+          writeOp.getLoc(), distributedVector,
+          offsetArray.take_front(rank * 2));
+      // Transpose to the native dimension order.
+      if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
+        slicedVector = rewriter.create<vector::TransposeOp>(
+            slicedVector.getLoc(), slicedVector,
+            invertPermutationVector(vectorLayout.getElementOrder()));
+      }
+      rewriter.create<vector::TransferWriteOp>(
+          writeOp.getLoc(), slicedVector, writeOp.getSource(), slicedIndices,
+          writeOp.getPermutationMapAttr(), writeOp.getMask(),
+          writeOp.getInBoundsAttr());
+    }
+
+    rewriter.eraseOp(writeOp);
+    return success();
+  }
+
+  Value threadId;
+};
+
 } // namespace
 
 void populateGPUDistributeNestedLayoutAttrPatterns(
     Value threadId, RewritePatternSet &patterns) {
-  patterns.add<DistributeTransferReadNestedLayoutAttr>(patterns.getContext(),
-                                                       threadId);
+  patterns.add<DistributeTransferReadNestedLayoutAttr,
+               DistributeTransferWriteNestedLayoutAttr>(patterns.getContext(),
+                                                        threadId);
 }
 
 }; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index d11526b..92f88ed 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -46,7 +46,7 @@
 // CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
 // CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c8]
 // CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 1, 0, 0, 0, 0]
-// CHECK: iree_vector_ext.to_simd %10 : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
 
 // -----
 
@@ -141,7 +141,7 @@
 // CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
 // CHECK: %[[OFF0:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2, %[[I0]]]
 // CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[I1]]]
-// CHECK: %[[OFF1:.+]] = affine.apply #[[$MAP1]]()[%arg1]
+// CHECK: %[[OFF1:.+]] = affine.apply #[[$MAP1]]()[%[[I1]]]
 // CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[OFF1]]]
 // CHECK: %[[OFF2:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#2, %[[I0]]]
 // CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF2]], %[[I1]]]
@@ -239,10 +239,10 @@
 // CHECK: %[[LIN_ID0:.+]] = affine.apply #[[$MAP:.+]]()[%[[IDS]]#2, %[[I1]]]
 // CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
 // CHECK: %[[I0_PLUS_8:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
-// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
 // CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
-// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0]], %[[LIN_ID1]]], {{.*}} permutation_map = #[[$MAP1]]
-// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID1]]], %cst_0 {in_bounds = [true, true], permutation_map = #map1} : memref<32x32x32x32xf16>, vector<1x8xf16>
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID1]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID1]]], %cst_0 {in_bounds = [true, true], permutation_map = #map1} : memref<32x32x32x32xf16>, vector<1x8xf16>
 
 // -----
 
@@ -325,3 +325,203 @@
 // unique transfer read ops. The broadcasted dimension (2) CSEs the duplicate
 // reads.
 // CHECK-COUNT-60: vector.transfer_read
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [2, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [8, 1],
+  elements_per_thread     = [1, 8],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @distribute_transfer_write_row_major
+func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %root, %alloc[%c0, %c0]
+          {in_bounds = [true, true],
+           "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                  : vector<16x16xf16>, memref<64x64xf16>
+  func.return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[SLICE:.+]] = vector.extract %{{.*}}[0, 0, 0, 0] : vector<1x8xf16> from vector<2x2x1x1x1x8xf16>
+// CHECK: vector.transfer_write %[[SLICE]], %{{.*}}[%[[IDS]]#2, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<64x64xf16>
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[IDS]]#2, %c8]
+// CHECK: %[[LANEX_PLUS_VECDIMX:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c0]
+// CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%[[LANEX_PLUS_VECDIMX]], %c8]
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [1, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [4, 8],
+  elements_per_thread     = [4, 1],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [1, 0]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @distribute_transfer_write_col_major
+func.func @distribute_transfer_write_col_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %root, %alloc[%c0, %c0]
+          {in_bounds = [true, true],
+           "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+                  : vector<16x16xf16>, memref<64x64xf16>
+  func.return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %0 into (%c1, %c1, %c4, %c8) : index, index, index, index
+// CHECK: %[[LANEY:.+]] = affine.apply #map()[%1#2]
+// CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
+// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
+// CHECK: vector.transfer_write %{{.*}}[%[[LANEY]], %[[IDS]]#3]
+// CHECK: %[[LANEX:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<1x4xf16> to vector<4x1xf16>
+// CHECK: vector.transfer_write {{.*}}[%[[LANEY]], %[[LANEX]]]
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [2, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [8, 1],
+  elements_per_thread     = [1, 8],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 8)>
+
+func.func @distribute_transfer_write_row_major_with_nontrivial_index(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+           "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                  : vector<16x16xf16>, memref<32x32x32x32xf16>
+  func.return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: @distribute_transfer_write_row_major_with_nontrivial_index
+// CHECK-SAME:    vector<16x16xf16>, %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[LIN_ID0:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.extract %{{.*}}[0, 0, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]] {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
+// CHECK: vector.extract %{{.*}}[1, 0, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %3] {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[LIN_ID2:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.extract %{{.*}}[0, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.extract %{{.*}}[1, 1, 0, 0]
+// CHECK: vector.transfer_write %{{.*}}[%c0, %c0, %[[LIN_ID1]], %[[LIN_ID2]]] {{.*}} permutation_map = #[[$MAP1]]
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [2, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [8, 1],
+  elements_per_thread     = [1, 8],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+func.func @distribute_transfer_read_write(%a: index, %b: index,
+                                          %arg0: memref<32x32x32x32xf16>,
+                                          %arg1: memref<32x32x32x32xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+           "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                  : memref<32x32x32x32xf16>, vector<16x16xf16>
+  vector.transfer_write %root, %arg1[%c0, %c0, %a, %b]
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+           "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                  : vector<16x16xf16>, memref<32x32x32x32xf16>
+  return
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: %[[B00:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX:[a-zA-Z0-9]+]], %[[OFFSET0:[a-zA-Z0-9]+]]]
+// CHECK: %[[B10:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX]], %[[OFFSET1:[a-zA-Z0-9]+]]]
+// CHECK: %[[B01:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH:[a-zA-Z0-9]+]], %[[OFFSET0]]]
+// CHECK: %[[B11:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH]], %[[OFFSET1]]]
+// CHECK: vector.transfer_write %[[B00]], %{{.*}}[%c0, %c0, %[[LANEX]], %[[OFFSET0]]]
+// CHECK: vector.transfer_write %[[B10]], %{{.*}}[%c0, %c0, %[[LANEX]], %[[OFFSET1]]]
+// CHECK: vector.transfer_write %[[B01]], %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH]], %[[OFFSET0]]]
+// CHECK: vector.transfer_write %[[B11]], %{{.*}}[%c0, %c0, %[[LANEX_PLUS_BATCH]], %[[OFFSET1]]]