[LLVMGPUVectorDistribute] Fix vector step distribute (#19227)

Currently, the 'thread_stride' of NestedLayoutAttr is misinterpreted as
the access stride of multi-dimensional vector.

However, it turns out it correspond to tid -> vtid mapping and the
undistributed vector is packed as :
subgroup x batch x outer x thread x element
where vtid is used to index 'thread' dimension.

Therefore, this commit removes the usage of 'thread_stride's and
'subgroups_stride' when calculating the base constant offset and rather
obtain them from packed undistributed vector shape.

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 31f4167..4991a35 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -990,6 +990,19 @@
     return lens;
   }
 
+  // This is a helper to extract strides from a given shape
+  // E.g. : a shape of 2x3x4 will return strides [12, 4, 1]
+  SmallVector<int64_t> getStrides(ArrayRef<int64_t> shape) const {
+    int64_t elementCount = ShapedType::getNumElements(shape);
+    SmallVector<int64_t> strides;
+    int64_t currStride = elementCount;
+    for (int64_t len : shape) {
+      currStride = currStride / len;
+      strides.push_back(currStride);
+    }
+    return strides;
+  }
+
   // Once we are in the realm of remaining dimensions,
   // the strides are not packed. This is a helper to
   // obtain the packed strides of the remaining dimensions.
@@ -997,14 +1010,7 @@
   //  getRemainingDims)
   SmallVector<int64_t> getPackedStrides(ArrayRef<DimInfo> dims) const {
     SmallVector<int64_t> lens = getLens(dims);
-    int64_t elementCount = ShapedType::getNumElements(lens);
-    SmallVector<int64_t> packedStrides;
-    int64_t currStride = elementCount;
-    for (int64_t len : lens) {
-      currStride = currStride / len;
-      packedStrides.push_back(currStride);
-    }
-    return packedStrides;
+    return getStrides(lens);
   }
 
   // This function emulates the slicing of otherwise large constant
@@ -1091,9 +1097,14 @@
     SmallVector<Value> subgroupIndices, threadIndices;
     populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
                                  subgroupIndices, threadIndices);
-    ArrayRef<int64_t> subgroupStrides = resultLayout.getSubgroupStrides();
+
+    SmallVector<int64_t> undistributedShape =
+        resultLayout.getUndistributedPackedShape();
+    SmallVector<int64_t> undistributedStrides = getStrides(undistributedShape);
+    constexpr int64_t subgroupIdx = 0;
+    constexpr int64_t threadIdx = 3;
+
     ArrayRef<int64_t> subgroupLengths = resultLayout.getSubgroupTile();
-    ArrayRef<int64_t> threadStrides = resultLayout.getThreadStrides();
     ArrayRef<int64_t> threadLengths = resultLayout.getThreadTile();
     // Step op by definition should be single dimensional.
     SmallVector<int64_t> distributedShape =
@@ -1102,8 +1113,9 @@
     int64_t distributedElements = ShapedType::getNumElements(distributedShape);
     int64_t originalElements = result.getType().getNumElements();
     SmallVector<DimInfo, 2> distributedDims{
-        {subgroupIndices[0], subgroupLengths[0], subgroupStrides[0]},
-        {threadIndices[0], threadLengths[0], threadStrides[0]}};
+        {subgroupIndices[0], subgroupLengths[0],
+         undistributedStrides[subgroupIdx]},
+        {threadIndices[0], threadLengths[0], undistributedStrides[threadIdx]}};
     llvm::sort(distributedDims, [](const DimInfo &lhs, const DimInfo &rhs) {
       return lhs.dimStride > rhs.dimStride;
     });
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir
index fbc5325..76c33e4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir
@@ -26,11 +26,10 @@
 }
 
 // CHECK-LABEL: func @step_1
-// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
 // CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 16) mod 4)>()[%thread_id_x]
-// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c16 : index
-// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex>
-// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<4xindex>
+// CHECK: %[[TIDB:.+]] = vector.broadcast %[[TID]] : index to vector<4xindex>
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDB]], %[[CST]] : vector<4xindex>
 
 // -----
 
@@ -94,10 +93,10 @@
 }
 
 // CHECK-LABEL: func @step_3
-// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 24, 25]> : vector<4xindex>
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 8, 9]> : vector<4xindex>
 // CHECK: %[[WID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 512) mod 3)>()[%thread_id_x]
 // CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x]
-// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c8 : index
+// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c16 : index
 // CHECK: %[[WID_STRIDEV:.+]] = vector.broadcast %[[WID_STRIDE]] : index to vector<4xindex>
 // CHECK: %[[OFFSET0:.+]] = arith.addi %[[WID_STRIDEV]], %[[CST]] : vector<4xindex>
 // CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index
@@ -132,7 +131,8 @@
 }
 
 // CHECK-LABEL: func @step_4
-// CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112]> : vector<8xindex>
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
-// CHECK: %[[TIDV:.+]] = vector.broadcast %[[TID]] : index to vector<8xindex>
-// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDV]], %[[CST]] : vector<8xindex>
+// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c8 : index
+// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<8xindex>
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<8xindex>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index f086ecb..92134cb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -311,6 +311,20 @@
   return shape;
 }
 
+/// Before we distribute, we would like to see this as:
+/// <SUBGROUP x BATCH x OUTER x THREAD x ELEMENT>
+SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const {
+  SmallVector<int64_t> shape;
+  int64_t rank = getRank();
+  shape.reserve(rank * 5);
+  shape.append(getSubgroupTile().begin(), getSubgroupTile().end());
+  shape.append(getBatchTile().begin(), getBatchTile().end());
+  shape.append(getOuterTile().begin(), getOuterTile().end());
+  shape.append(getThreadTile().begin(), getThreadTile().end());
+  shape.append(getElementTile().begin(), getElementTile().end());
+  return shape;
+}
+
 // Gets the rank of the undistributed vector for this layout.
 int64_t NestedLayoutAttr::getRank() const {
   // The layout requires that all size lists are the same length and match
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index 446ff77..913fb9f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -292,6 +292,9 @@
     // Returns the subgroup/lane ids delinearized from a single linearized
     // thread ID.
     SmallVector<Value> computeThreadIds(Value threadId, int64_t subgroupSize, RewriterBase &rewriter) const;
+
+    // Get the undistributed shape that is subgroup x batch x outer x thread x element
+    SmallVector<int64_t> getUndistributedPackedShape() const;
   }];
 
   let genVerifyDecl = 1;