[VectorExt] Add LayoutV2 supporting warp distribution and rank > 2 (#16368)
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index 8dfa939..82ea1fb 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -733,6 +733,7 @@ ":IREEVectorExtIncGen", ":IREEVectorExtInterfacesIncGen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", ], )
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td index 285f862..05e331c 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -99,5 +99,202 @@ }]; } -#endif // IREE_DIALECT_VECTOREXT_ATTRS +def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout", + [ DeclareAttrInterfaceMethods<VectorLayoutInterface> ]> { + let mnemonic = "nested_layout"; + let summary = [{A layout representing a mapping from GPU thread hierarchy to a shape}]; + let description = [{ + This layout explicitly defines how a shape is mapped to a compute + hierarchy. We consider the following levels of hierarchy, inspired by GPUs: + 1. Subgroups per Workgroup + 2. Threads per Subgroup + 3. Elements per Thread + + Conceptually, each higher level of hierarchy can be viewed as multiple + tiles of the lower level of hierarchy; each lower level of hierarchy is + nested in the higher level of hierarchy. The last level represents the + final elements in memory. + + The conceptual mapping is leveraged during compilation for tiling and + distributing to hardware for parallel computation. Concretely, the mapping + is done on each dimension of the original vector shape. For example, for + vector shape 16x16x16, we have 3 dimensions, so at each level of the + hierarchy we would have 3 tile sizes. Similarly for vector shape 32x32, we + would have 2-D tile sizes per compute hierarchy level. + + We now describe each level of tiling. Each level of tiling represents a + count of tiles over the next level (rather than a list of tile sizes) and + an ordering over the tiles: + + 1. Subgroups per Workgroup + + This level of tiling is also known as "subgroup/warp distribution". It + represents how subgroups are distributed in a workgroup. + + The subgroups are placed contiguously with their shape and ordering + determined by: + - `subgroups_per_workgroup`: Sizes of this level of tiling + - `subgroup_order`: Ordering of dimensions, from outermost to innermost + + For example, subgroups_per_workgroup=[4, 2], subgroup_order=[1, 0] will + arrange the subgroups in the order: + + 0 4 + 1 5 + 2 6 + 3 7 + + The total number of subgroups used (computed by multiplying each dim in + subgroups_per_workgroup) should be a multiple of number of subgroups in the + harware. If the total number of subgroups used exceeds the number of + subgroups of the hardware, then the subgroup used (say x) is + x mod num_subgroups: + + num_subgroups = 4 + + 0 4 0 0 + 1 5 x mod 4 1 1 + 2 6 -------> 2 2 + 3 7 3 3 + + 2. Threads per Subgroup: + + Threads in a subgroup are distributed in three levels. + + The first level, batches, are a way to represent instruction unrolling. For + example, an intrinsic which can only take 4x4 shape at a time, uses batches + to unroll a 16x16 shape to the native intrinsice shape. + + Batches can be thought of as loops around the original layout: + + for b_0 in range(batch_0): + for b_1 in range(batch_1): + ... + + Batches are represented using two attributes: + - batches_per_subgroup: Ranges of each loop + - batch_order: Ordering of each loop, from outermost to innermost + + The second level, outers, is a way to represent thread layout duplication + required by a particular intrinsic. For example, some AMDGPU matrix + multiplication variants require threads to be distributed + like: + + 0 1 2 3 4 + 5 6 7 8 9 + --------- --> Thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5 + 0 1 2 3 4 outers_per_batch=[2, 1] + 5 6 7 8 9 threads_per_outer=[2, 5] + + Outers is represented using two attributes: + - outers_per_batch: Number of outers in a batch + - outer_order: Ordering of outers, from outermost to innermost + + Finally, threads are distributed in a single outer. The thread + distribution is represented by: + + - threads_per_outer: Sizes of this level of tiling + - thread_order: Ordering of dimensions, from outermost to innermost + + Examples of thread distribution over a 8x4 shape: + + { + batches_per_subgroup = [2, 1] + outers_per_batch = [2, 2] + threads_per_outer = [2, 2] + + batch_order = [0, 1] + outer_order = [0, 1] + thread_order = [1, 0] + } + + Distributed tile: + + { + [0 2]|[0 2] 0,1,2,3 --> thread ids + [1 3]|[1 3] + ------------ [x z] --> a single outer tile + [0 2]|[0 2] [y w] + [1 3]|[1 3] + }{ + [0 2]|[0 2] { ... } --> a single batch tile + [1 3]|[1 3] + ------------ + [0 2]|[0 2] + [1 3]|[1 3] + } + + So, the thread distribution looks like: + + [0 2 0 2] + [1 3 1 3] + [0 2 0 2] + [1 3 1 3] + [0 2 0 2] + [1 3 1 3] + [0 2 0 2] + [1 3 1 3] + + The total number of threads used (computed by multiplying each dim in + threads_per_outer) should be a multiple of subgroup size of the + harware. If the total number of threads used exceeds the subgroup size of + the hardware, then the threads used (say tid) is tid mod subgroup_size: + + subgroup_size = 4 + + 0 1 0 0 + 2 3 tid mod 4 1 1 + 4 5 --------> 2 2 + 6 7 3 3 + + 3. Elements per Thread + + The final level of tiling, representing the minimum shape of vector that + is treated as an atom. + + The elements are placed contigiously with their shape and ordering + determined by: + - `elements_per_thread`: Sizes of this level of tiling + - `element_order`: Ordering of dimensions, from outermost to innermost + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t", "subgroups_per_workgroup">:$subgroupsPerWorkgroup, + ArrayRefParameter<"int64_t", "subgroup_order">:$subgroupOrder, + + ArrayRefParameter<"int64_t", "batches_per_subgroup">:$batchesPerSubgroup, + ArrayRefParameter<"int64_t", "batch_order">:$batchOrder, + + ArrayRefParameter<"int64_t", "outers_per_batch">:$outersPerBatch, + ArrayRefParameter<"int64_t", "outer_order">:$outerOrder, + + ArrayRefParameter<"int64_t", "threads_per_outer">:$threadsPerOuter, + ArrayRefParameter<"int64_t", "thread_order">:$threadOrder, + + ArrayRefParameter<"int64_t", "elements_per_thread">:$elementsPerThread, + ArrayRefParameter<"int64_t", "element_order">:$elementOrder + ); + + // TODO: add custom parser/printer and builder to elide default value array + // refs. + let assemblyFormat = [{ + `<` `subgroups_per_workgroup` `=` `[` $subgroupsPerWorkgroup `]` `,` + `batches_per_subgroup` `=` `[` $batchesPerSubgroup `]` `,` + `outers_per_batch` `=` `[` $outersPerBatch `]` `,` + `threads_per_outer` `=` `[` $threadsPerOuter `]` `,` + `elements_per_thread` `=` `[` $elementsPerThread `]` `,` + + `subgroup_order` `=` `[` $subgroupOrder `]` `,` + `batch_order` `=` `[` $batchOrder `]` `,` + `outer_order` `=` `[` $outerOrder `]` `,` + `thread_order` `=` `[` $threadOrder `]` `,` + `element_order` `=` `[` $elementOrder `]` + `>` + }]; + + let genVerifyDecl = 1; +} + + +#endif // IREE_DIALECT_VECTOREXT_ATTRS
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp index ea5969c..3f9b0c8 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -6,6 +6,7 @@ #include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" #include <numeric> @@ -195,6 +196,68 @@ return offset; } +VectorLayoutInterface +NestedLayoutAttr::project(ArrayRef<bool> projectedDims) const { + llvm_unreachable("Not yet implemented"); +} + +VectorLayoutInterface +NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const { + llvm_unreachable("Not yet implemented"); +} + +SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const { + llvm_unreachable("Not yet implemented"); +} + +bool NestedLayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const { + // Multiply all shapes in the layout. + for (int i = 0, e = shape.size(); i < e; ++i) { + int64_t expectedShape = getSubgroupsPerWorkgroup()[i] * + getBatchesPerSubgroup()[i] * + getOutersPerBatch()[i] * getThreadsPerOuter()[i] * + getElementsPerThread()[i]; + if (expectedShape != shape[i]) { + return false; + } + } + return true; +} + +// TODO: These things should ideally go into the parser when we have a custom +// parser. +LogicalResult NestedLayoutAttr::verify( + llvm::function_ref<InFlightDiagnostic()> emitError, + ArrayRef<int64_t> subgroupsPerWorkgroup, ArrayRef<int64_t> subgroupOrder, + ArrayRef<int64_t> batchesPerSubgroup, ArrayRef<int64_t> batchOrder, + ArrayRef<int64_t> outersPerBatch, ArrayRef<int64_t> outerOrder, + ArrayRef<int64_t> threadsPerOuter, ArrayRef<int64_t> threadOrder, + ArrayRef<int64_t> elementsPerThread, ArrayRef<int64_t> elementOrder) { + + size_t rank = subgroupsPerWorkgroup.size(); + auto checkTile = [&](ArrayRef<int64_t> tileShape, ArrayRef<int64_t> order) { + if (tileShape.size() != rank || order.size() != rank) { + emitError() << "all tiles must have the same rank as the layout"; + return failure(); + } + if (!mlir::isPermutationVector(order)) { + emitError() << "all orderings must be permutation vectors"; + return failure(); + } + return success(); + }; + + if (failed(checkTile(subgroupsPerWorkgroup, subgroupOrder)) || + failed(checkTile(batchesPerSubgroup, batchOrder)) || + failed(checkTile(outersPerBatch, outerOrder)) || + failed(checkTile(threadsPerOuter, threadOrder)) || + failed(checkTile(elementsPerThread, elementOrder))) { + return failure(); + } + + return success(); +} + } // namespace mlir::iree_compiler::IREE::VectorExt using namespace mlir::iree_compiler::IREE::VectorExt;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp index 904d35a..213f57b 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtDialect.cpp
@@ -27,6 +27,10 @@ os << "layout"; return AliasResult::OverridableAlias; } + if (llvm::isa<NestedLayoutAttr>(attr)) { + os << "nested"; + return AliasResult::OverridableAlias; + } return AliasResult::NoAlias; } };
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir index 495286f..e4ee93a 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir
@@ -21,6 +21,51 @@ // ----- +#nested_1 = #iree_vector_ext.nested_layout< + subgroups_per_workgroup = [1, 1], + batches_per_subgroup = [2, 4], + outers_per_batch = [4, 1], + threads_per_outer = [4, 2], + elements_per_thread = [1, 4], + + subgroup_order = [0, 1], + batch_order = [0, 1], + outer_order = [1, 0], + thread_order = [0, 1], + element_order = [0, 1] +> + +#nested_2 = #iree_vector_ext.nested_layout< + subgroups_per_workgroup = [1, 1], + batches_per_subgroup = [4, 2], + outers_per_batch = [1, 4], + threads_per_outer = [2, 4], + elements_per_thread = [4, 1], + + subgroup_order = [1, 0], + batch_order = [1, 0], + outer_order = [0, 1], + thread_order = [1, 0], + element_order = [1, 0] +> + +func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> { + %cst_0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> + %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #nested_1, desiredLayout = #nested_2} : vector<32x32xf16> -> vector<32x32xf16> + return %2 : vector<32x32xf16> +} + +// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 2], outers_per_batch = [1, 4], threads_per_outer = [2, 4], elements_per_thread = [4, 1], subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [0, 1], thread_order = [1, 0], element_order = [1, 0]> +// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [4, 1], threads_per_outer = [4, 2], elements_per_thread = [1, 4], subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [1, 0], thread_order = [0, 1], element_order = [0, 1]> +// CHECK-LABEL: func.func @specify_nested +// CHECK: iree_vector_ext.layout_conflict_resolution +// CHECK-SAME: desiredLayout = #[[LAYOUT0]] +// CHECK-SAME: sourceLayout = #[[LAYOUT1]] + +// ----- + func.func @to_simd_op(%simt: vector<4x4x4xf16>) -> vector<64x64xf16> { %simd = iree_vector_ext.to_simd %simt : vector<4x4x4xf16> -> vector<64x64xf16> func.return %simd : vector<64x64xf16>