[VectorDistribution] Add patterns for distributing transfer_read/transfer_write (#16115)

This patch adds distribution patterns for transfer_read/transfer_write
and lowers them to vector.load/vector.store per thread.

These distribution patterns do the lowering for
transfer_read/transfer_write in one-shot, which is different from how
transfer_read/transfer_write are lowered in upstream mlir. The upstream
patterns unroll one dimension at a time and apply the patterns
recursively. We do this lowering in one-shot because we have the layout
attribute which defines the iteration space for the lowering.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 15583b5..84dce8d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -8,7 +8,9 @@
 #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
 #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Rewrite/PatternApplicator.h"
@@ -20,6 +22,93 @@
 
 namespace {
 
+/// Given a LayoutAttr, find the shape of the given layout dimension. It is
+/// expected that the layout has at most one instance of the requested
+/// dimension. Example:
+///   LayoutAttr: <<BATCHX: 4>, <BATCHY: 4, LANEX: 4>>
+///   dim: BATCHX
+///   output: 4
+static std::optional<int64_t> findDimShape(LayoutAttr layout,
+                                           LayoutDimension dim) {
+  for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
+    if (std::optional<int64_t> shape = dimLayout.getShape(dim)) {
+      return shape;
+    }
+  }
+  return std::nullopt;
+}
+
+/// Given the state of the iterator, compute the indices of the original vector
+/// that the current iterator state is iterating over. These indices are
+/// parameterized by the thread grid.
+static SmallVector<Value> computeSIMDIndex(const LayoutIterator::State &state,
+                                           LayoutAttr layout,
+                                           ArrayRef<Value> threadGrid,
+                                           RewriterBase &rewriter) {
+  MLIRContext *ctx = layout.getContext();
+  AffineExpr threadX, threadY, threadZ;
+  bindSymbols(ctx, threadX, threadY, threadZ);
+
+  SmallVector<Value> simdIndex;
+  // Calculate the index for each dim separately.
+  for (PerDimLayoutAttr dimLayout : layout.getLayouts()) {
+    AffineExpr offset = getAffineConstantExpr(0, ctx);
+    AffineExpr stride = getAffineConstantExpr(1, ctx);
+    for (auto [label, shape] : llvm::reverse(
+             llvm::zip(dimLayout.getLabels(), dimLayout.getShapes()))) {
+      int64_t position = state.lookup(label.getValue()).getPosition();
+
+      switch (label.getValue()) {
+      case LayoutDimension::LANEX:
+        offset = offset + stride * threadX;
+        break;
+      case LayoutDimension::LANEY:
+        offset = offset + stride * threadY;
+        break;
+      case LayoutDimension::LANEZ:
+        offset = offset + stride * threadZ;
+        break;
+      default:
+        offset = offset + stride * getAffineConstantExpr(position, ctx);
+        break;
+      }
+      stride = stride * getAffineConstantExpr(shape, ctx);
+    }
+
+    // Compute the index for the dim.
+    AffineMap indexMap = AffineMap::get(0, 3, offset);
+    Value index = rewriter.create<affine::AffineApplyOp>(
+        rewriter.getUnknownLoc(), indexMap, threadGrid);
+    simdIndex.push_back(index);
+  }
+
+  return simdIndex;
+}
+
+/// Given the state of the iterator, compute the indices of the distributed
+/// vector that the current iterator state is iterating over. The indices
+/// are not parameterized by thread, and it is expected that the indices for
+/// all threads are same.
+static SmallVector<int64_t> computeSIMTIndex(const LayoutIterator::State &state,
+                                             LayoutAttr layout) {
+  constexpr LayoutDimension labels[] = {
+      LayoutDimension::BATCHX, LayoutDimension::BATCHY,
+      LayoutDimension::VECTORZ, LayoutDimension::VECTORY,
+      LayoutDimension::VECTORX};
+
+  SmallVector<int64_t> offset;
+  for (LayoutDimension label : labels) {
+    std::optional shape = findDimShape(layout, label);
+    if (!shape) {
+      continue;
+    }
+    // Get current position for the label.
+    int64_t position = state.lookup(label).getPosition();
+    offset.push_back(position);
+  }
+  return offset;
+}
+
 struct DistributeConstants final : OpDistributionPattern<arith::ConstantOp> {
   using OpDistributionPattern::OpDistributionPattern;
 
@@ -99,6 +188,208 @@
   }
 };
 
