[VectorExt] Add LayoutV2 supporting warp distribution and rank > 2 (#16368)

diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 8dfa939..82ea1fb 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -733,6 +733,7 @@
         ":IREEVectorExtIncGen",
         ":IREEVectorExtInterfacesIncGen",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
     ],
 )
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index 285f862..05e331c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -99,5 +99,202 @@
   }];
 }
 
-#endif // IREE_DIALECT_VECTOREXT_ATTRS
+def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
+      [ DeclareAttrInterfaceMethods<VectorLayoutInterface> ]> {
+  let mnemonic = "nested_layout";
+  let summary = [{A layout representing a mapping from GPU thread hierarchy to a shape}];
+  let description = [{
+    This layout explicitly defines how a shape is mapped to a compute
+    hierarchy. We consider the following levels of hierarchy, inspired by GPUs:
 
+    1. Subgroups per Workgroup
+    2. Threads per Subgroup
+    3. Elements per Thread
+
+    Conceptually, each higher level of hierarchy can be viewed as multiple
+    tiles of the lower level of hierarchy; each lower level of hierarchy is
+    nested in the higher level of hierarchy. The last level represents the
+    final elements in memory.
+
+    The conceptual mapping is leveraged during compilation for tiling and
+    distributing to hardware for parallel computation. Concretely, the mapping
+    is done on each dimension of the original vector shape. For example, for
+    vector shape 16x16x16, we have 3 dimensions, so at each level of the
+    hierarchy we would have 3 tile sizes. Similarly for vector shape 32x32, we
+    would have 2-D tile sizes per compute hierarchy level.
+
+    We now describe each level of tiling. Each level of tiling represents a
+    count of tiles over the next level (rather than a list of tile sizes) and
+    an ordering over the tiles:
+
+    1. Subgroups per Workgroup
+
+    This level of tiling is also known as "subgroup/warp distribution". It
+    represents how subgroups are distributed in a workgroup. 
+
+    The subgroups are placed contiguously with their shape and ordering
+    determined by:
+      - `subgroups_per_workgroup`: Sizes of this level of tiling
+      - `subgroup_order`: Ordering of dimensions, from outermost to innermost
+
+    For example, subgroups_per_workgroup=[4, 2], subgroup_order=[1, 0] will
+    arrange the subgroups in the order:
+
+    0 4
+    1 5
+    2 6
+    3 7
+
+    The total number of subgroups used (computed by multiplying each dim in
+    subgroups_per_workgroup) should be a multiple of number of subgroups in the
+    harware. If the total number of subgroups used exceeds the number of
+    subgroups of the hardware, then the subgroup used (say x) is 
+    x mod num_subgroups:
+
+    num_subgroups = 4
+
+    0 4               0 0
+    1 5    x mod 4    1 1
+    2 6    ------->   2 2
+    3 7               3 3
+
+    2. Threads per Subgroup:
+
+    Threads in a subgroup are distributed in three levels.
+
+    The first level, batches, are a way to represent instruction unrolling. For
+    example, an intrinsic which can only take 4x4 shape at a time, uses batches
+    to unroll a 16x16 shape to the native intrinsice shape.
+
+    Batches can be thought of as loops around the original layout:
+
+    for b_0 in range(batch_0):
+      for b_1 in range(batch_1):
+        ...
+
+    Batches are represented using two attributes:
+      - batches_per_subgroup: Ranges of each loop
+      - batch_order: Ordering of each loop, from outermost to innermost
+
+    The second level, outers, is a way to represent thread layout duplication
+    required by a particular intrinsic. For example, some AMDGPU matrix
+    multiplication variants require threads to be distributed
+    like:
+
+    0 1 2 3 4
+    5 6 7 8 9
+    --------- --> Thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5 
+    0 1 2 3 4     outers_per_batch=[2, 1]
+    5 6 7 8 9     threads_per_outer=[2, 5]
+
+    Outers is represented using two attributes:
+      - outers_per_batch: Number of outers in a batch
+      - outer_order: Ordering of outers, from outermost to innermost
+
+    Finally, threads are distributed in a single outer. The thread
+    distribution is represented by:
+
+      - threads_per_outer: Sizes of this level of tiling
+      - thread_order: Ordering of dimensions, from outermost to innermost
+
+    Examples of thread distribution over a 8x4 shape:
+
+    {
+      batches_per_subgroup = [2, 1]
+      outers_per_batch = [2, 2]
+      threads_per_outer = [2, 2]
+
+      batch_order = [0, 1]
+      outer_order = [0, 1]
+      thread_order = [1, 0]
+    }
+
+    Distributed tile:
+
+    {
+      [0 2]|[0 2]      0,1,2,3 --> thread ids
+      [1 3]|[1 3]      
+      ------------     [x z]   --> a single outer tile
+      [0 2]|[0 2]      [y w]
+      [1 3]|[1 3]
+    }{
+      [0 2]|[0 2]      { ... } --> a single batch tile
+      [1 3]|[1 3]
+      ------------
+      [0 2]|[0 2]
+      [1 3]|[1 3]
+    }
+
+    So, the thread distribution looks like:
+
+    [0 2 0 2]
+    [1 3 1 3]
+    [0 2 0 2]
+    [1 3 1 3]
+    [0 2 0 2]
+    [1 3 1 3]
+    [0 2 0 2]
+    [1 3 1 3]
+
+    The total number of threads used (computed by multiplying each dim in
+    threads_per_outer) should be a multiple of subgroup size of the
+    harware. If the total number of threads used exceeds the subgroup size of
+    the hardware, then the threads used (say tid) is tid mod subgroup_size:
+
+    subgroup_size = 4
+
+    0 1                0 0
+    2 3    tid mod 4   1 1
+    4 5    -------->   2 2
+    6 7                3 3
+
+    3. Elements per Thread
+
+    The final level of tiling, representing the minimum shape of vector that
+    is treated as an atom.
+
+    The elements are placed contigiously with their shape and ordering
+    determined by:
+      - `elements_per_thread`: Sizes of this level of tiling
+      - `element_order`: Ordering of dimensions, from outermost to innermost
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t", "subgroups_per_workgroup">:$subgroupsPerWorkgroup,
+    ArrayRefParameter<"int64_t", "subgroup_order">:$subgroupOrder,
+
+    ArrayRefParameter<"int64_t", "batches_per_subgroup">:$batchesPerSubgroup,
+    ArrayRefParameter<"int64_t", "batch_order">:$batchOrder,
+
+    ArrayRefParameter<"int64_t", "outers_per_batch">:$outersPerBatch,
+    ArrayRefParameter<"int64_t", "outer_order">:$outerOrder,
+
+    ArrayRefParameter<"int64_t", "threads_per_outer">:$threadsPerOuter,
+    ArrayRefParameter<"int64_t", "thread_order">:$threadOrder,
+
+    ArrayRefParameter<"int64_t", "elements_per_thread">:$elementsPerThread,
+    ArrayRefParameter<"int64_t", "element_order">:$elementOrder
+  );
+
+  // TODO: add custom parser/printer and builder to elide default value array
+  // refs.
+  let assemblyFormat = [{
+    `<` `subgroups_per_workgroup` `=` `[` $subgroupsPerWorkgroup `]` `,`
+        `batches_per_subgroup`    `=` `[` $batchesPerSubgroup `]` `,`
+        `outers_per_batch`        `=` `[` $outersPerBatch `]` `,`
+        `threads_per_outer`       `=` `[` $threadsPerOuter `]` `,`
+        `elements_per_thread`     `=` `[` $elementsPerThread `]` `,`
+
+        `subgroup_order`          `=` `[` $subgroupOrder `]` `,`
+        `batch_order`             `=` `[` $batchOrder `]` `,`
+        `outer_order`             `=` `[` $outerOrder `]` `,`
+        `thread_order`            `=` `[` $threadOrder `]` `,`
+        `element_order`           `=` `[` $elementOrder `]` 
+    `>`
+  }];
+
+  let genVerifyDecl = 1;
+}
+
+
+#endif // IREE_DIALECT_VECTOREXT_ATTRS
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index ea5969c..3f9b0c8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -6,6 +6,7 @@
 
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
 #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <numeric>
@@ -195,6 +196,68 @@
   return offset;
 }
 
