[VectorExt] Add support for projecting nested layouts (#16528)

This adds two new fields to the layout to track which ids in the basis
are correspond to the subgroups_per_workgroup and threads_per_outer. The
reason these masks are required is because when performing a
rank-reducing projection, we lose necessary degrees of freedom
(dimensions) that can indicate how the data in the underlying vector is
replicated across threads. The simplest example is reducing a 2x2 vector
with a grid of 2x2 threads. In this case there are two ways the reduced
data can be replicated across threads.

```
vector<2> = s0, s1, s0, s1
vector<2> = s0, s0, s1, s1
```

however we only have one valid layout based on the constraints on the
thread basis/thread counts. Assuming all other dimensions are 1, it is
required that thread_count == vector size == 2. Similarly, the total
number of threads in the basis must be equal to the flat total number of
threads (in this case 4). Since there is only one valid layout, there is
no way to differentiate between these two distributed cases.

The active_id masks fix this by further decoupling the thread basis
(i.e. how to delinearize the flat thread id to a set of ids used by the
layout) from the actual vector shape. Handling projections then is
simply a matter of masking off the ids of the basis according to the
projected dims.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index 2497666..cd5b57a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -324,6 +324,45 @@
 
 // -----
 
+#layout = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1],
+  batches_per_subgroup    = [1],
+  outers_per_batch        = [1],
+  threads_per_outer       = [4],
+  elements_per_thread     = [4],
+
+  subgroup_basis          = [1],
+  thread_basis            = [4, 16],
+  thread_active_ids       = [true, false]
+>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK-LABEL: @distribute_transfer_read_broadcast
+func.func @distribute_transfer_read_broadcast(%arg0: memref<32x32xf16>) -> vector<16xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %root = vector.transfer_read %arg0[%c0, %c0], %cst
+          {in_bounds = [true],
+           "__vector_layout_test_anchor_result_0" = #layout}
+                  : memref<32x32xf16>, vector<16xf16>
+  func.return %root : vector<16xf16>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: %[[IDS:.+]]:3 = affine.delinearize_index %{{.*}} into (%c1, %c4, %c16) : index, index, index
+// CHECK: %[[LANEY:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#1]
+// CHECK: %[[RD:.+]] = vector.transfer_read %{{.*}}[%c0, %[[LANEY:.+]]], {{.*}} : memref<32x32xf16>, vector<4xf16>
+
+// -----
+
 #layout_row_major = #iree_vector_ext.nested_layout<
   subgroups_per_workgroup = [1, 1],
   batches_per_subgroup    = [2, 2],
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
index 5fbff7f..720eb6d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
@@ -19,7 +19,7 @@
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -45,7 +45,7 @@
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -74,7 +74,7 @@
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -107,7 +107,7 @@
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -142,7 +142,7 @@
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -221,3 +221,79 @@
     transform.yield
   }
 }
+
+// -----
+
+#layout = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [1, 1],
+  outers_per_batch = [1, 1],
+  threads_per_outer = [4, 16],
+  elements_per_thread = [4, 1],
+
+  subgroup_basis = [1, 1],
+  thread_basis   = [4, 16]
+>
+
+// Propagate and enforce through reduction along dim 0.
+// The printing of this layout is too long for a remark. Just verify that
+// the subgroup/thread bases are what we expect.
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @reduction(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector<16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.0 : f16
+    %cst0_1 = arith.constant dense<0.0> : vector<16xf16>
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+    %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{thread_basis = [4, 16]}}
+    %root_red = vector.multi_reduction<add>, %root, %cst0_1 [0]  : vector<16x16xf16> to vector<16xf16>
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+    %c = arith.mulf %root_red, %a : vector<16xf16>
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [false, true], thread_basis = [4, 16], thread_active_ids= [false, true]}}
+    func.return %c : vector<16xf16>
+  }
+
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#layout = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [1, 1],
+  outers_per_batch = [1, 1],
+  threads_per_outer = [4, 16],
+  elements_per_thread = [4, 1],
+
+  subgroup_basis = [1, 1],
+  thread_basis   = [4, 16]
+>
+
+// Propagate and enforce through reduction along dim 1.
+// The printing of this layout is too long for a remark. Just verify that
+// the subgroup/thread bases are what we expect.
+builtin.module attributes { transform.with_named_sequence } {
+  func.func @reduction(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector<16xf16> {
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.0 : f16
+    %cst0_1 = arith.constant dense<0.0> : vector<16xf16>
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+    %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{thread_basis = [4, 16]}}
+    %root_red = vector.multi_reduction<add>, %root, %cst0_1 [1]  : vector<16x16xf16> to vector<16xf16>
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+    %c = arith.mulf %root_red, %a : vector<16xf16>
+    // expected-remark @above {{subgroup_basis = [1, 1], subgroup_active_ids= [true, false], thread_basis = [4, 16], thread_active_ids= [true, false]}}
+    func.return %c : vector<16xf16>
+  }
+
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 752e8dd..9dc6abe 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
 #include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -435,10 +436,11 @@
     applyPermutationToVector(elementOrder, permute);
   }
 