+/// Given a projected permutation, get a reduced permutation, i.e. without
+/// the projected dimensions.
+static SmallVector<int64_t> getReducedPermutation(AffineMap permutationMap) {
+  assert(permutationMap.isProjectedPermutation() &&
+         "permutation map should be a projected permutation.");
+  // TODO: The permutation map may also have broadcasting. Currently, we do not
+  // handle it. This can be fixed by adding a "BROADCAST" dimension in the
+  // layout.
+
+  SmallVector<int64_t> permutation;
+  permutation.reserve(permutationMap.getNumResults());
+
+  unsigned leadingUnitDims =
+      permutationMap.getNumDims() - permutationMap.getNumResults();
+  for (AffineExpr dim : permutationMap.getResults()) {
+    // Get this dim's position in the permutation map.
+    auto dimExpr = dyn_cast<AffineDimExpr>(dim);
+    if (!dimExpr) {
+      llvm::report_fatal_error("permutation map is not a projected "
+                               "permutation.");
+    }
+
+    unsigned pos = dimExpr.getPosition();
+    assert(pos >= leadingUnitDims && "invalid permutation map");
+    pos -= leadingUnitDims;
+    permutation.push_back(pos);
+  }
+  return permutation;
+}
+
+template <typename OpTy>
+struct DistributeXferLayoutAttr : OpDistributionPattern<OpTy> {
+  static_assert(std::is_same<OpTy, vector::TransferReadOp>::value ||
+                    std::is_same<OpTy, vector::TransferWriteOp>::value,
+                "expected vector::TransferReadOp or vector::TransferWriteOp");
+
+  DistributeXferLayoutAttr(MLIRContext *context, ArrayRef<Value> threadGrid,
+                           PatternBenefit benefit = 1)
+      : OpDistributionPattern<OpTy>(context, benefit), threadGrid(threadGrid) {}
+
+  VectorValue accessMemory(OpTy xferOp, VectorValue accumulator,
+                           LayoutAttr vectorLayout,
+                           PatternRewriter &rewriter) const {
+    // We need to take special consideration of the permutation map when
+    // lowering. When accessing memory, we use the memoryLayout, because that
+    // is how the data is accessed in memory. The data is stored in the vector
+    // according to vectorLayout.
+    SmallVector<int64_t> permutation =
+        getReducedPermutation(xferOp.getPermutationMap());
+    LayoutAttr memoryLayout =
+        cast<LayoutAttr>(vectorLayout.permute(permutation));
+
+    int loadWidth = getLoadStoreWidth(memoryLayout);
+    DenseMap<LayoutDimension, int64_t> steps;
+    steps[LayoutDimension::VECTORX] = loadWidth;
+    LayoutIterator iterator(vectorLayout, steps);
+
+    iterator.apply([&](const LayoutIterator::State &state) {
+      SmallVector<Value> memoryIndices =
+          getMemoryIndices(state, memoryLayout, xferOp.getIndices(), rewriter);
+      SmallVector<int64_t> accIndices = computeSIMTIndex(state, vectorLayout);
+      accumulator = accessUnit(xferOp, memoryIndices, accIndices, accumulator,
+                               vectorLayout, memoryLayout, rewriter);
+    });
+
+    return accumulator;
+  }
+
+  SmallVector<Value> getMemoryIndices(const LayoutIterator::State &state,
+                                      LayoutAttr memoryLayout,
+                                      SmallVector<Value> indices,
+                                      RewriterBase &rewriter) const {
+    SmallVector<Value> simdIndices =
+        computeSIMDIndex(state, memoryLayout, threadGrid, rewriter);
+    SmallVector<Value> memoryIndices(indices);
+
+    // The memory layout has some projected leading dims that indices doesn't.
+    int leadingProjectedDims = memoryIndices.size() - simdIndices.size();
+    for (int i = leadingProjectedDims, e = memoryIndices.size(); i < e; ++i) {
+      memoryIndices[i] = rewriter.create<arith::AddIOp>(
+          rewriter.getUnknownLoc(), memoryIndices[i],
+          simdIndices[i - leadingProjectedDims]);
+    }
+
+    return memoryIndices;
+  }
+
+  virtual VectorValue accessUnit(OpTy xferOp, SmallVector<Value> &memoryIndices,
+                                 SmallVector<int64_t> &accIndices,
+                                 VectorValue accumulator,
+                                 LayoutAttr vectorLayout,
+                                 LayoutAttr memoryLayout,
+                                 PatternRewriter &rewriter) const = 0;
+
+  int getLoadStoreWidth(LayoutAttr layout) const {
+    PerDimLayoutAttr fastestChanging = layout.getLayouts().back();
+    if (std::optional<int64_t> width =
+            fastestChanging.getShape(LayoutDimension::VECTORX)) {
+      return *width;
+    }
+    return 1;
+  }
+
+  SmallVector<Value> threadGrid;
+};
+
+struct DistributeTransferReadLayoutAttr final
+    : DistributeXferLayoutAttr<vector::TransferReadOp> {
+  using DistributeXferLayoutAttr::DistributeXferLayoutAttr;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    VectorValue vector = readOp.getVector();
+    LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.results[0]);
+    if (!vectorLayout) {
+      return failure();
+    }
+
+    // TODO: Return failure if we need masking.
+
+    Type elementType = readOp.getSource().getType().getElementType();
+    auto vectorType = VectorType::get(
+        vectorLayout.getDistributedShape(vector.getType()), elementType);
+    Value zero = rewriter.create<arith::ConstantOp>(
+        readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+    VectorValue acc = cast<VectorValue>(zero);
+
+    VectorValue readVec = accessMemory(readOp, acc, vectorLayout, rewriter);
+
+    replaceOpWithDistributedValues(rewriter, readOp, readVec);
+    return success();
+  }
+
+  VectorValue accessUnit(vector::TransferReadOp readOp,
+                         SmallVector<Value> &memoryIndices,
+                         SmallVector<int64_t> &accIndices,
+                         VectorValue accumulator, LayoutAttr vectorLayout,
+                         LayoutAttr memoryLayout,
+                         PatternRewriter &rewriter) const override {
+    auto unitType = VectorType::get({getLoadStoreWidth(memoryLayout)},
+                                    accumulator.getType().getElementType());
+    VectorValue load = rewriter.create<vector::LoadOp>(
+        readOp.getLoc(), unitType, readOp.getSource(), memoryIndices);
+    return rewriter.create<vector::InsertStridedSliceOp>(
+        readOp.getLoc(), load, accumulator, accIndices,
+        SmallVector<int64_t>{1});
+  }
+};
+
+struct DistributeTransferWriteLayoutAttr final
+    : DistributeXferLayoutAttr<vector::TransferWriteOp> {
+  using DistributeXferLayoutAttr::DistributeXferLayoutAttr;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    VectorValue vector = writeOp.getVector();
+    LayoutAttr vectorLayout = dyn_cast<LayoutAttr>(signature.operands[0]);
+    if (!vectorLayout) {
+      return failure();
+    }
+
+    // TODO: Return failure if we need masking.
+
+    Type elementType = writeOp.getSource().getType().getElementType();
+    auto vectorType = VectorType::get(
+        vectorLayout.getDistributedShape(vector.getType()), elementType);
+    Value zero = rewriter.create<arith::ConstantOp>(
+        writeOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+    VectorValue acc = cast<VectorValue>(zero);
+
+    accessMemory(writeOp, acc, vectorLayout, rewriter);
+
+    rewriter.eraseOp(writeOp);
+    return success();
+  }
+
+  VectorValue accessUnit(vector::TransferWriteOp writeOp,
+                         SmallVector<Value> &memoryIndices,
+                         SmallVector<int64_t> &accIndices,
+                         VectorValue accumulator, LayoutAttr vectorLayout,
+                         LayoutAttr memoryLayout,
+                         PatternRewriter &rewriter) const override {
+    int width = getLoadStoreWidth(memoryLayout);
+
+    SmallVector<int64_t> strides(accIndices.size(), 1);
+    SmallVector<int64_t> shapes(accIndices.size(), 1);
+    shapes[shapes.size() - 1] = width;
+    Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+        writeOp.getLoc(), getDistributed(rewriter, accumulator, vectorLayout),
+        accIndices, shapes, strides);
+    result = rewriter.create<vector::ExtractOp>(
+        writeOp.getLoc(), result,
+        SmallVector<int64_t>(accIndices.size() - 1, 0));
+    rewriter.create<vector::StoreOp>(writeOp.getLoc(), result,
+                                     writeOp.getSource(), memoryIndices);
+
+    return accumulator;
+  }
+};
+
 }; // namespace
 
 void populateGPUDistributionPatterns(RewritePatternSet &patterns) {
@@ -108,4 +399,11 @@
                DistributeElementwise<arith::AddFOp>>(patterns.getContext());
 }
 