+VectorLayoutInterface
+NestedLayoutAttr::project(ArrayRef<bool> projectedDims) const {
+  llvm_unreachable("Not yet implemented");
+}
+
+VectorLayoutInterface
+NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const {
+  llvm_unreachable("Not yet implemented");
+}
+
+SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
+  llvm_unreachable("Not yet implemented");
+}
+
+bool NestedLayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const {
+  // Multiply all shapes in the layout.
+  for (int i = 0, e = shape.size(); i < e; ++i) {
+    int64_t expectedShape = getSubgroupsPerWorkgroup()[i] *
+                            getBatchesPerSubgroup()[i] *
+                            getOutersPerBatch()[i] * getThreadsPerOuter()[i] *
+                            getElementsPerThread()[i];
+    if (expectedShape != shape[i]) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// TODO: These things should ideally go into the parser when we have a custom
+// parser.
+LogicalResult NestedLayoutAttr::verify(
+    llvm::function_ref<InFlightDiagnostic()> emitError,
+    ArrayRef<int64_t> subgroupsPerWorkgroup, ArrayRef<int64_t> subgroupOrder,
+    ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> batchOrder,
+    ArrayRef<int64_t> outersPerBatch, ArrayRef<int64_t> outerOrder,
+    ArrayRef<int64_t> threadsPerOuter, ArrayRef<int64_t> threadOrder,
+    ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> elementOrder) {
+
+  size_t rank = subgroupsPerWorkgroup.size();
+  auto checkTile = [&](ArrayRef<int64_t> tileShape, ArrayRef<int64_t> order) {
+    if (tileShape.size() != rank || order.size() != rank) {
+      emitError() << "all tiles must have the same rank as the layout";
+      return failure();
+    }
+    if (!mlir::isPermutationVector(order)) {
+      emitError() << "all orderings must be permutation vectors";
+      return failure();
+    }
+    return success();
+  };
+
+  if (failed(checkTile(subgroupsPerWorkgroup, subgroupOrder)) ||
+      failed(checkTile(batchesPerSubgroup, batchOrder)) ||
+      failed(checkTile(outersPerBatch, outerOrder)) ||
+      failed(checkTile(threadsPerOuter, threadOrder)) ||
+      failed(checkTile(elementsPerThread, elementOrder))) {
+    return failure();
+  }
+
+  return success();
+}
+
 } // namespace mlir::iree_compiler::IREE::VectorExt
 
 using namespace mlir::iree_compiler::IREE::VectorExt;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
index 904d35a..213f57b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -27,6 +27,10 @@
       os << "layout";
       return AliasResult::OverridableAlias;
     }