-  return NestedLayoutAttr::get(context, subgroupCount, subgroupOrder,
-                               batchCount, batchOrder, outerCount, outerOrder,
-                               threadCount, threadOrder, elementCount,
-                               elementOrder, subgroupBasis, threadBasis);
+  return NestedLayoutAttr::get(
+      context, subgroupCount, subgroupOrder, batchCount, batchOrder, outerCount,
+      outerOrder, threadCount, threadOrder, elementCount, elementOrder,
+      subgroupBasis, SmallVector<bool>(subgroupBasis.size(), true), threadBasis,
+      SmallVector<bool>(threadBasis.size(), true));
 }
 
 std::optional<std::tuple<VectorExt::VectorLayoutInterface,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index fe8b44c..6070887 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -294,7 +294,9 @@
 
     auto layout = IREE::VectorExt::NestedLayoutAttr::get(
         context, subgroupCounts, order, batchSizes, order, outerSizes, order,
-        threadCounts, order, elementSizes, order, subgroupBasis, threadBasis);
+        threadCounts, order, elementSizes, order, subgroupBasis,
+        SmallVector<bool>(subgroupBasis.size(), true), threadBasis,
+        SmallVector<bool>(threadBasis.size(), true));
     analysis.setAnchor(transfer.getResult(), layout);
     if (printLayout) {
       llvm::outs() << "transfer '" << transfer << "' vector layout: " << layout
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index bba41a0..7096682 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -282,7 +282,10 @@
     ArrayRefParameter<"int64_t", "element_order">:$elementOrder,
 
     ArrayRefParameter<"int64_t", "subgroup_basis">:$subgroupBasis,
-    ArrayRefParameter<"int64_t", "thread_basis">:$threadBasis
+    ArrayRefParameter<"bool", "subgroup_active_ids">:$subgroupActiveIds,
+
+    ArrayRefParameter<"int64_t", "thread_basis">:$threadBasis,
+    ArrayRefParameter<"bool", "thread_active_ids">:$threadActiveIds
   );
 
   // By default, identity orderings are elided when parsing/printing.
@@ -300,8 +303,8 @@
         custom<Permutation>("\"thread_order\"", ref($threadsPerOuter), "true", $threadOrder) ``
         custom<Permutation>("\"element_order\"", ref($elementsPerThread), "true", $elementOrder) ``
 
-        `subgroup_basis`          `=` `[` $subgroupBasis `]` `,`
-        `thread_basis`            `=` `[` $threadBasis `]`
+        custom<Basis>("\"subgroup_basis\"", "\"subgroup_active_ids\"", "true", $subgroupBasis, $subgroupActiveIds) ``
+        custom<Basis>("\"thread_basis\"", "\"thread_active_ids\"", "false", $threadBasis, $threadActiveIds) ``
     `>`
   }];
 
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
index 6d84b14..85adb10 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
@@ -33,7 +33,7 @@
       /*description=*/"Projects the given layout.",
       /*retTy=*/"VectorLayoutInterface",
       /*methodName=*/"project",
-      /*args=*/(ins "::llvm::ArrayRef<bool>":$projectedDims)
+      /*args=*/(ins "::llvm::ArrayRef<bool>":$droppedDims)
     >,
     InterfaceMethod<
       /*description=*/"Get the distributed shape for the given vector type.",
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index ef56793..63c52e1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -10,10 +10,13 @@
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -63,14 +66,14 @@
 
 // Project out the layout for the specified dimensions
 // resulting in the layout for a lower dimensional vector.