+void populateGPUDistributionLayoutAttrPatterns(ArrayRef<Value> threadGrid,
+                                               RewritePatternSet &patterns) {
+  patterns
+      .add<DistributeTransferReadLayoutAttr, DistributeTransferWriteLayoutAttr>(
+          patterns.getContext(), threadGrid);
+}
+
 }; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
index c6da6ca..cdda85a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
@@ -31,6 +31,9 @@
 
 void populateGPUDistributionPatterns(RewritePatternSet &patterns);
 
+void populateGPUDistributionLayoutAttrPatterns(ArrayRef<Value> threadGrid,
+                                               RewritePatternSet &patterns);
+
 } // namespace mlir::iree_compiler
 
 #endif // IREE_COMPILER_CODEGEN_COMMON_GPUPATTERNS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index 53ddddf..a7ab386 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -41,3 +41,175 @@
     transform.yield
   }
 }
+
+// -----
+
+#layout_row_major = #iree_vector_ext.layout<<[BATCHX, LANEY], [2, 8]>, <[BATCHY, LANEX, VECTORX], [2, 1, 8]>>
+#layout_col_major = #iree_vector_ext.layout<<[BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[BATCHY, LANEX], [2, 8]>>
+
+builtin.module attributes { transform.with_named_sequence } {
+  // CHECK-LABEL: @distribute_transfer_read_row_major
+  func.func @distribute_transfer_read_row_major(%alloc: memref<4x4xf16>) -> vector<16x16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f16
+    %root = vector.transfer_read %alloc[%c0, %c0], %cst
+            {in_bounds = [false, false],
+             "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                    : memref<4x4xf16>, vector<16x16xf16>
+    // CHECK-COUNT-4: vector.load {{.*}}, vector<8xf16>
+    func.return %root : vector<16x16xf16>
+  }
+
+  // CHECK-LABEL: @distribute_transfer_read_col_major
+  func.func @distribute_transfer_read_col_major(%alloc: memref<32x32xf16>) -> vector<16x16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f16
+    %root = vector.transfer_read %alloc[%c0, %c0], %cst
+            {in_bounds = [true, true],
+             "__vector_layout_test_anchor_result_0" = #layout_col_major}
+                    : memref<32x32xf16>, vector<16x16xf16>
+    // CHECK-COUNT-8: vector.load {{.*}}, vector<1xf16>
+    func.return %root : vector<16x16xf16>
+  }
+
+  // CHECK-LABEL: @distribute_transfer_read_row_major_with_broadcast
+  func.func @distribute_transfer_read_row_major_with_broadcast(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f16
+    %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+             "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                    : memref<32x32x32x32xf16>, vector<16x16xf16>
+    // CHECK-COUNT-4: vector.load {{.*}}, vector<8xf16>
+    func.return %root : vector<16x16xf16>
+  }
+
+  // CHECK-LABEL: @distribute_transfer_read_col_major_with_broadcast
+  func.func @distribute_transfer_read_col_major_with_broadcast(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f16
+    %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+             "__vector_layout_test_anchor_result_0" = #layout_col_major}
+                    : memref<32x32x32x32xf16>, vector<16x16xf16>
+    // CHECK-COUNT-8: vector.load {{.*}}, vector<1xf16>
+    func.return %root : vector<16x16xf16>
+  }
+
+  // CHECK-LABEL: @distribute_transfer_read_row_major_transpose
+  func.func @distribute_transfer_read_row_major_transpose(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f16
+    %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+             "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                    : memref<32x32x32x32xf16>, vector<16x16xf16>
+    // CHECK-COUNT-32: vector.load {{.*}}, vector<1xf16>
+    func.return %root : vector<16x16xf16>
+  }
+
+  // CHECK-LABEL: @distribute_transfer_read_col_major_transpose
+  func.func @distribute_transfer_read_col_major_transpose(%a: index, %b: index, %alloc: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f16
+    %root = vector.transfer_read %alloc[%c0, %c0, %a, %b], %cst
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+             "__vector_layout_test_anchor_result_0" = #layout_col_major}
+                    : memref<32x32x32x32xf16>, vector<16x16xf16>
+    // CHECK-COUNT-2: vector.load {{.*}}, vector<4xf16>
+    func.return %root : vector<16x16xf16>
+  }
+
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#layout_row_major = #iree_vector_ext.layout<<[BATCHX, LANEY], [2, 8]>, <[BATCHY, LANEX, VECTORX], [2, 1, 8]>>
+#layout_col_major = #iree_vector_ext.layout<<[BATCHX, LANEY, VECTORX], [1, 4, 4]>, <[BATCHY, LANEX], [2, 8]>>
+
+builtin.module attributes { transform.with_named_sequence } {
+  // CHECK-LABEL: @distribute_transfer_write_row_major
+  func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %root, %alloc[%c0, %c0]
+            {in_bounds = [true, true],
+             "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                    : vector<16x16xf16>, memref<64x64xf16>
+    // CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
+    func.return
+  }
+
+  // CHECK-LABEL: @distribute_transfer_write_col_major
+  func.func @distribute_transfer_write_col_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %root, %alloc[%c0, %c0]
+            {in_bounds = [true, true],
+             "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+                    : vector<16x16xf16>, memref<64x64xf16>
+    // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
+    func.return
+  }
+
+  // CHECK-LABEL: @distribute_transfer_write_row_major_with_broadcast
+  func.func @distribute_transfer_write_row_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+             "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                    : vector<16x16xf16>, memref<32x32x32x32xf16>
+    // CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
+    func.return
+  }
+
+  // CHECK-LABEL: @distribute_transfer_write_col_major_with_broadcast
+  func.func @distribute_transfer_write_col_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+             "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+                    : vector<16x16xf16>, memref<32x32x32x32xf16>
+    // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
+    func.return
+  }
+
+  // CHECK-LABEL: @distribute_transfer_write_row_major_transpose
+  func.func @distribute_transfer_write_row_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+             "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                    : vector<16x16xf16>, memref<32x32x32x32xf16>
+    // CHECK-COUNT-32: vector.store {{.*}}, vector<1xf16>
+    func.return
+  }
+
+  // CHECK-LABEL: @distribute_transfer_write_col_major_transpose
+  func.func @distribute_transfer_write_col_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
+            {in_bounds = [true, true],
+             permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+             "__vector_layout_test_anchor_operand_0" = #layout_col_major}
+                    : vector<16x16xf16>, memref<32x32x32x32xf16>
+    // CHECK-COUNT-2: vector.store {{.*}}, vector<4xf16>
+    func.return
+  }
+
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 38af96b..4eec9cb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -1018,7 +1018,16 @@
     transform::TransformState &state) {
   TestVectorLayoutOptions options(target);
   RewritePatternSet patterns(target.getContext());
+
+  rewriter.setInsertionPointToStart(&target.getBody().front());
+  SmallVector<Value> threadGrid = {
+      rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x),
+      rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::y),
+      rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::z),
+  };
+
   populateGPUDistributionPatterns(patterns);
+  populateGPUDistributionLayoutAttrPatterns(threadGrid, patterns);
   distributeVectorOps(target, patterns, options);
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
index e1d490b..f13d2f0 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h
@@ -42,6 +42,8 @@
     return position < other.position;
   }
 
+  int64_t getPosition() const { return position; }
+
 private:
   int64_t position, stride;
 };