[DOCS] Update VectorExt::NestedLayoutAttr docs (#19246)
This commit updates the NestedLayoutAttr docs
to represent what it is today and adds few
examples to make it more understandable.
Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index c401e67..01cf309 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -18,45 +18,62 @@
let mnemonic = "nested_layout";
let summary = [{A layout representing a mapping from GPU thread hierarchy to a shape}];
let description = [{
- This layout explicitly defines how a shape is mapped to a compute
- hierarchy. We consider the following levels of hierarchy, inspired by GPUs:
+ This layout explicitly defines how the shape of the associated vector
+ is mapped to a compute hierarchy.
+ We consider the following levels of hierarchy, inspired by GPUs:
- 1. Subgroups per Workgroup
- 2. Threads per Subgroup
- 3. Elements per Thread
+ 1. Subgroups per workgroup
+ 2. Threads per subgroup
+ 3. Elements per thread
- Conceptually, each higher level of hierarchy can be viewed as multiple
- tiles of the lower level of hierarchy; each lower level of hierarchy is
- nested in the higher level of hierarchy. The last level represents the
- final elements in memory.
+ Note that elements in a thread is also conceptually viewed as
+ a 3 dimensions. i.e. elements per thread = batch x outer x element
+ However, the final order of sub-dimensions are not exactly in that
+ hierarchy. For e.g. a single dimensional vector say `vector< n x f16>`
+ is viewed as a
+ `vector<subgroup x batch x outer x thread x element>` 5 dimensional
+ vector. For a two dimensional vector, each above sub-dimension would
+ be doubled. i.e. `vector< n1 x n2 x f16>` is viewed as a
+ `vector<subgroup1 x subgroup2 x batch1 x batch2 x ... x element1 x element2>`
- The conceptual mapping is leveraged during compilation for tiling and
- distributing to hardware for parallel computation. Concretely, the mapping
- is done on each dimension of the original vector shape. For example, for
- vector shape 16x16x16, we have 3 dimensions, so at each level of the
- hierarchy we would have 3 tile sizes. Similarly for vector shape 32x32, we
- would have 2-D tile sizes per compute hierarchy level.
+ Now, when the vector<subgroup x batch x outer x thread x element> is
+ indexed, the indices of 'subgroup' and `thread` are not directly refferring
+ to the subgroup_id and thread_id in the GPU context. lets define them
+ as virtual_subgroup_id and virtual_thread_id and they hold the following
+ definition:
+ ```
+ virtual_subgroup_id[i] = (subgroup_id / subgroup_stride[i]) % subgroup_tile_size[i]
+ virtual_thread_id[i] = (thread_id / thread_stride[i]) % thread_tile_size[i]
+ ```
+
+ the inverse mapping would be:
+ ```
+ subgroup_id = sum_i(subgroup_stride[i] * virtual_subgroup_id[i]) % mul_i(subgroup_tile_size[i])
+ thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
+ for i = [0 : rank(undistributed_vector)]
+ ```
+
+ NOTE: if stride is zero, it represents non-distribution of that
+ dimension on that hierarchy.
We now describe each level of tiling. Each level of tiling represents a
count of tiles over the next level (rather than a list of tile sizes).
- 1. Subgroups per Workgroup
+ #### Subgroups per Workgroup
This level of tiling is also known as "subgroup/warp distribution". It
- represents how subgroups are distributed in a workgroup.
+ represents how the vector is distributed into subgroups.
- The subgroups are placed contiguously with their shape and ordering
- determined by:
- - `subgroup_tile`: Sizes of this level of tiling
- - `subgroup_order`: Ordering of dimensions, from outermost to innermost
-
- For example, subgroup_tile=[4, 2], subgroup_order=[1, 0] will
+ For example, consider distributing `vector<4x2xf16>` to a
+ `subgroup_tile=[4, 2], subgroup_stride=[1, 4]` will
arrange the subgroups in the order:
- 0 4
- 1 5
- 2 6
- 3 7
+ ```
+ virtual_subgroups_ids:
+ [0][0] , [0][1] , [1][0], [1][1], [2][0], [2][1], [3][0], [3][1]
+ subgroups_ids:
+ 0, 4, 1, 5, 2, 6, 3, 7
+ ```
The total number of subgroups used (computed by multiplying each dim in
subgroup_tile) should be a multiple of number of subgroups in the
@@ -64,103 +81,111 @@
subgroups of the hardware, then the subgroup used (say x) is
x mod num_subgroups:
+ ```
num_subgroups = 4
- 0 4 0 0
- 1 5 x mod 4 1 1
- 2 6 -------> 2 2
- 3 7 3 3
+ 0, 4, 1, 5, 2, 6, 3, 7
+ | mod 4
+ V
+ 0, 0, 1, 1, 2, 2, 3, 3
+ ```
- 2. Threads per Subgroup:
+ #### Threads per Subgroup:
- Threads in a subgroup are distributed in three levels.
+ This level of tiling is also known as "thread distribution" within a subgroup.
+ The logic is quite similiar to subgroup distribution using the tile sizes
+ and the 'thread_strides'.
+
+ #### Element distribution on a thread
+
+ So after the vector is distributed per thread
+ on a subgroup, it is viewed as [batch] x [outer] x [element]
+ where each sub-dimensions group has dimensions equal
+ to original rank of the undistributed vector.
The first level, batches, are a way to represent instruction unrolling. For
example, an intrinsic which can only take 4x4 shape at a time, uses batches
to unroll a 16x16 shape to the native intrinsice shape.
- Batches can be thought of as loops around the original layout:
-
- for b_0 in range(batch_0):
- for b_1 in range(batch_1):
- ...
-
- `batch_tile` represents the range of each loop.
-
The second level, outers, is a way to represent thread layout duplication
required by a particular intrinsic. For example, some AMDGPU matrix
multiplication variants require threads to be distributed
like:
- 0 1 2 3 4
- 5 6 7 8 9
- --------- --> Thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
- 0 1 2 3 4 outer_tile=[2, 1]
- 5 6 7 8 9 thread_tile=[2, 5]
+ E.g.: `outer_tile=[2, 1], thread_tile=[2, 5]`
+ the thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
+
+ ```
+ outer = 0,0 :
+ [0 1 2 3 4]
+ [5 6 7 8 9]
+
+ outer = 1,0 :
+ [0 1 2 3 4]
+ [5 6 7 8 9]
+ ```
`outer_tile` represents the number of outers in a batch.
- Finally, threads are distributed in a single outer. The thread
- distribution is represented by:
-
- - thread_tile: Sizes of this level of tiling
- - thread_order: Ordering of dimensions, from outermost to innermost
-
- Examples of thread distribution over a 8x4 shape:
-
- {
- batch_tile = [2, 1]
- outer_tile = [2, 2]
- thread_tile = [2, 2]
-
- thread_order = [1, 0]
- }
-
- Distributed tile:
-
- {
- [0 2]|[0 2] 0,1,2,3 --> thread ids
- [1 3]|[1 3]
- ------------ [x z] --> a single outer tile
- [0 2]|[0 2] [y w]
- [1 3]|[1 3]
- }{
- [0 2]|[0 2] { ... } --> a single batch tile
- [1 3]|[1 3]
- ------------
- [0 2]|[0 2]
- [1 3]|[1 3]
- }
-
- So, the thread distribution looks like:
-
- [0 2 0 2]
- [1 3 1 3]
- [0 2 0 2]
- [1 3 1 3]
- [0 2 0 2]
- [1 3 1 3]
- [0 2 0 2]
- [1 3 1 3]
-
- The total number of threads used (computed by multiplying each dim in
- thread_tile) should be a multiple of subgroup size of the
- harware. If the total number of threads used exceeds the subgroup size of
- the hardware, then the threads used (say tid) is tid mod subgroup_size:
-
- subgroup_size = 4
-
- 0 1 0 0
- 2 3 tid mod 4 1 1
- 4 5 --------> 2 2
- 6 7 3 3
-
- 3. Elements per Thread
-
The final level of tiling, representing the minimum shape of vector that
is treated as an atom.
`element_tile` represents the native size of the vector.
+
+ #### A full example
+
+ Vector to be distributed: `vector<64x64xf16>`
+ ```
+ NestedLayout : <
+ subgroup_tile = [2, 1],
+ batch_tile = [2, 4],
+ outer_tile = [1, 1],
+ thread_tile = [16, 4],
+ element_tile = [1, 4],
+ subgroup_strides = [1, 0],
+ thread_strides = [1, 16]
+ >
+ ```
+
+ This is conceptually viewed as a: `vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>`
+ where the first groups of sub-dimensions
+ represent the distribution into subgroups.
+ The subgroup_strides being [1, 0] means
+ each subgroup is going to get a vector
+ as follows:
+
+ ```
+ subgroup0 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+ from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
+ subgroup1 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+ from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
+ subgroup2 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+ from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
+ subgroup3 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+ from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
+ ```
+
+ Then each vector<[2x4]x[1x1]x[16x4]x[1x4]>
+ is distributed threads in a subgroup using
+ thread_strides = [1, 16]
+
+ recall: `thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])`
+
+ ```
+ thread0 : vector<[2x4]x[1x1]x[1x4]>
+ from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,0,:,:]
+ thread1 : vector<[2x4]x[1x1]x[1x4]>
+ from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,1,0,:,:]
+ ...
+ ...
+ thread16 : vector<[2x4]x[1x1]x[1x4]>
+ from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,1,:,:]
+ ```
+
+ Finally we are left with a distributed vector
+ of conceptual view : `vector<[2x4]x[1x1]x[1x4]>`
+ where the actual shape is : `vector<2x16>`.
+
}];
let parameters = (ins