[DOCS] Update VectorExt::NestedLayoutAttr docs (#19246)

This commit updates the NestedLayoutAttr docs
to represent what it is today and adds few
examples to make it more understandable.

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index c401e67..01cf309 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -18,45 +18,62 @@
   let mnemonic = "nested_layout";
   let summary = [{A layout representing a mapping from GPU thread hierarchy to a shape}];
   let description = [{
-    This layout explicitly defines how a shape is mapped to a compute
-    hierarchy. We consider the following levels of hierarchy, inspired by GPUs:
+    This layout explicitly defines how the shape of the associated vector
+    is mapped to a compute hierarchy.
+    We consider the following levels of hierarchy, inspired by GPUs:
 
-    1. Subgroups per Workgroup
-    2. Threads per Subgroup
-    3. Elements per Thread
+    1. Subgroups per workgroup
+    2. Threads per subgroup
+    3. Elements per thread
 
-    Conceptually, each higher level of hierarchy can be viewed as multiple
-    tiles of the lower level of hierarchy; each lower level of hierarchy is
-    nested in the higher level of hierarchy. The last level represents the
-    final elements in memory.
+    Note that elements in a thread is also conceptually viewed as
+    a 3 dimensions. i.e. elements per thread = batch x outer x element
+    However, the final order of sub-dimensions are not exactly in that
+    hierarchy. For e.g. a single dimensional vector say `vector< n x f16>`
+    is viewed as a
+    `vector<subgroup x batch x outer x thread x element>` 5 dimensional
+    vector. For a two dimensional vector, each above sub-dimension would
+    be doubled. i.e. `vector< n1 x n2 x f16>` is viewed as a
+    `vector<subgroup1 x subgroup2 x batch1 x batch2 x ... x element1 x element2>`
 
-    The conceptual mapping is leveraged during compilation for tiling and
-    distributing to hardware for parallel computation. Concretely, the mapping
-    is done on each dimension of the original vector shape. For example, for
-    vector shape 16x16x16, we have 3 dimensions, so at each level of the
-    hierarchy we would have 3 tile sizes. Similarly for vector shape 32x32, we
-    would have 2-D tile sizes per compute hierarchy level.
+    Now, when the vector<subgroup x batch x outer x thread x element> is
+    indexed, the indices of 'subgroup' and `thread` are not directly refferring
+    to the subgroup_id and thread_id in the GPU context. lets define them
+    as virtual_subgroup_id and virtual_thread_id and they hold the following
+    definition:
+    ```
+    virtual_subgroup_id[i] = (subgroup_id / subgroup_stride[i]) % subgroup_tile_size[i]
+    virtual_thread_id[i]   = (thread_id   / thread_stride[i]) % thread_tile_size[i]
+    ```
+
+    the inverse mapping would be:
+    ```
+    subgroup_id = sum_i(subgroup_stride[i] * virtual_subgroup_id[i]) % mul_i(subgroup_tile_size[i])
+    thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
+        for i = [0 : rank(undistributed_vector)]
+    ```
+
+    NOTE: if stride is zero, it represents non-distribution of that
+    dimension on that hierarchy.
 
     We now describe each level of tiling. Each level of tiling represents a
     count of tiles over the next level (rather than a list of tile sizes).
 
-    1. Subgroups per Workgroup
+    #### Subgroups per Workgroup
 
     This level of tiling is also known as "subgroup/warp distribution". It
-    represents how subgroups are distributed in a workgroup.
+    represents how the vector is distributed into subgroups.
 
-    The subgroups are placed contiguously with their shape and ordering
-    determined by:
-      - `subgroup_tile`: Sizes of this level of tiling
-      - `subgroup_order`: Ordering of dimensions, from outermost to innermost
-
-    For example, subgroup_tile=[4, 2], subgroup_order=[1, 0] will
+    For example, consider distributing `vector<4x2xf16>` to a
+    `subgroup_tile=[4, 2], subgroup_stride=[1, 4]` will
     arrange the subgroups in the order:
 
-    0 4
-    1 5
-    2 6
-    3 7
+    ```
+    virtual_subgroups_ids:
+    [0][0] , [0][1] , [1][0], [1][1], [2][0], [2][1], [3][0], [3][1]
+    subgroups_ids:
+    0, 4, 1, 5, 2, 6, 3, 7
+    ```
 
     The total number of subgroups used (computed by multiplying each dim in
     subgroup_tile) should be a multiple of number of subgroups in the
@@ -64,103 +81,111 @@
     subgroups of the hardware, then the subgroup used (say x) is
     x mod num_subgroups:
 
+    ```
     num_subgroups = 4
 
-    0 4               0 0
-    1 5    x mod 4    1 1
-    2 6    ------->   2 2
-    3 7               3 3
+    0, 4, 1, 5, 2, 6, 3, 7
+    | mod 4
+    V
+    0, 0, 1, 1, 2, 2, 3, 3
+    ```
 
-    2. Threads per Subgroup:
+    #### Threads per Subgroup:
 
-    Threads in a subgroup are distributed in three levels.
+    This level of tiling is also known as "thread distribution" within a subgroup.
+    The logic is quite similiar to subgroup distribution using the tile sizes
+    and the 'thread_strides'.
+
+    #### Element distribution on a thread
+
+    So after the vector is distributed per thread
+    on a subgroup, it is viewed as [batch] x [outer] x [element]
+    where each sub-dimensions group has dimensions equal
+    to original rank of the undistributed vector.
 
     The first level, batches, are a way to represent instruction unrolling. For
     example, an intrinsic which can only take 4x4 shape at a time, uses batches
     to unroll a 16x16 shape to the native intrinsice shape.
 
-    Batches can be thought of as loops around the original layout:
-
-    for b_0 in range(batch_0):
-      for b_1 in range(batch_1):
-        ...
-
-    `batch_tile` represents the range of each loop.
-
     The second level, outers, is a way to represent thread layout duplication
     required by a particular intrinsic. For example, some AMDGPU matrix
     multiplication variants require threads to be distributed
     like:
 
-    0 1 2 3 4
-    5 6 7 8 9
-    --------- --> Thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
-    0 1 2 3 4     outer_tile=[2, 1]
-    5 6 7 8 9     thread_tile=[2, 5]
+    E.g.: `outer_tile=[2, 1], thread_tile=[2, 5]`
+    the thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
+
+    ```
+    outer = 0,0 :
+    [0 1 2 3 4]
+    [5 6 7 8 9]
+
+    outer = 1,0 :
+    [0 1 2 3 4]
+    [5 6 7 8 9]
+    ```
 
     `outer_tile` represents the number of outers in a batch.
 
-    Finally, threads are distributed in a single outer. The thread
-    distribution is represented by:
-
-      - thread_tile: Sizes of this level of tiling
-      - thread_order: Ordering of dimensions, from outermost to innermost
-
-    Examples of thread distribution over a 8x4 shape:
-
-    {
-      batch_tile = [2, 1]
-      outer_tile = [2, 2]
-      thread_tile = [2, 2]
-
-      thread_order = [1, 0]
-    }
-
-    Distributed tile:
-
-    {
-      [0 2]|[0 2]      0,1,2,3 --> thread ids
-      [1 3]|[1 3]
-      ------------     [x z]   --> a single outer tile
-      [0 2]|[0 2]      [y w]
-      [1 3]|[1 3]
-    }{
-      [0 2]|[0 2]      { ... } --> a single batch tile
-      [1 3]|[1 3]
-      ------------
-      [0 2]|[0 2]
-      [1 3]|[1 3]
-    }
-
-    So, the thread distribution looks like:
-
-    [0 2 0 2]
-    [1 3 1 3]
-    [0 2 0 2]
-    [1 3 1 3]
-    [0 2 0 2]
-    [1 3 1 3]
-    [0 2 0 2]
-    [1 3 1 3]
-
-    The total number of threads used (computed by multiplying each dim in
-    thread_tile) should be a multiple of subgroup size of the
-    harware. If the total number of threads used exceeds the subgroup size of
-    the hardware, then the threads used (say tid) is tid mod subgroup_size:
-
-    subgroup_size = 4
-
-    0 1                0 0
-    2 3    tid mod 4   1 1
-    4 5    -------->   2 2
-    6 7                3 3
-
-    3. Elements per Thread
-
     The final level of tiling, representing the minimum shape of vector that
     is treated as an atom.
 
     `element_tile` represents the native size of the vector.
+
+    #### A full example
+
+    Vector to be distributed: `vector<64x64xf16>`
+    ```
+    NestedLayout : <
+        subgroup_tile = [2, 1],
+        batch_tile = [2, 4],
+        outer_tile = [1, 1],
+        thread_tile = [16, 4],
+        element_tile = [1, 4],
+        subgroup_strides = [1, 0],
+        thread_strides = [1, 16]
+    >
+    ```
+
+    This is conceptually viewed as a: `vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>`
+    where the first groups of sub-dimensions
+    represent the distribution into subgroups.
+    The subgroup_strides being [1, 0] means
+    each subgroup is going to get a vector
+    as follows:
+
+    ```
+    subgroup0 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+    from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
+    subgroup1 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+    from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
+    subgroup2 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+    from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
+    subgroup3 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
+    from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
+    ```
+
+    Then each vector<[2x4]x[1x1]x[16x4]x[1x4]>
+    is distributed threads in a subgroup using
+    thread_strides = [1, 16]
+
+    recall: `thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])`
+
+    ```
+    thread0 : vector<[2x4]x[1x1]x[1x4]>
+    from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,0,:,:]
+    thread1 : vector<[2x4]x[1x1]x[1x4]>
+    from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,1,0,:,:]
+    ...
+    ...
+    thread16 : vector<[2x4]x[1x1]x[1x4]>
+    from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,1,:,:]
+    ```
+
+    Finally we are left with a distributed vector
+    of conceptual view : `vector<[2x4]x[1x1]x[1x4]>`
+    where the actual shape is : `vector<2x16>`.
+
   }];
 
   let parameters = (ins