[Codegen] Add transfer read distribution pattern for nested layout (#16393)

diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 808469c..fe92b60 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -56,6 +56,7 @@
         "GPUGeneralizeNamedOps.cpp",
         "GPULowerToUKernels.cpp",
         "GPUMultiBuffering.cpp",
+        "GPUNestedLayoutDistributionPatterns.cpp",
         "GPUPatterns.cpp",
         "GPUPipelining.cpp",
         "GPUReduceBankConflicts.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 7214b6c..cfd05aa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -54,6 +54,7 @@
     "GPUGeneralizeNamedOps.cpp"
     "GPULowerToUKernels.cpp"
     "GPUMultiBuffering.cpp"
+    "GPUNestedLayoutDistributionPatterns.cpp"
     "GPUPatterns.cpp"
     "GPUPipelining.cpp"
     "GPUReduceBankConflicts.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
new file mode 100644
index 0000000..a516148
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -0,0 +1,215 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
+#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
+#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+
+namespace mlir::iree_compiler {
+
+using namespace mlir::iree_compiler::IREE::VectorExt;
+using VectorValue = TypedValue<VectorType>;
+
+/// Helper to linearize the given |ids| with maximum values given as |sizes|.
+/// Gets the element ID in terms of |elementCount| and adds the element
+/// |offset|. For example,
+///
+/// IDs = [d0, d1, d2, d3]
+/// sizes = [s0, s1, s2, s3]
+/// linear_index = d0 * (s1 * s2 * s3)
+///              + d1 * (s2 * s3)
+///              + d2 * (s3)
+///              + d3
+/// return element_index = linear_index * |elementCount| + |offset|;
+static Value linearizeIndex(OpBuilder &builder, Value offset,
+                            ArrayRef<OpFoldResult> ids, ArrayRef<int64_t> sizes,
+                            int64_t elementCount) {
+  SmallVector<AffineExpr> exprs(ids.size() + 1);
+  bindSymbolsList(builder.getContext(), MutableArrayRef{exprs});
+  AffineExpr idExpr = builder.getAffineConstantExpr(0);
+
+  for (int i = 0, e = ids.size(); i < e; ++i) {
+    if (sizes[i] > 1) {
+      // Multiply by the residual threads along this dimension (which must be
+      // faster changing than all previous dimensions) and add the id for this
+      // dimension.
+      idExpr = idExpr * builder.getAffineConstantExpr(sizes[i]) + exprs[i];
+    }
+  }
+  idExpr = idExpr * builder.getAffineConstantExpr(elementCount);
+  idExpr = idExpr + exprs.back();
+  SmallVector<OpFoldResult> mapArgs(ids);
+  mapArgs.push_back(offset);
+  return affine::makeComposedAffineApply(
+             builder, offset.getLoc(),
+             AffineMap::get(0, mapArgs.size(), idExpr), mapArgs)
+      .getResult();
+}
+
+namespace {
+
+/// Pattern to distribute `vector.transfer_read` ops with nested layouts.
+struct DistributeTransferReadNestedLayoutAttr final
+    : OpDistributionPattern<vector::TransferReadOp> {
+  using OpDistributionPattern::OpDistributionPattern;
+
+  DistributeTransferReadNestedLayoutAttr(MLIRContext *context, Value threadId)
+      : OpDistributionPattern(context), threadId(threadId) {}
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                DistributionSignature &signature,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Support masking.
+    if (readOp.getMask()) {
+      return failure();
+    }
+    NestedLayoutAttr vectorLayout =
+        dyn_cast<NestedLayoutAttr>(signature[readOp.getResult()]);
+    if (!vectorLayout) {
+      return failure();
+    }
+
+    // Guard on memrefs for distribution. In isolation this pattern is agnostic
+    // to tensors or memrefs.
+    if (!isa<MemRefType>(readOp.getSource().getType())) {
+      return failure();
+    }
+
+    // The delinearized thread IDs are returned from outer most to inner most,
+    // i.e. before applying the layout described dimensions ordering.
+    ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
+
+    SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
+    SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
+    int64_t rank = vectorLayout.getBatchOrder().size();
+
+    // Let the unroll order first unroll the batch dimensions, then the
+    // outer vector dimensions. We unroll in the order specified by the
+    // layout.
+    SmallVector<int64_t> loopOrder;
+    int64_t base = 0;
+    for (int64_t b : vectorLayout.getBatchOrder()) {
+      loopOrder.push_back(base + b);
+      tileShape[base + b] = 1;
+    }
+    base += rank;
+    // We must unroll along the outer dimensions as well to match the rank
+    // requirements of vector transfer ops (<= memref rank up to broadcasts).
+    for (int64_t o : vectorLayout.getOuterOrder()) {
+      loopOrder.push_back(base + o);
+      tileShape[base + o] = 1;
+    }
+    base += rank;
+    for (int i = 0, e = rank; i < e; ++i) {
+      loopOrder.push_back(base + i);
+    }
+
+    Type elementType = readOp.getSource().getType().getElementType();
+    auto vectorType = VectorType::get(distShape, elementType);
+    // The shape of the vector we read is pre-permutation. The permutation is
+    // a transpose on the resulting read vector.
+    auto innerVectorType =
+        VectorType::get(vectorLayout.getElementsPerThread(), elementType);
+
+    // Initialize the full distributed vector for unrolling the batch/outer
+    // vector dimensions.
+    Value zero = rewriter.create<arith::ConstantOp>(
+        readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+    VectorValue acc = cast<VectorValue>(zero);
+
+    // Subgroup and thread (lane) indices normalized to the order in which
+    // they are used by each dimension.
+    SmallVector<Value> warpIndices = llvm::to_vector(
+        llvm::map_range(vectorLayout.getSubgroupOrder(),
+                        [&](int64_t i) { return threadIds[i]; }));
+    SmallVector<Value> threadIndices = llvm::to_vector(
+        llvm::map_range(vectorLayout.getThreadOrder(),
+                        [&](int64_t i) { return threadIds[i + rank]; }));
+
+    auto isBroadcast = [](AffineExpr expr) {
+      if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
+        return constExpr.getValue() == 0;
+      return false;
+    };
+
+    ValueRange indices = readOp.getIndices();
+    SmallVector<int64_t> strides(rank, 1);
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(distShape, tileShape, loopOrder)) {
+      // Permute the batch and outer vector offsets to match the order of
+      // the vector dimensions using the inverse of the batch/offset order.
+      SmallVector<int64_t> batchOffsets = applyPermutation(
+          ArrayRef<int64_t>(offsets.begin(), rank),
+          invertPermutationVector(vectorLayout.getBatchOrder()));
+      SmallVector<int64_t> outerVectorOffsets = applyPermutation(
+          ArrayRef<int64_t>(offsets.begin() + rank, rank),
+          invertPermutationVector(vectorLayout.getOuterOrder()));
+
+      SmallVector<Value> slicedIndices(indices.begin(), indices.end());
+      for (const auto &[i, dim] :
+           llvm::enumerate(readOp.getPermutationMap().getResults())) {
+        if (isBroadcast(dim))
+          continue;
+        unsigned pos = cast<AffineDimExpr>(dim).getPosition();
+        SmallVector<OpFoldResult> ids = {
+            warpIndices[i], rewriter.getIndexAttr(batchOffsets[i]),
+            rewriter.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
+        // The order in which a vector dimension is "tiled" is
+        // subgroups -> batches -> outer vectors -> threads -> elements
+        SmallVector<int64_t> sizes = {
+            vectorLayout.getSubgroupsPerWorkgroup()[i],
+            vectorLayout.getBatchesPerSubgroup()[i],
+            vectorLayout.getOutersPerBatch()[i],
+            vectorLayout.getThreadsPerOuter()[i]};
+        slicedIndices[pos] =
+            linearizeIndex(rewriter, indices[pos], ids, sizes,
+                           vectorLayout.getElementsPerThread()[i]);
+      }
+      Value slicedRead = rewriter.create<vector::TransferReadOp>(
+          readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
+          readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
+          readOp.getInBoundsAttr());
+      // Transpose to the element order.
+      if (!isIdentityPermutation(vectorLayout.getElementOrder())) {
+        slicedRead = rewriter.create<vector::TransposeOp>(
+            slicedRead.getLoc(), slicedRead, vectorLayout.getElementOrder());
+      }
+
+      acc = rewriter.create<vector::InsertStridedSliceOp>(
+          readOp.getLoc(), slicedRead, acc, offsets, strides);
+    }
+
+    replaceOpWithDistributedValues(rewriter, readOp, acc);
+    return success();
+  }
+
+  Value threadId;
+};
+
+} // namespace
+
+void populateGPUDistributeNestedLayoutAttrPatterns(
+    Value threadId, RewritePatternSet &patterns) {
+  patterns.add<DistributeTransferReadNestedLayoutAttr>(patterns.getContext(),
+                                                       threadId);
+}
+
+}; // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
index 727eaaa..5475f03 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h
@@ -37,6 +37,9 @@
 void populateGPUReductionDistributionPatterns(RewritePatternSet &patterns,
                                               int64_t maxBitsPerShuffle = 32);
 
+void populateGPUDistributeNestedLayoutAttrPatterns(Value threadId,
+                                                   RewritePatternSet &patterns);
+
 } // namespace mlir::iree_compiler
 
 #endif // IREE_COMPILER_CODEGEN_COMMON_GPUPATTERNS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 42a3c0a..d21e6a7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -23,6 +23,7 @@
             "gpu_distribute_shared_memory.mlir",
             "gpu_generalize_named_ops.mlir",
             "gpu_lower_to_ukernels.mlir",