-VectorLayoutInterface LayoutAttr::project(ArrayRef<bool> projectedDims) const {
-  assert(projectedDims.size() == getLayouts().size() &&
-         "projectedDims size must match layout size");
+VectorLayoutInterface LayoutAttr::project(ArrayRef<bool> droppedDims) const {
+  assert(droppedDims.size() == getLayouts().size() &&
+         "droppedDims size must match layout size");
 
   ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
-  assert(projectedDims.size() == layouts.size());
+  assert(droppedDims.size() == layouts.size());
   SmallVector<PerDimLayoutAttr> newLayouts;
-  for (auto pair : llvm::zip(projectedDims, layouts)) {
+  for (auto pair : llvm::zip(droppedDims, layouts)) {
     if (!std::get<0>(pair))
       newLayouts.push_back(std::get<1>(pair));
   }
@@ -217,9 +220,86 @@
   return false;
 }
 
+// Project the nested layout. This take a mask on the dimensions of the vector
+// associated with this layout and projects out those dimensions. This reduces
+// the rank of the layout in the process.
 VectorLayoutInterface
-NestedLayoutAttr::project(ArrayRef<bool> projectedDims) const {
-  llvm_unreachable("Not yet implemented");
+NestedLayoutAttr::project(ArrayRef<bool> droppedDims) const {
+  assert(droppedDims.size() == getBatchesPerSubgroup().size() &&
+         "droppedDims size must match layout rank");
+
+  // Projection for this layout simply means the sizes along the projected
+  // are dropped.
+  SmallVector<int64_t> subgroupCount;
+  SmallVector<int64_t> batchCount;
+  SmallVector<int64_t> outerCount;
+  SmallVector<int64_t> threadCount;
+  SmallVector<int64_t> elementCount;
+  int64_t count = 0;
+  // Map to track pre-projection -> post-projection indices. Used to update
+  // the dimension orders.
+  llvm::DenseMap<int64_t, int64_t> indexToRankReducedIndexMap;
+  for (auto [idx, isProjected] : llvm::enumerate(droppedDims)) {
+    if (!isProjected) {
+      subgroupCount.push_back(getSubgroupsPerWorkgroup()[idx]);
+      batchCount.push_back(getBatchesPerSubgroup()[idx]);
+      outerCount.push_back(getOutersPerBatch()[idx]);
+      threadCount.push_back(getThreadsPerOuter()[idx]);
+      elementCount.push_back(getElementsPerThread()[idx]);
+      indexToRankReducedIndexMap[idx] = count++;
+    }
+  }
+  // This layout is invalid for rank-0 vectors.
+  assert(count >= 0 && "unimplemented rank-0 vector");
+
+  auto getRankReducedPermutation =
+      [&](ArrayRef<int64_t> perm) -> SmallVector<int64_t> {
+    SmallVector<int64_t> newPerm;
+    for (auto i : perm) {
+      if (indexToRankReducedIndexMap.contains(i)) {
+        newPerm.push_back(indexToRankReducedIndexMap[i]);
+      }
+    }
+    return newPerm;
+  };
+
+  SmallVector<int64_t> subgroupOrder =
+      getRankReducedPermutation(getSubgroupOrder());
+  SmallVector<int64_t> batchOrder = getRankReducedPermutation(getBatchOrder());
+  SmallVector<int64_t> outerOrder = getRankReducedPermutation(getOuterOrder());
+  SmallVector<int64_t> threadOrder =
+      getRankReducedPermutation(getThreadOrder());
+  SmallVector<int64_t> elementOrder =
+      getRankReducedPermutation(getElementOrder());
+
+  // Compose the projected dims with the basis mask to get the new active
+  // ids. Active ids indicates that we should use the ids marked as true, and
+  // projected dims drop the dims marked as true. So to get the new mask, we
+  // turn off all of the currently `true` ids marked as projected. For example:
+  //
+  // subgroup_active_ids = [true,  true,  false, true]
+  // projected_dims =      [false, true,         false]
+  //
+  // new_active_ids =      [true,  false, false, true]
+  auto composeMasks = [&](SmallVector<bool> &newMask, ArrayRef<bool> mask) {
+    int64_t rankReducedIdx = 0;
+    for (auto [i, active] : llvm::enumerate(newMask)) {
+      if (active) {
+        newMask[i] = !mask[rankReducedIdx];
+        rankReducedIdx++;
+      }
+    }
+  };
+  SmallVector<bool> subgroupMask(getSubgroupActiveIds());
+  SmallVector<bool> threadMask(getThreadActiveIds());
+  composeMasks(subgroupMask, droppedDims);
+  composeMasks(threadMask, droppedDims);
+
+  return NestedLayoutAttr::get(getContext(), subgroupCount, subgroupOrder,
+                               batchCount, batchOrder, outerCount, outerOrder,
+                               threadCount, threadOrder, elementCount,
+                               elementOrder, getSubgroupBasis(), subgroupMask,
+                               getThreadBasis(), threadMask);
 }
 
 VectorLayoutInterface
@@ -260,7 +340,8 @@
     ArrayRef<int64_t> outersPerBatch, ArrayRef<int64_t> outerOrder,
     ArrayRef<int64_t> threadsPerOuter, ArrayRef<int64_t> threadOrder,
     ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> elementOrder,
-    ArrayRef<int64_t> subgroupBasis, ArrayRef<int64_t> threadBasis) {
+    ArrayRef<int64_t> subgroupBasis, ArrayRef<bool> subgroupActiveIds,
+    ArrayRef<int64_t> threadBasis, ArrayRef<bool> threadActiveIds) {
 
   size_t rank = subgroupsPerWorkgroup.size();
   auto checkTile = [&](ArrayRef<int64_t> tileShape, ArrayRef<int64_t> order) {
@@ -279,9 +360,23 @@
       failed(checkTile(batchesPerSubgroup, batchOrder)) ||
       failed(checkTile(outersPerBatch, outerOrder)) ||
       failed(checkTile(threadsPerOuter, threadOrder)) ||
-      failed(checkTile(elementsPerThread, elementOrder)) ||
-      failed(checkTile(subgroupBasis, subgroupOrder)) ||
-      failed(checkTile(threadBasis, threadOrder))) {
+      failed(checkTile(elementsPerThread, elementOrder))) {
+    return failure();
+  }
+
+  auto checkBasis = [&](ArrayRef<int64_t> basis, ArrayRef<bool> mask) {
+    if (basis.size() != mask.size()) {
+      emitError() << "basis and active id mask must be the same length";
+      return failure();
+    }
+    if (llvm::count(mask, true) != rank) {
+      emitError()
+          << "number of active basis ids must be equal to the layout rank";
+    }
+    return success();
+  };
+  if (failed(checkBasis(subgroupBasis, subgroupActiveIds)) ||
+      failed(checkBasis(threadBasis, threadActiveIds))) {
     return failure();
   }
 
@@ -289,7 +384,11 @@
 }
 
 /// Given a single flat thread ID, compute the indices of the distributed
-/// dimensions (subgroup and thread ids).
+/// dimensions (subgroup and thread ids). The only difference between subgroup
+/// and thread dimensions is the order in which they are "divided out" of the
+/// underlying vector (i.e. vector_shape /= subgroups -> batches -> outers ->
+/// threads -> elements). There is no requirement that a subgroup id only
+/// spans subgroups.
 SmallVector<Value>
 NestedLayoutAttr::computeThreadIds(Value threadId,
                                    RewriterBase &rewriter) const {
@@ -324,11 +423,85 @@
   auto tileSizes = llvm::concat<const int64_t>(
       applyPermutation(getSubgroupsPerWorkgroup(), getSubgroupOrder()),
       applyPermutation(getThreadsPerOuter(), getThreadOrder()));
+  auto tileSizesIterator = tileSizes.begin();
 
-  // Modulo the delinearized subgroup/thread ids by the number of unique
-  // elements distributed to those ids.
-  for (auto [delinearized, basis, tile] :
-       llvm::zip(delinearized, basisSizes, tileSizes)) {
+  auto activeIdFilter =
+      llvm::concat<const bool>(getSubgroupActiveIds(), getThreadActiveIds());
+
+  // Modulo the active delinearized subgroup/thread ids by the number of unique
+  // elements distributed to those ids. The only difference between subgroup
+  // and thread dimensions is the order in which they are "divided out" of the
+  // underlying vector (i.e. vector_shape /= subgroups -> batches -> outers ->
+  // threads -> elements). There is no requirement that a subgroup id only
+  // spans subgroups.
+  //
+  // thread_basis = [8, 4, 2]
+  // active_thread_ids = [true, false, true]
+  // threads_per_outer = [4, 2]
+  //
+  // To obtain the thread ids, we just delinearize based on the basis.
+  //
+  // i0, i1, i2 = affine.delinearize_inds %threadId (8, 4, 2)
+  //
+  // And then to get the thread id for the layout, we only consider the active
+  // ids:
+  //
+  // layout_id0 = i0 % 4
+  // layout_id1 = i2 % 2
+  //
+  // The typical way this is used it to implicitly broadcast data across
+  // threads. For example, take a simpler case of the following:
+  //
+  // vector_shape = vector<2>
+  // thread_basis = [2, 2]
+  // active_thread_ids = [true, false]
+  // threads_per_outer = [2]
+  //
+  // If we give the two elements in the vector labels, say s0 and s1, we can
+  // see what this layout assigns as ids when doing a read of those two values
+  // across 4 threads.
+  //
+  // %id = gpu.flat_thread_id   // In range [0, 4)
+  // i0, i1 = affine.delinearize_index %id (2, 2)
+  // %id = 0, 1, 2, 3
+  // ----------------
+  // i0  = 0, 0, 1, 1
+  // i1  = 0, 1, 0, 1
+  //
+  // %0 = vector.load mem[i0]
+  //
+  // %id = 0, 1, 2, 3
+  // ----------------
+  // %0  = s0 s0 s1 s1
+  //
+  // If we instead had this layout:
+  //
+  // thread_basis = [4]
+  // active_thread_ids = [true]
+  // threads_per_outer = [2]
+  //
+  // With the modulus, we would get:
+  //
+  // %id = gpu.flat_thread_id   // In range [0, 4)
+  // i0 = %id = affine.delinearize_index %id (4)
+  // layout_i0 = i0 % 2
+  //
+  // %id        = 0, 1, 2, 3
+  // ----------------
+  // layout_i0  = 0, 1, 0, 1
+  //
+  // %0 = vector.load mem[layout_i0]
+  //
+  // %id = 0, 1, 2, 3
+  // ----------------
+  // %0  = s0 s1 s0 s1
+  for (auto [delinearized, basis, isActive] :
+       llvm::zip_equal(delinearized, basisSizes, activeIdFilter)) {
+    if (!isActive) {
+      continue;
+    }
+    int64_t tile = *tileSizesIterator;
+    tileSizesIterator++;
     if (basis == tile) {
       continue;
     }
@@ -399,6 +572,91 @@
   }
 }
 
+// Custom parser/printer for a basis (array of i64 values) and a mask (array
+// of boolean values).
+static ParseResult parseBasis(AsmParser &parser, StringRef basisName,
+                              StringRef maskName, bool parseComma,
+                              SmallVector<int64_t> &basis,
+                              SmallVector<bool> &mask) {
+  if (failed(parser.parseKeyword(basisName)) || failed(parser.parseEqual()) ||
+      failed(parser.parseLSquare())) {
+    return failure();
+  }
+  auto arrayParser = FieldParser<SmallVector<int64_t>>::parse(parser);
+  if (failed(arrayParser)) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "failed to parse basis parameter '")
+        << basisName << "' which is to be a `::llvm::ArrayRef<int64_t>`";
+  }
+  basis = *arrayParser;
+  if (parser.parseRSquare()) {
+    return failure();
+  }
+  // Optionally parse a comma between the basis and mask.
+  if (parser.parseOptionalComma()) {
+    // If we were supposed to find a comma, fail parsing.
+    if (parseComma) {
+      return failure();
+    }
+    // If it was fine not to find a comma, set the mask. If the comma was
+    // missing this will fail to parse the closing angle bracket.
+    mask = SmallVector<bool>(basis.size(), true);
+    return success();
+  }
+  // There is a comma, meaning we either must find the mask, or we shouldn't
+  // have expected a comma.
+  if (failed(parser.parseOptionalKeyword(maskName))) {
+    if (!parseComma) {
+      return failure();
+    }
+    mask = SmallVector<bool>(basis.size(), true);
+    return success();
+  }
+
+  if (failed(parser.parseEqual()) || failed(parser.parseLSquare())) {
+    return failure();
+  }
+  auto maskParser = FieldParser<SmallVector<bool>>::parse(parser);
+  if (failed(maskParser)) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "failed to parse mask parameter '")
+        << maskName << "' which is to be a `::llvm::ArrayRef<bool>`";
+  }
+  if (failed(parser.parseRSquare()) ||
+      (parseComma && failed(parser.parseComma()))) {
+    return failure();
+  }
+  mask = *maskParser;
+
+  return success();
+}
+
+static void printBasis(AsmPrinter &p, StringRef basisName, StringRef maskName,
+                       bool printComma, ArrayRef<int64_t> basis,
+                       ArrayRef<bool> mask) {
+  p << basisName;
+  // This is called without whitespace inserted by default for optionality.
+  // Insert it explicitly instead.
+  p << ' ';
+  p << '=';
+  p << ' ';
+  p << '[';
+  llvm::interleaveComma(basis, p);
+  p << ']';
+  if (llvm::any_of(mask, [](bool b) { return !b; })) {
+    p << ',' << ' ';
+    p << maskName;
+    p << '=';
+    p << ' ';
+    p << '[';
+    llvm::interleaveComma(mask, p);
+    p << ']';
+  }
+  if (printComma) {
+    p << ',' << ' ';
+  }
+}
+
 } // namespace mlir::iree_compiler::IREE::VectorExt
 
 using namespace mlir::iree_compiler::IREE::VectorExt;
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
index 552485c..d284b44 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir
@@ -43,3 +43,18 @@
   %simd = iree_vector_ext.to_simt %simt : vector<64xf32> -> vector<2x2xf16>
   func.return %simd : vector<2x2xf16>
 }
+
+// -----
+
+// expected-error @+1 {{number of active basis ids must be equal to the layout rank}}
+#layout = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1],
+  batches_per_subgroup = [1],
+  outers_per_batch = [1],
+  threads_per_outer = [1],
+  elements_per_thread = [1],
+
+  subgroup_basis = [2],
+  subgroup_active_ids = [false],
+  thread_basis   = [2]
+>
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 844000c..6e60022 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -54,19 +54,84 @@
   thread_basis   = [2, 4]
 >
 
+#nested_3 = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [4, 2],
+  outers_per_batch = [1, 4],
+  threads_per_outer = [2, 4],
+  elements_per_thread = [4, 1],
+
+  subgroup_order = [1, 0],
+  batch_order = [1, 0],
+  thread_order = [1, 0],
+  element_order = [1, 0],
+
+  subgroup_basis = [2, 4, 8],
+  subgroup_active_ids = [true, true, false],
+  thread_basis   = [2, 4]
+>
+
+#nested_4 = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [4, 2],
+  outers_per_batch = [1, 4],
+  threads_per_outer = [2, 4],
+  elements_per_thread = [4, 1],
+
+  subgroup_order = [1, 0],
+  batch_order = [1, 0],
+  thread_order = [1, 0],
+  element_order = [1, 0],
+
+  subgroup_basis = [2, 4, 8],
+  subgroup_active_ids = [true, true, false],
+  thread_basis   = [2, 4, 2],
+  thread_active_ids = [false, true, true]
+>
+
+#nested_5 = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [4, 2],
+  outers_per_batch = [1, 4],
+  threads_per_outer = [2, 4],
+  elements_per_thread = [4, 1],
+
+  subgroup_order = [1, 0],
+  batch_order = [1, 0],
+  thread_order = [1, 0],
+  element_order = [1, 0],
+
+  subgroup_basis = [2, 4],
+  subgroup_active_ids = [true, true],
+  thread_basis   = [4, 2],
+  thread_active_ids = [true, true]
+>
+
 func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
   %cst_0 = arith.constant 0.0 : f16
   %c0 = arith.constant 0 : index
   %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