+    if (llvm::isa<NestedLayoutAttr>(attr)) {
+      os << "nested";
+      return AliasResult::OverridableAlias;
+    }
     return AliasResult::NoAlias;
   }
 };
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 495286f..e4ee93a 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -21,6 +21,51 @@
 
 // -----
 
+#nested_1 = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [2, 4],
+  outers_per_batch = [4, 1],
+  threads_per_outer = [4, 2],
+  elements_per_thread = [1, 4],
+
+  subgroup_order = [0, 1],
+  batch_order = [0, 1],
+  outer_order = [1, 0],
+  thread_order = [0, 1],
+  element_order = [0, 1]
+>
+
+#nested_2 = #iree_vector_ext.nested_layout<
+  subgroups_per_workgroup = [1, 1],
+  batches_per_subgroup = [4, 2],
+  outers_per_batch = [1, 4],
+  threads_per_outer = [2, 4],
+  elements_per_thread = [4, 1],
+
+  subgroup_order = [1, 0],
+  batch_order = [1, 0],
+  outer_order = [0, 1],
+  thread_order = [1, 0],
+  element_order = [1, 0]
+>
+
+func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
+  %cst_0 = arith.constant 0.0 : f16
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
+  %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #nested_1, desiredLayout = #nested_2} : vector<32x32xf16> -> vector<32x32xf16>
+  return %2 : vector<32x32xf16>
+}
+
+// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [0, 1], thread_order = [1, 0], element_order = [1, 0]>
+// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [1, 0], thread_order = [0, 1], element_order = [0, 1]>
+// CHECK-LABEL: func.func @specify_nested
+// CHECK:      iree_vector_ext.layout_conflict_resolution
+// CHECK-SAME:         desiredLayout = #[[LAYOUT0]]
+// CHECK-SAME:         sourceLayout = #[[LAYOUT1]]
+
+// -----
+
 func.func @to_simd_op(%simt: vector<4x4x4xf16>) -> vector<64x64xf16> {
   %simd = iree_vector_ext.to_simd %simt : vector<4x4x4xf16> -> vector<64x64xf16>
   func.return %simd : vector<64x64xf16>