[VectorExt] Add LayoutV2 supporting warp distribution and rank > 2 (#16368)
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 8dfa939..82ea1fb 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -733,6 +733,7 @@
":IREEVectorExtIncGen",
":IREEVectorExtInterfacesIncGen",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
index 285f862..05e331c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -99,5 +99,202 @@
}];
}
-#endif // IREE_DIALECT_VECTOREXT_ATTRS
+def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
+ [ DeclareAttrInterfaceMethods<VectorLayoutInterface> ]> {
+ let mnemonic = "nested_layout";
+ let summary = [{A layout representing a mapping from GPU thread hierarchy to a shape}];
+ let description = [{
+ This layout explicitly defines how a shape is mapped to a compute
+ hierarchy. We consider the following levels of hierarchy, inspired by GPUs:
+ 1. Subgroups per Workgroup
+ 2. Threads per Subgroup
+ 3. Elements per Thread
+
+ Conceptually, each higher level of hierarchy can be viewed as multiple
+ tiles of the lower level of hierarchy; each lower level of hierarchy is
+ nested in the higher level of hierarchy. The last level represents the
+ final elements in memory.
+
+ The conceptual mapping is leveraged during compilation for tiling and
+ distributing to hardware for parallel computation. Concretely, the mapping
+ is done on each dimension of the original vector shape. For example, for
+ vector shape 16x16x16, we have 3 dimensions, so at each level of the
+ hierarchy we would have 3 tile sizes. Similarly for vector shape 32x32, we
+ would have 2-D tile sizes per compute hierarchy level.
+
+ We now describe each level of tiling. Each level of tiling represents a
+ count of tiles over the next level (rather than a list of tile sizes) and
+ an ordering over the tiles:
+
+ 1. Subgroups per Workgroup
+
+ This level of tiling is also known as "subgroup/warp distribution". It
+ represents how subgroups are distributed in a workgroup.
+
+ The subgroups are placed contiguously with their shape and ordering
+ determined by:
+ - `subgroups_per_workgroup`: Sizes of this level of tiling
+ - `subgroup_order`: Ordering of dimensions, from outermost to innermost
+
+ For example, subgroups_per_workgroup=[4, 2], subgroup_order=[1, 0] will
+ arrange the subgroups in the order:
+
+ 0 4
+ 1 5
+ 2 6
+ 3 7
+
+ The total number of subgroups used (computed by multiplying each dim in
+ subgroups_per_workgroup) should be a multiple of number of subgroups in the
+ harware. If the total number of subgroups used exceeds the number of
+ subgroups of the hardware, then the subgroup used (say x) is
+ x mod num_subgroups:
+
+ num_subgroups = 4
+
+ 0 4 0 0
+ 1 5 x mod 4 1 1
+ 2 6 -------> 2 2
+ 3 7 3 3
+
+ 2. Threads per Subgroup:
+
+ Threads in a subgroup are distributed in three levels.
+
+ The first level, batches, are a way to represent instruction unrolling. For
+ example, an intrinsic which can only take 4x4 shape at a time, uses batches
+ to unroll a 16x16 shape to the native intrinsice shape.
+
+ Batches can be thought of as loops around the original layout:
+
+ for b_0 in range(batch_0):
+ for b_1 in range(batch_1):
+ ...
+
+ Batches are represented using two attributes:
+ - batches_per_subgroup: Ranges of each loop
+ - batch_order: Ordering of each loop, from outermost to innermost
+
+ The second level, outers, is a way to represent thread layout duplication
+ required by a particular intrinsic. For example, some AMDGPU matrix
+ multiplication variants require threads to be distributed
+ like:
+
+ 0 1 2 3 4
+ 5 6 7 8 9
+ --------- --> Thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
+ 0 1 2 3 4 outers_per_batch=[2, 1]
+ 5 6 7 8 9 threads_per_outer=[2, 5]
+
+ Outers is represented using two attributes:
+ - outers_per_batch: Number of outers in a batch
+ - outer_order: Ordering of outers, from outermost to innermost
+
+ Finally, threads are distributed in a single outer. The thread
+ distribution is represented by:
+
+ - threads_per_outer: Sizes of this level of tiling
+ - thread_order: Ordering of dimensions, from outermost to innermost
+
+ Examples of thread distribution over a 8x4 shape:
+
+ {
+ batches_per_subgroup = [2, 1]
+ outers_per_batch = [2, 2]
+ threads_per_outer = [2, 2]
+
+ batch_order = [0, 1]
+ outer_order = [0, 1]
+ thread_order = [1, 0]
+ }
+
+ Distributed tile:
+
+ {
+ [0 2]|[0 2] 0,1,2,3 --> thread ids
+ [1 3]|[1 3]
+ ------------ [x z] --> a single outer tile
+ [0 2]|[0 2] [y w]
+ [1 3]|[1 3]
+ }{
+ [0 2]|[0 2] { ... } --> a single batch tile
+ [1 3]|[1 3]
+ ------------
+ [0 2]|[0 2]
+ [1 3]|[1 3]
+ }
+
+ So, the thread distribution looks like:
+
+ [0 2 0 2]
+ [1 3 1 3]
+ [0 2 0 2]
+ [1 3 1 3]
+ [0 2 0 2]
+ [1 3 1 3]
+ [0 2 0 2]
+ [1 3 1 3]
+
+ The total number of threads used (computed by multiplying each dim in
+ threads_per_outer) should be a multiple of subgroup size of the
+ harware. If the total number of threads used exceeds the subgroup size of
+ the hardware, then the threads used (say tid) is tid mod subgroup_size:
+
+ subgroup_size = 4
+
+ 0 1 0 0
+ 2 3 tid mod 4 1 1
+ 4 5 --------> 2 2
+ 6 7 3 3
+
+ 3. Elements per Thread
+
+ The final level of tiling, representing the minimum shape of vector that
+ is treated as an atom.
+
+ The elements are placed contigiously with their shape and ordering
+ determined by:
+ - `elements_per_thread`: Sizes of this level of tiling
+ - `element_order`: Ordering of dimensions, from outermost to innermost
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t", "subgroups_per_workgroup">:$subgroupsPerWorkgroup,
+ ArrayRefParameter<"int64_t", "subgroup_order">:$subgroupOrder,
+
+ ArrayRefParameter<"int64_t", "batches_per_subgroup">:$batchesPerSubgroup,
+ ArrayRefParameter<"int64_t", "batch_order">:$batchOrder,
+
+ ArrayRefParameter<"int64_t", "outers_per_batch">:$outersPerBatch,
+ ArrayRefParameter<"int64_t", "outer_order">:$outerOrder,
+
+ ArrayRefParameter<"int64_t", "threads_per_outer">:$threadsPerOuter,
+ ArrayRefParameter<"int64_t", "thread_order">:$threadOrder,
+
+ ArrayRefParameter<"int64_t", "elements_per_thread">:$elementsPerThread,
+ ArrayRefParameter<"int64_t", "element_order">:$elementOrder
+ );
+
+ // TODO: add custom parser/printer and builder to elide default value array
+ // refs.
+ let assemblyFormat = [{
+ `<` `subgroups_per_workgroup` `=` `[` $subgroupsPerWorkgroup `]` `,`
+ `batches_per_subgroup` `=` `[` $batchesPerSubgroup `]` `,`
+ `outers_per_batch` `=` `[` $outersPerBatch `]` `,`
+ `threads_per_outer` `=` `[` $threadsPerOuter `]` `,`
+ `elements_per_thread` `=` `[` $elementsPerThread `]` `,`
+
+ `subgroup_order` `=` `[` $subgroupOrder `]` `,`
+ `batch_order` `=` `[` $batchOrder `]` `,`
+ `outer_order` `=` `[` $outerOrder `]` `,`
+ `thread_order` `=` `[` $threadOrder `]` `,`
+ `element_order` `=` `[` $elementOrder `]`
+ `>`
+ }];
+
+ let genVerifyDecl = 1;
+}
+
+
+#endif // IREE_DIALECT_VECTOREXT_ATTRS
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index ea5969c..3f9b0c8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -6,6 +6,7 @@
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
@@ -195,6 +196,68 @@
return offset;
}
+VectorLayoutInterface
+NestedLayoutAttr::project(ArrayRef<bool> projectedDims) const {
+ llvm_unreachable("Not yet implemented");
+}
+
+VectorLayoutInterface
+NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const {
+ llvm_unreachable("Not yet implemented");
+}
+
+SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
+ llvm_unreachable("Not yet implemented");
+}
+
+bool NestedLayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const {
+ // Multiply all shapes in the layout.
+ for (int i = 0, e = shape.size(); i < e; ++i) {
+ int64_t expectedShape = getSubgroupsPerWorkgroup()[i] *
+ getBatchesPerSubgroup()[i] *
+ getOutersPerBatch()[i] * getThreadsPerOuter()[i] *
+ getElementsPerThread()[i];
+ if (expectedShape != shape[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// TODO: These things should ideally go into the parser when we have a custom
+// parser.
+LogicalResult NestedLayoutAttr::verify(
+ llvm::function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> subgroupsPerWorkgroup, ArrayRef<int64_t> subgroupOrder,
+ ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> batchOrder,
+ ArrayRef<int64_t> outersPerBatch, ArrayRef<int64_t> outerOrder,
+ ArrayRef<int64_t> threadsPerOuter, ArrayRef<int64_t> threadOrder,
+ ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> elementOrder) {
+
+ size_t rank = subgroupsPerWorkgroup.size();
+ auto checkTile = [&](ArrayRef<int64_t> tileShape, ArrayRef<int64_t> order) {
+ if (tileShape.size() != rank || order.size() != rank) {
+ emitError() << "all tiles must have the same rank as the layout";
+ return failure();
+ }
+ if (!mlir::isPermutationVector(order)) {
+ emitError() << "all orderings must be permutation vectors";
+ return failure();
+ }
+ return success();
+ };
+
+ if (failed(checkTile(subgroupsPerWorkgroup, subgroupOrder)) ||
+ failed(checkTile(batchesPerSubgroup, batchOrder)) ||
+ failed(checkTile(outersPerBatch, outerOrder)) ||
+ failed(checkTile(threadsPerOuter, threadOrder)) ||
+ failed(checkTile(elementsPerThread, elementOrder))) {
+ return failure();
+ }
+
+ return success();
+}
+
} // namespace mlir::iree_compiler::IREE::VectorExt
using namespace mlir::iree_compiler::IREE::VectorExt;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
index 904d35a..213f57b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -27,6 +27,10 @@
os << "layout";
return AliasResult::OverridableAlias;
}
+ if (llvm::isa<NestedLayoutAttr>(attr)) {
+ os << "nested";
+ return AliasResult::OverridableAlias;
+ }
return AliasResult::NoAlias;
}
};
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
index 495286f..e4ee93a 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -21,6 +21,51 @@
// -----
+#nested_1 = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [2, 4],
+ outers_per_batch = [4, 1],
+ threads_per_outer = [4, 2],
+ elements_per_thread = [1, 4],
+
+ subgroup_order = [0, 1],
+ batch_order = [0, 1],
+ outer_order = [1, 0],
+ thread_order = [0, 1],
+ element_order = [0, 1]
+>
+
+#nested_2 = #iree_vector_ext.nested_layout<
+ subgroups_per_workgroup = [1, 1],
+ batches_per_subgroup = [4, 2],
+ outers_per_batch = [1, 4],
+ threads_per_outer = [2, 4],
+ elements_per_thread = [4, 1],
+
+ subgroup_order = [1, 0],
+ batch_order = [1, 0],
+ outer_order = [0, 1],
+ thread_order = [1, 0],
+ element_order = [1, 0]
+>
+
+func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
+ %cst_0 = arith.constant 0.0 : f16
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
+ %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #nested_1, desiredLayout = #nested_2} : vector<32x32xf16> -> vector<32x32xf16>
+ return %2 : vector<32x32xf16>
+}
+
+// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [0, 1], thread_order = [1, 0], element_order = [1, 0]>
+// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [1, 0], thread_order = [0, 1], element_order = [0, 1]>
+// CHECK-LABEL: func.func @specify_nested
+// CHECK: iree_vector_ext.layout_conflict_resolution
+// CHECK-SAME: desiredLayout = #[[LAYOUT0]]
+// CHECK-SAME: sourceLayout = #[[LAYOUT1]]
+
+// -----
+
func.func @to_simd_op(%simt: vector<4x4x4xf16>) -> vector<64x64xf16> {
%simd = iree_vector_ext.to_simd %simt : vector<4x4x4xf16> -> vector<64x64xf16>
func.return %simd : vector<64x64xf16>