[LLVMGPU][VectorDist] Enable support to distribute vector.transfer_write with non-contiguous dims (#17895)

This commit adds support to distribute vector.transfer_write with
non-contiguous dims. It does so by adding better support to detect and
handle "projected"/unused dimensions.

This is especially useful for attention-transpose fusion, where it
introduces a vector.transfer_write with indexing_map and types that
looks like:
```
permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>} : vector<32x64xf16>, memref<2x4096x10x64xf16>
```

Previously it used to crash on an assertion error which stated that it's
an invalid indexing_map. This fixes that.

---------

Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 8837891..4b3ae90 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -163,31 +163,32 @@
 
 /// Given a projected permutation, get a reduced permutation, i.e. without
 /// the projected dimensions.
-static SmallVector<int64_t> getReducedPermutation(AffineMap permutationMap) {
+static SmallVector<int64_t>
+getReducedPermutation(AffineMap permutationMap,
+                      llvm::SmallBitVector &unusedDims) {
   assert(permutationMap.isProjectedPermutation() &&
          "permutation map should be a projected permutation.");
   // TODO: The permutation map may also have broadcasting. Currently, we do not
   // handle it. This can be fixed by adding a "BROADCAST" dimension in the
   // layout.
 
+  unusedDims.clear();
+  unusedDims.resize(permutationMap.getNumDims(), true);
+
+  for (AffineExpr dimExpr : permutationMap.getResults()) {
+    int64_t pos = cast<AffineDimExpr>(dimExpr).getPosition();
+    unusedDims[pos] = false;
+  }
+
   SmallVector<int64_t> permutation;
   permutation.reserve(permutationMap.getNumResults());
 
-  unsigned leadingUnitDims =
-      permutationMap.getNumDims() - permutationMap.getNumResults();
-  for (AffineExpr dim : permutationMap.getResults()) {
-    // Get this dim's position in the permutation map.
-    auto dimExpr = dyn_cast<AffineDimExpr>(dim);
-    if (!dimExpr) {
-      llvm::report_fatal_error("permutation map is not a projected "
-                               "permutation.");
-    }
-
-    unsigned pos = dimExpr.getPosition();
-    assert(pos >= leadingUnitDims && "invalid permutation map");
-    pos -= leadingUnitDims;
+  AffineMap reducedMap = compressUnusedDims(permutationMap);
+  for (AffineExpr dimExpr : reducedMap.getResults()) {
+    int64_t pos = cast<AffineDimExpr>(dimExpr).getPosition();
     permutation.push_back(pos);
   }
+
   return permutation;
 }
 
@@ -208,8 +209,9 @@
     // lowering. When accessing memory, we use the memoryLayout, because that
     // is how the data is accessed in memory. The data is stored in the vector
     // according to vectorLayout.
+    llvm::SmallBitVector unusedDims;
     SmallVector<int64_t> permutation =
-        getReducedPermutation(xferOp.getPermutationMap());
+        getReducedPermutation(xferOp.getPermutationMap(), unusedDims);
     LayoutAttr memoryLayout =
         cast<LayoutAttr>(vectorLayout.permute(permutation));
 
@@ -219,8 +221,8 @@
     LayoutIterator iterator(vectorLayout, steps);
 
     iterator.apply([&](const LayoutIterator::State &state) {
-      SmallVector<Value> memoryIndices =
-          getMemoryIndices(state, memoryLayout, xferOp.getIndices(), rewriter);
+      SmallVector<Value> memoryIndices = getMemoryIndices(
+          state, memoryLayout, xferOp.getIndices(), unusedDims, rewriter);
       SmallVector<int64_t> accIndices = state.computeSIMTIndex();
       accumulator = accessUnit(xferOp, memoryIndices, accIndices, accumulator,
                                vectorLayout, memoryLayout, rewriter);
@@ -232,17 +234,22 @@
   SmallVector<Value> getMemoryIndices(const LayoutIterator::State &state,
                                       LayoutAttr memoryLayout,
                                       SmallVector<Value> indices,
+                                      llvm::SmallBitVector &projectedDims,
                                       RewriterBase &rewriter) const {
     SmallVector<Value> simdIndices =
         computeSIMDIndex(state, memoryLayout, laneId, rewriter);
     SmallVector<Value> memoryIndices(indices);
 
     // The memory layout has some projected leading dims that indices doesn't.
-    int leadingProjectedDims = memoryIndices.size() - simdIndices.size();
-    for (int i = leadingProjectedDims, e = memoryIndices.size(); i < e; ++i) {
+    int currSimd = 0;
+    for (int i = 0, e = memoryIndices.size(); i < e; ++i) {
+      if (projectedDims[i]) {
+        continue;
+      }
+
       memoryIndices[i] = rewriter.create<arith::AddIOp>(
-          rewriter.getUnknownLoc(), memoryIndices[i],
-          simdIndices[i - leadingProjectedDims]);
+          rewriter.getUnknownLoc(), memoryIndices[i], simdIndices[currSimd]);
+      ++currSimd;
     }
 
     return memoryIndices;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
index ebeca29..e5ecf59 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir
@@ -196,49 +196,47 @@
 // TODO: Use affine min tricks based on the grid size to elide the mod.
 // Note that this IR is invalid if subgroup size != 8.
 
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 mod 8)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 8 + 8)>
-
-// CHECK-LABEL: @distribute_transfer_write_row_major
 func.func @distribute_transfer_write_row_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %root, %alloc[%c0, %c0]
           {in_bounds = [true, true],
            "__vector_layout_test_anchor_operand_0" = #layout_row_major}
                   : vector<16x16xf16>, memref<64x64xf16>
-
-  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
-  // CHECK-DAG: %[[LANEID:.+]] = gpu.thread_id  x
-  // CHECK: %[[VEC_LANE_Y:.+]] = affine.apply #[[$MAP0]]()[%[[LANEID]]]
-  // CHECK: %[[DIST_SRC_VEC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<2x2x8xf16>
-  // CHECK: %[[BATCH_0_0:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 0] : vector<8xf16> from vector<2x2x8xf16>
-  // CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
-
-  // CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.apply #[[$MAP1]]()[%[[LANEID]]]
-  // CHECK: %[[BATCH_1_0:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 0] : vector<8xf16> from vector<2x2x8xf16>
-  // CHECK: vector.store %[[BATCH_1_0]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
-
-  // CHECK: %[[BATCH_0_1:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 1] : vector<8xf16> from vector<2x2x8xf16>
-  // CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
-
-  // CHECK: %[[BATCH_1_1:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 1] : vector<8xf16> from vector<2x2x8xf16>
-  // CHECK: vector.store %[[BATCH_1_1]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
   func.return
 }
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 mod 8 + 8)>
 
-// CHECK-LABEL: @distribute_transfer_write_col_major
+// CHECK-LABEL: @distribute_transfer_write_row_major
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[LANEID:.+]] = gpu.thread_id  x
+// CHECK: %[[VEC_LANE_Y:.+]] = affine.apply #[[$MAP0]]()[%[[LANEID]]]
+// CHECK: %[[DIST_SRC_VEC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<2x2x8xf16>
+// CHECK: %[[BATCH_0_0:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 0] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_0_0]], %{{.*}}[%[[VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
+
+// CHECK: %[[NEXT_VEC_LANE_Y:.+]] = affine.apply #[[$MAP1]]()[%[[LANEID]]]
+// CHECK: %[[BATCH_1_0:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 0] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_1_0]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C0]]] : memref<64x64xf16>, vector<8xf16>
+
+// CHECK: %[[BATCH_0_1:.+]] = vector.extract %[[DIST_SRC_VEC]][0, 1] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_0_1]], %{{.*}}[%[[VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
+
+// CHECK: %[[BATCH_1_1:.+]] = vector.extract %[[DIST_SRC_VEC]][1, 1] : vector<8xf16> from vector<2x2x8xf16>
+// CHECK: vector.store %[[BATCH_1_1]], %{{.*}}[%[[NEXT_VEC_LANE_Y]], %[[C8]]] : memref<64x64xf16>, vector<8xf16>
+
 func.func @distribute_transfer_write_col_major(%root: vector<16x16xf16>, %alloc: memref<64x64xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %root, %alloc[%c0, %c0]
           {in_bounds = [true, true],
            "__vector_layout_test_anchor_operand_0" = #layout_col_major}
                   : vector<16x16xf16>, memref<64x64xf16>
-  // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
   func.return
 }
+// CHECK-LABEL: @distribute_transfer_write_col_major
+// CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
 
-// CHECK-LABEL: @distribute_transfer_write_row_major_with_broadcast
 func.func @distribute_transfer_write_row_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -246,11 +244,11 @@
            permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
            "__vector_layout_test_anchor_operand_0" = #layout_row_major}
                   : vector<16x16xf16>, memref<32x32x32x32xf16>
-  // CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
   func.return
 }
+// CHECK-LABEL: @distribute_transfer_write_row_major_with_broadcast
+// CHECK-COUNT-4: vector.store {{.*}}, vector<8xf16>
 
-// CHECK-LABEL: @distribute_transfer_write_col_major_with_broadcast
 func.func @distribute_transfer_write_col_major_with_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -258,11 +256,11 @@
            permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
            "__vector_layout_test_anchor_operand_0" = #layout_col_major}
                   : vector<16x16xf16>, memref<32x32x32x32xf16>
-  // CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
   func.return
 }
+// CHECK-LABEL: @distribute_transfer_write_col_major_with_broadcast
+// CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
 
-// CHECK-LABEL: @distribute_transfer_write_row_major_transpose
 func.func @distribute_transfer_write_row_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -270,11 +268,11 @@
            permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
            "__vector_layout_test_anchor_operand_0" = #layout_row_major}
                   : vector<16x16xf16>, memref<32x32x32x32xf16>
-  // CHECK-COUNT-32: vector.store {{.*}}, vector<1xf16>
   func.return
 }