+            "gpu_nested_layout_vector_distribution.mlir",
             "gpu_pipeline.mlir",
             "gpu_tensor_alloc.mlir",
             "gpu_tensor_tile.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index ddc8f47..2891b8a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -19,6 +19,7 @@
     "gpu_distribute_shared_memory.mlir"
     "gpu_generalize_named_ops.mlir"
     "gpu_lower_to_ukernels.mlir"
+    "gpu_nested_layout_vector_distribution.mlir"
     "gpu_pipeline.mlir"
     "gpu_tensor_alloc.mlir"
     "gpu_tensor_tile.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
new file mode 100644
index 0000000..d11526b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -0,0 +1,327 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse %s | FileCheck %s
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [2, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [8, 1],
+  elements_per_thread     = [1, 8],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-LABEL: @distribute_transfer_read_row_major
+func.func @distribute_transfer_read_row_major(%arg0: memref<4x4xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0], %cst
+          {in_bounds = [false, false],
+           "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                  : memref<4x4xf16>, vector<16x16xf16>
+  func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x1x1x1x8xf16>
+// CHECK: %[[IDX:.+]] = gpu.thread_id  x
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %[[IDX]] into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c0], {{.*}} : memref<4x4xf16>, vector<1x8xf16>
+// CHECK: vector.insert_strided_slice %{{.*}}, %[[ACC]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<2x2x1x1x1x8xf16>
+// CHECK: vector.transfer_read %arg0[%[[IDS]]#2, %c8]
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+// CHECK: %[[ID_PLUS_BATCH1:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
+// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c0]
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [0, 1, 0, 0, 0, 0]
+// CHECK: vector.transfer_read %arg0[%[[ID_PLUS_BATCH1]], %c8]
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 1, 0, 0, 0, 0]
+// CHECK: iree_vector_ext.to_simd %10 : vector<2x2x1x1x1x8xf16> -> vector<16x16xf16>
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [1, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [4, 8],
+  elements_per_thread     = [4, 1],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [1, 0]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @distribute_transfer_read_col_major
+func.func @distribute_transfer_read_col_major(%arg0: memref<32x32xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0], %cst
+          {in_bounds = [true, true],
+           "__vector_layout_test_anchor_result_0" = #layout_col_major}
+                  : memref<32x32xf16>, vector<16x16xf16>
+  func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c4, %c8) : index, index, index, index
+// CHECK: %[[LANEY:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2]
+// CHECK: %[[RD00:.+]] = vector.transfer_read %arg0[%[[LANEY:.+]], %[[IDS]]#3], {{.*}} : memref<32x32xf16>, vector<4x1xf16>
+// CHECK: %[[ELEM_ORDER:.+]] = vector.transpose %[[RD00]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
+// CHECK: vector.insert_strided_slice %[[ELEM_ORDER]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x4xf16> into vector<2x1x1x1x1x4xf16>
+// CHECK: %[[LANEX_PLUS_BATCH:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#3]
+// CHECK: vector.transfer_read %arg0[%[[LANEY]], %[[LANEX_PLUS_BATCH]]], %{{.*}} {in_bounds = [true, true]} : memref<32x32xf16>, vector<4x1xf16>
+// CHECK: vector.transpose %{{.*}}, [1, 0] : vector<4x1xf16> to vector<1x4xf16>
+// CHECK: vector.insert_strided_slice {{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x1x1x1x1x4xf16> -> vector<16x16xf16>
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [2, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [8, 1],
+  elements_per_thread     = [1, 8],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 8)>
+
+func.func @distribute_transfer_read_row_major_with_nontrivial_index(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
+           "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                  : memref<32x32x32x32xf16>, vector<16x16xf16>
+  func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+// CHECK-LABEL: @distribute_transfer_read_row_major_with_nontrivial_index
+// CHECK-SAME:    %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[OFF0:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#2, %[[I0]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[I1]]]
+// CHECK: %[[OFF1:.+]] = affine.apply #[[$MAP1]]()[%arg1]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF0]], %[[OFF1]]]
+// CHECK: %[[OFF2:.+]] = affine.apply #[[$MAP2]]()[%[[IDS]]#2, %[[I0]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF2]], %[[I1]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[OFF2]], %[[OFF1]]]
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [1, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [4, 8],
+  elements_per_thread     = [4, 1],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [1, 0]
+>
+
+func.func @distribute_transfer_read_col_major_with_broadcast(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (0, 0)>,
+           "__vector_layout_test_anchor_result_0" = #layout_col_major}
+                  : memref<32x32x32x32xf16>, vector<16x16xf16>
+  func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (0, 0)>
+
+// CHECK-LABEL: @distribute_transfer_read_col_major_with_broadcast
+// CHECK-SAME:    %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[BROADCAST_READ:.+]] = vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[I1]]], %{{.*}} permutation_map = #[[$MAP]]
+// CHECK: %[[UNIT:.+]] = vector.transpose %[[BROADCAST_READ]], [1, 0] : vector<4x1xf16> to vector<1x4xf16>
+// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [0, 0, 0, 0, 0, 0]
+// CHECK: vector.insert_strided_slice %[[UNIT]], %{{.*}} {offsets = [1, 0, 0, 0, 0, 0]
+
+// -----
+
+#layout_row_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [2, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [8, 1],
+  elements_per_thread     = [1, 8],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [1, 0],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 8)>
+
+func.func @distribute_transfer_read_row_major_transpose(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+           "__vector_layout_test_anchor_result_0" = #layout_row_major}
+                  : memref<32x32x32x32xf16>, vector<16x16xf16>
+  func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: @distribute_transfer_read_row_major_transpose
+// CHECK-SAME:    %[[I0:.+]]: index, %[[I1:.+]]: index
+
+// CHECK: %[[IDS:.+]]:4 = affine.delinearize_index %{{.*}} into (%c1, %c1, %c8, %c1) : index, index, index, index
+// CHECK: %[[LIN_ID0:.+]] = affine.apply #[[$MAP:.+]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.transfer_read %{{.*}}[%c0, %c0, %[[I0]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[I0_PLUS_8:.+]] = affine.apply #[[$MAP2]]()[%[[I0]]]
+// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID0]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: %[[LIN_ID1:.+]] = affine.apply #[[$MAP3]]()[%[[IDS]]#2, %[[I1]]]
+// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0]], %[[LIN_ID1]]], {{.*}} permutation_map = #[[$MAP1]]
+// CHECK: vector.transfer_read %arg2[%c0, %c0, %[[I0_PLUS_8]], %[[LIN_ID1]]], %cst_0 {in_bounds = [true, true], permutation_map = #map1} : memref<32x32x32x32xf16>, vector<1x8xf16>
+
+// -----
+
+#layout_col_major = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup    = [1, 2],
+  outers_per_batch        = [1, 1],
+  threads_per_outer       = [4, 8],
+  elements_per_thread     = [4, 1],
+
+  subgroup_order          = [0, 1],
+  batch_order             = [0, 1],
+  outer_order             = [0, 1],
+  thread_order            = [0, 1],
+  element_order           = [0, 1]
+>
+
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+
+// CHECK-LABEL: @distribute_transfer_read_col_major_transpose
+func.func @distribute_transfer_read_col_major_transpose(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+           "__vector_layout_test_anchor_result_0" = #layout_col_major}
+                  : memref<32x32x32x32xf16>, vector<16x16xf16>
+  func.return %root : vector<16x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: vector.transfer_read {{.*}} permutation_map = #[[$MAP2]]
+// CHECK: vector.transfer_read {{.*}} permutation_map = #[[$MAP2]]
+
+// -----
+
+#layout = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [7, 3, 1, 1],
+  batches_per_subgroup    = [3, 5, 2, 1],
+  outers_per_batch        = [1, 1, 2, 4],
+  threads_per_outer       = [1, 1, 2, 2],
+  elements_per_thread     = [1, 1, 1, 2],
+
+  subgroup_order          = [1, 0, 2, 3],
+  batch_order             = [1, 2, 3, 0],
+  outer_order             = [0, 3, 1, 2],
+  thread_order            = [0, 1, 3, 2],
+  element_order           = [0, 1, 2, 3]
+>
+
+func.func @distribute_transfer_read_row_major_with_permutations(%a: index, %b: index, %arg0: memref<32x32x32x32xf16>) -> vector<21x15x8x16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0, %a, %b], %cst
+          {in_bounds = [true, true, true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d3, 0, d1)>,
+           "__vector_layout_test_anchor_result_0" = #layout}
+                  : memref<32x32x32x32xf16>, vector<21x15x8x16xf16>
+  func.return %root : vector<21x15x8x16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+// CHECK-LABEL: @distribute_transfer_read_row_major_with_permutations
+
+// Verify that there are (batch0: 3) * (batch1: 5) * (outer3: 4) = 60 total
+// unique transfer read ops. The broadcasted dimension (2) CSEs the duplicate
+// reads.
+// CHECK-COUNT-60: vector.transfer_read
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index c801ad2..1ca1dd2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -966,6 +966,7 @@
   populateGPUDistributionPatterns(patterns);
   populateGPUDistributionLayoutAttrPatterns(laneId, patterns);
   populateGPUReductionDistributionPatterns(patterns);
+  populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns);
   distributeVectorOps(target, patterns, options);
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 82ea1fb..31044bc 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -733,6 +733,7 @@
         ":IREEVectorExtIncGen",
         ":IREEVectorExtInterfacesIncGen",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
     ],
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index 05e331c..c12f9eb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -293,6 +293,12 @@
     `>`
   }];
 
+  let extraClassDeclaration = [{
+    // Returns the subgroup/lane ids delinearized from a single linearized
+    // thread ID.
+    ValueRange computeThreadIds(Value threadId, RewriterBase &rewriter) const;
+  }];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h
index 71e9f0c..59bcd1a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.h
@@ -9,6 +9,8 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Value.h>
 
 /// Include the generated interface declarations.
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtAttrInterfaces.h.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
index 6e965d1..82d810b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/CMakeLists.txt
@@ -12,6 +12,7 @@
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRAffineDialect
 )
 
 iree_dialects_target_includes(IREEVectorExtDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 6b80529..033e0ff 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -4,12 +4,16 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include <numeric>
+
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include <numeric>
 
 using namespace mlir;
 
@@ -264,6 +268,23 @@
   return success();
 }
 
+/// Given a single flat thread ID, compute the indices of the distributed
+/// dimensions (subgroup and thread ids).
+ValueRange NestedLayoutAttr::computeThreadIds(Value threadId,
+                                              RewriterBase &rewriter) const {
+  SmallVector<OpFoldResult> basis;
+  for (auto warpTy : getSubgroupOrder()) {
+    basis.push_back(rewriter.getIndexAttr(getSubgroupsPerWorkgroup()[warpTy]));
+  }
+  for (auto threadTy : getThreadOrder()) {
+    basis.push_back(rewriter.getIndexAttr(getThreadsPerOuter()[threadTy]));
+  }
+
+  auto delinearized = rewriter.create<mlir::affine::AffineDelinearizeIndexOp>(
+      threadId.getLoc(), threadId, basis);
+  return delinearized->getResults();
+}
+
 } // namespace mlir::iree_compiler::IREE::VectorExt
 
 using namespace mlir::iree_compiler::IREE::VectorExt;