-  %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #nested_1, desiredLayout = #nested_2} : vector<32x32xf16> -> vector<32x32xf16>
+  %2 = iree_vector_ext.layout_conflict_resolution %result {
+    sourceLayout = #nested_1,
+    desiredLayout = #nested_2,
+    otherLayout0 = #nested_3,
+    otherLayout1 = #nested_4,
+    otherLayout2 = #nested_5
+  } : vector<32x32xf16> -> vector<32x32xf16>
   return %2 : vector<32x32xf16>
 }
 
 // CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [2, 4]>
 // CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], outer_order = [1, 0], subgroup_basis = [1, 1], thread_basis = [4, 2]>
+// CHECK-DAG: #[[LAYOUT2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4, 8], subgroup_active_ids= [true, true, false], thread_basis = [2, 4]>
+// CHECK-DAG: #[[LAYOUT3:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4, 8], subgroup_active_ids= [true, true, false], thread_basis = [2, 4, 2], thread_active_ids= [false, true, true]>
+// CHECK-DAG: #[[LAYOUT4:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], subgroup_basis = [2, 4], thread_basis = [4, 2]>
 // CHECK-LABEL: func.func @specify_nested
 // CHECK:      iree_vector_ext.layout_conflict_resolution
 // CHECK-SAME:         desiredLayout = #[[LAYOUT0]]
+// CHECK-SAME:         otherLayout0 = #[[LAYOUT2]]
+// CHECK-SAME:         otherLayout1 = #[[LAYOUT3]]
+// CHECK-SAME:         otherLayout2 = #[[LAYOUT4]]
 // CHECK-SAME:         sourceLayout = #[[LAYOUT1]]
 
 // -----