+// CHECK-LABEL: @distribute_transfer_write_row_major_transpose
+// CHECK-COUNT-32: vector.store {{.*}}, vector<1xf16>
 
-// CHECK-LABEL: @distribute_transfer_write_col_major_transpose
 func.func @distribute_transfer_write_col_major_transpose(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %root, %alloc[%c0, %c0, %a, %b]
@@ -282,10 +280,25 @@
            permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
            "__vector_layout_test_anchor_operand_0" = #layout_col_major}
                   : vector<16x16xf16>, memref<32x32x32x32xf16>
-
-  // CHECK-COUNT-2: vector.store {{.*}}, vector<4xf16>
   func.return
 }
+// CHECK-LABEL: @distribute_transfer_write_col_major_transpose
+// CHECK-COUNT-2: vector.store {{.*}}, vector<4xf16>
+
+
+func.func @distribute_transfer_write_with_non_contiguous_broadcast(%root: vector<16x16xf16>, %a: index, %b: index, %alloc: memref<32x32x32x32xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %root, %alloc[%c0, %a, %c0, %b]
+          {in_bounds = [true, true],
+           permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+           "__vector_layout_test_anchor_operand_0" = #layout_row_major}
+                  : vector<16x16xf16>, memref<32x32x32x32xf16>
+  func.return
+}
+// CHECK-LABEL: func.func @distribute_transfer_write_with_non_contiguous_broadcast
+// CHECK-SAME: %[[ROOT:.+]]: vector<16x16xf16>, %[[A:.+]]: index, %[[B:.+]]: index, %[[ALLOC:.+]]: memref<32x32x32x32xf16>)
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-COUNT-4: vector.store %{{.+}}, %[[ALLOC]][%[[C0]], {{.+}}, %[[C0]], %{{.+}}] : memref<32x32x32x32xf16>, vector<8xf16>
 
 builtin.module attributes { transform.with_named_sequence } {
   transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {