[LLVMGPUVectorDistribute] Fix vector step distribute (#19227)
Currently, the 'thread_stride' of NestedLayoutAttr is misinterpreted as
the access stride of multi-dimensional vector.
However, it turns out it correspond to tid -> vtid mapping and the
undistributed vector is packed as :
subgroup x batch x outer x thread x element
where vtid is used to index 'thread' dimension.
Therefore, this commit removes the usage of 'thread_stride's and
'subgroups_stride' when calculating the base constant offset and rather
obtain them from packed undistributed vector shape.
Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 31f4167..4991a35 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -990,6 +990,19 @@
return lens;
}
+ // This is a helper to extract strides from a given shape
+ // E.g. : a shape of 2x3x4 will return strides [12, 4, 1]
+ SmallVector<int64_t> getStrides(ArrayRef<int64_t> shape) const {
+ int64_t elementCount = ShapedType::getNumElements(shape);
+ SmallVector<int64_t> strides;
+ int64_t currStride = elementCount;
+ for (int64_t len : shape) {
+ currStride = currStride / len;
+ strides.push_back(currStride);
+ }
+ return strides;
+ }
+
// Once we are in the realm of remaining dimensions,
// the strides are not packed. This is a helper to
// obtain the packed strides of the remaining dimensions.
@@ -997,14 +1010,7 @@
// getRemainingDims)
SmallVector<int64_t> getPackedStrides(ArrayRef<DimInfo> dims) const {
SmallVector<int64_t> lens = getLens(dims);
- int64_t elementCount = ShapedType::getNumElements(lens);
- SmallVector<int64_t> packedStrides;
- int64_t currStride = elementCount;
- for (int64_t len : lens) {
- currStride = currStride / len;
- packedStrides.push_back(currStride);
- }
- return packedStrides;
+ return getStrides(lens);
}
// This function emulates the slicing of otherwise large constant
@@ -1091,9 +1097,14 @@
SmallVector<Value> subgroupIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
subgroupIndices, threadIndices);
- ArrayRef<int64_t> subgroupStrides = resultLayout.getSubgroupStrides();
+
+ SmallVector<int64_t> undistributedShape =
+ resultLayout.getUndistributedPackedShape();
+ SmallVector<int64_t> undistributedStrides = getStrides(undistributedShape);
+ constexpr int64_t subgroupIdx = 0;
+ constexpr int64_t threadIdx = 3;
+
ArrayRef<int64_t> subgroupLengths = resultLayout.getSubgroupTile();
- ArrayRef<int64_t> threadStrides = resultLayout.getThreadStrides();
ArrayRef<int64_t> threadLengths = resultLayout.getThreadTile();
// Step op by definition should be single dimensional.
SmallVector<int64_t> distributedShape =
@@ -1102,8 +1113,9 @@
int64_t distributedElements = ShapedType::getNumElements(distributedShape);
int64_t originalElements = result.getType().getNumElements();
SmallVector<DimInfo, 2> distributedDims{
- {subgroupIndices[0], subgroupLengths[0], subgroupStrides[0]},
- {threadIndices[0], threadLengths[0], threadStrides[0]}};
+ {subgroupIndices[0], subgroupLengths[0],
+ undistributedStrides[subgroupIdx]},
+ {threadIndices[0], threadLengths[0], undistributedStrides[threadIdx]}};
llvm::sort(distributedDims, [](const DimInfo &lhs, const DimInfo &rhs) {
return lhs.dimStride > rhs.dimStride;
});
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir
index fbc5325..76c33e4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_step.mlir
@@ -26,11 +26,10 @@
}
// CHECK-LABEL: func @step_1
-// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 16) mod 4)>()[%thread_id_x]
-// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c16 : index
-// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex>
-// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<4xindex>
+// CHECK: %[[TIDB:.+]] = vector.broadcast %[[TID]] : index to vector<4xindex>
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDB]], %[[CST]] : vector<4xindex>
// -----
@@ -94,10 +93,10 @@
}
// CHECK-LABEL: func @step_3
-// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 24, 25]> : vector<4xindex>
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 8, 9]> : vector<4xindex>
// CHECK: %[[WID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 512) mod 3)>()[%thread_id_x]
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x]
-// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c8 : index
+// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c16 : index
// CHECK: %[[WID_STRIDEV:.+]] = vector.broadcast %[[WID_STRIDE]] : index to vector<4xindex>
// CHECK: %[[OFFSET0:.+]] = arith.addi %[[WID_STRIDEV]], %[[CST]] : vector<4xindex>
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index
@@ -132,7 +131,8 @@
}
// CHECK-LABEL: func @step_4
-// CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112]> : vector<8xindex>
+// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
-// CHECK: %[[TIDV:.+]] = vector.broadcast %[[TID]] : index to vector<8xindex>
-// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDV]], %[[CST]] : vector<8xindex>
+// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c8 : index
+// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<8xindex>
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<8xindex>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index f086ecb..92134cb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -311,6 +311,20 @@
return shape;
}
+/// Before we distribute, we would like to see this as:
+/// <SUBGROUP x BATCH x OUTER x THREAD x ELEMENT>
+SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const {
+ SmallVector<int64_t> shape;
+ int64_t rank = getRank();
+ shape.reserve(rank * 5);
+ shape.append(getSubgroupTile().begin(), getSubgroupTile().end());
+ shape.append(getBatchTile().begin(), getBatchTile().end());
+ shape.append(getOuterTile().begin(), getOuterTile().end());
+ shape.append(getThreadTile().begin(), getThreadTile().end());
+ shape.append(getElementTile().begin(), getElementTile().end());
+ return shape;
+}
+
// Gets the rank of the undistributed vector for this layout.
int64_t NestedLayoutAttr::getRank() const {
// The layout requires that all size lists are the same length and match
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index 446ff77..913fb9f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -292,6 +292,9 @@
// Returns the subgroup/lane ids delinearized from a single linearized
// thread ID.
SmallVector<Value> computeThreadIds(Value threadId, int64_t subgroupSize, RewriterBase &rewriter) const;
+
+ // Get the undistributed shape that is subgroup x batch x outer x thread x element
+ SmallVector<int64_t> getUndistributedPackedShape() const;
}];
let genVerifyDecl = 1;