[LLVMGPU][VectorDist] Enable support to distribute vector.transfer_write with non-contiguous dims (#17895)
This commit adds support to distribute vector.transfer_write with
non-contiguous dims. It does so by adding better support to detect and
handle "projected"/unused dimensions.
This is especially useful for attention-transpose fusion, where it
introduces a vector.transfer_write with indexing_map and types that
looks like:
```
permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>} : vector<32x64xf16>, memref<2x4096x10x64xf16>
```
Previously it used to crash on an assertion error which stated that it's
an invalid indexing_map. This fixes that.
---------
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 8837891..4b3ae90 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -163,31 +163,32 @@
/// Given a projected permutation, get a reduced permutation, i.e. without
/// the projected dimensions.
-static SmallVector<int64_t> getReducedPermutation(AffineMap permutationMap) {
+static SmallVector<int64_t>
+getReducedPermutation(AffineMap permutationMap,
+ llvm::SmallBitVector &unusedDims) {
assert(permutationMap.isProjectedPermutation() &&
"permutation map should be a projected permutation.");
// TODO: The permutation map may also have broadcasting. Currently, we do not
// handle it. This can be fixed by adding a "BROADCAST" dimension in the
// layout.
+ unusedDims.clear();
+ unusedDims.resize(permutationMap.getNumDims(), true);
+
+ for (AffineExpr dimExpr : permutationMap.getResults()) {
+ int64_t pos = cast<AffineDimExpr>(dimExpr).getPosition();
+ unusedDims[pos] = false;
+ }
+
SmallVector<int64_t> permutation;
permutation.reserve(permutationMap.getNumResults());
- unsigned leadingUnitDims =
- permutationMap.getNumDims() - permutationMap.getNumResults();
- for (AffineExpr dim : permutationMap.getResults()) {
- // Get this dim's position in the permutation map.
- auto dimExpr = dyn_cast<AffineDimExpr>(dim);
- if (!dimExpr) {
- llvm::report_fatal_error("permutation map is not a projected "
- "permutation.");
- }
-
- unsigned pos = dimExpr.getPosition();
- assert(pos >= leadingUnitDims && "invalid permutation map");
- pos -= leadingUnitDims;
+ AffineMap reducedMap = compressUnusedDims(permutationMap);
+ for (AffineExpr dimExpr : reducedMap.getResults()) {
+ int64_t pos = cast<AffineDimExpr>(dimExpr).getPosition();
permutation.push_back(pos);
}
+
return permutation;
}
@@ -208,8 +209,9 @@
// lowering. When accessing memory, we use the memoryLayout, because that
// is how the data is accessed in memory. The data is stored in the vector
// according to vectorLayout.
+ llvm::SmallBitVector unusedDims;
SmallVector<int64_t> permutation =
- getReducedPermutation(xferOp.getPermutationMap());
+ getReducedPermutation(xferOp.getPermutationMap(), unusedDims);
LayoutAttr memoryLayout =
cast<LayoutAttr>(vectorLayout.permute(permutation));
@@ -219,8 +221,8 @@
LayoutIterator iterator(vectorLayout, steps);
iterator.apply([&](const LayoutIterator::State &state) {
- SmallVector<Value> memoryIndices =
- getMemoryIndices(state, memoryLayout, xferOp.getIndices(), rewriter);
+ SmallVector<Value> memoryIndices = getMemoryIndices(
+ state, memoryLayout, xferOp.getIndices(), unusedDims, rewriter);
SmallVector<int64_t> accIndices = state.computeSIMTIndex();
accumulator = accessUnit(xferOp, memoryIndices, accIndices, accumulator,
vectorLayout, memoryLayout, rewriter);
@@ -232,17 +234,22 @@
SmallVector<Value> getMemoryIndices(const LayoutIterator::State &state,
LayoutAttr memoryLayout,
SmallVector<Value> indices,
+ llvm::SmallBitVector &projectedDims,
RewriterBase &rewriter) const {
SmallVector<Value> simdIndices =
computeSIMDIndex(state, memoryLayout, laneId, rewriter);
SmallVector<Value> memoryIndices(indices);
// The memory layout has some projected leading dims that indices doesn't.
- int leadingProjectedDims = memoryIndices.size() - simdIndices.size();
- for (int i = leadingProjectedDims, e = memoryIndices.size(); i < e; ++i) {
+ int currSimd = 0;
+ for (int i = 0, e = memoryIndices.size(); i < e; ++i) {
+ if (projectedDims[i]) {
+ continue;
+ }
+
memoryIndices[i] = rewriter.create<arith::AddIOp>(
- rewriter.getUnknownLoc(), memoryIndices[i],
- simdIndices[i - leadingProjectedDims]);
+ rewriter.getUnknownLoc(), memoryIndices[i], simdIndices[currSimd]);
+ ++currSimd;
}
return memoryIndices;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index ebeca29..e5ecf59 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -196,49 +196,47 @@
// TODO: Use affine min tricks based on the grid size to elide the mod.
// Note that this IR is invalid if subgroup size != 8.
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 mod 8)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 8 + 8)>
-
-// CHECK-LABEL: @distribute_transfer_write_row_major
func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %root, %alloc[%c0, %c0]
{in_bounds = [true, true],
"__vector_layout_test_anchor_operand_0" = #layout_row_major}
: vector<16x16xf16>, memref<64x64xf16>
-
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
- // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
- // CHECK-DAG: %[[LANEID:.+]] = gpu.thread_id x
- // CHECK: %[[VEC_LANE_Y:.+]] = affine.apply #[[$MAP0]]()[%[[LANEID]]]
- // CHECK: %[[DIST_SRC_VEC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<2x2x8xf16>
- // CHECK: %[[BATCH_0_0:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 0] : vector<8xf16> from vector<2x2x8xf16>
- // CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
-
- // CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.apply #[[$MAP1]]()[%[[LANEID]]]
- // CHECK: %[[BATCH_1_0:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 0] : vector<8xf16> from vector<2x2x8xf16>
- // CHECK: vector.store %[[BATCH_1_0]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
-
- // CHECK: %[[BATCH_0_1:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 1] : vector<8xf16> from vector<2x2x8xf16>
- // CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
-
- // CHECK: %[[BATCH_1_1:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 1] : vector<8xf16> from vector<2x2x8xf16>
- // CHECK: vector.store %[[BATCH_1_1]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
func.return
}
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 8 + 8)>
-// CHECK-LABEL: @distribute_transfer_write_col_major
+// CHECK-LABEL: @distribute_transfer_write_row_major
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[LANEID:.+]] = gpu.thread_id x
+// CHECK: %[[VEC_LANE_Y:.+]] = affine.apply #[[$MAP0]]()[%[[LANEID]]]
+// CHECK: %[[DIST_SRC_VEC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<2x2x8xf16>
+// CHECK: %[[BATCH_0_0:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 0] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
+
+// CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.apply #[[$MAP1]]()[%[[LANEID]]]
+// CHECK: %[[BATCH_1_0:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 0] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_1_0]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
+
+// CHECK: %[[BATCH_0_1:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 1] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
+
+// CHECK: %[[BATCH_1_1:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 1] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_1_1]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
+
func.func @distribute_transfer_write_col_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %root, %alloc[%c0, %c0]
{in_bounds = [true, true],
"__vector_layout_test_anchor_operand_0" = #layout_col_major}
: vector<16x16xf16>, memref<64x64xf16>
- // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
func.return
}
+// CHECK-LABEL: @distribute_transfer_write_col_major
+// CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
-// CHECK-LABEL: @distribute_transfer_write_row_major_with_broadcast
func.func @distribute_transfer_write_row_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -246,11 +244,11 @@
permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
"__vector_layout_test_anchor_operand_0" = #layout_row_major}
: vector<16x16xf16>, memref<32x32x32x32xf16>
- // CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
func.return
}
+// CHECK-LABEL: @distribute_transfer_write_row_major_with_broadcast
+// CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
-// CHECK-LABEL: @distribute_transfer_write_col_major_with_broadcast
func.func @distribute_transfer_write_col_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -258,11 +256,11 @@
permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
"__vector_layout_test_anchor_operand_0" = #layout_col_major}
: vector<16x16xf16>, memref<32x32x32x32xf16>
- // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
func.return
}
+// CHECK-LABEL: @distribute_transfer_write_col_major_with_broadcast
+// CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
-// CHECK-LABEL: @distribute_transfer_write_row_major_transpose
func.func @distribute_transfer_write_row_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -270,11 +268,11 @@
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
"__vector_layout_test_anchor_operand_0" = #layout_row_major}
: vector<16x16xf16>, memref<32x32x32x32xf16>
- // CHECK-COUNT-32: vector.store {{.*}}, vector<1xf16>
func.return
}
+// CHECK-LABEL: @distribute_transfer_write_row_major_transpose
+// CHECK-COUNT-32: vector.store {{.*}}, vector<1xf16>
-// CHECK-LABEL: @distribute_transfer_write_col_major_transpose
func.func @distribute_transfer_write_col_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -282,10 +280,25 @@
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
"__vector_layout_test_anchor_operand_0" = #layout_col_major}
: vector<16x16xf16>, memref<32x32x32x32xf16>
-
- // CHECK-COUNT-2: vector.store {{.*}}, vector<4xf16>
func.return
}
+// CHECK-LABEL: @distribute_transfer_write_col_major_transpose
+// CHECK-COUNT-2: vector.store {{.*}}, vector<4xf16>
+
+
+func.func @distribute_transfer_write_with_non_contiguous_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %root, %alloc[%c0, %a, %c0, %b]
+ {in_bounds = [true, true],
+ permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+ : vector<16x16xf16>, memref<32x32x32x32xf16>
+ func.return
+}
+// CHECK-LABEL: func.func @distribute_transfer_write_with_non_contiguous_broadcast
+// CHECK-SAME: %[[ROOT:.+]]: vector<16x16xf16>, %[[A:.+]]: index, %[[B:.+]]: index, %[[ALLOC:.+]]: memref<32x32x32x32xf16>)
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-COUNT-4: vector.store %{{.+}}, %[[ALLOC]][%[[C0]], {{.+}}, %[[C0]], %{{.+}}] : memref<32x32x32x32xf16>, vector<8xf16>
builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {