Fix bug in tile_to_foreach_thread mapping computation (#11312)

The permutation of tile sizes based on `mapping` was computed
incorrectly.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 174b330..2ecd1a2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -560,6 +560,7 @@
   auto workload = workgroupCountOp.getOperands();
 
   SmallVector<OpFoldResult> unpackedTileSizes;
+  int64_t numTiledDims = 0;
   for (auto ofr : tileSizes) {
     if (ofr.is<Value>() &&
         ofr.get<Value>().getType().isa<pdl::OperationType>()) {
@@ -576,6 +577,7 @@
     } else {
       unpackedTileSizes.push_back(ofr);
     }
+    if (!isConstantIntValue(unpackedTileSizes.back(), 0)) ++numTiledDims;
   }
 
   if (unpackedTileSizes.size() > workload.size()) {
@@ -584,33 +586,43 @@
         "number of tile sizes overflow the dimension from the workload");
   }
 
-  SmallVector<OpFoldResult> workgroupCount, permutedWorkgroupCount;
+  // Generate permutation of tiled dims based on the specified mapping.
+  SmallVector<int64_t> mappingPermutation;
+  if (mapping.has_value()) {
+    if (numTiledDims != mapping->size()) {
+      return rewriter.notifyMatchFailure(exportOp,
+                                         "number of mapping elements must "
+                                         "match number of non-zero tile sizes");
+    }
+    for (DeviceMappingAttrInterface map : mapping.value())
+      mappingPermutation.push_back(map.getMappingId());
+  } else {
+    // No mapping specified: No permutation.
+    for (int64_t i = 0; i < numTiledDims; ++i) mappingPermutation.push_back(i);
+  }
+
+  // Compute number of workgroups.
+  SmallVector<OpFoldResult> workgroupCount(3, rewriter.getIndexAttr(1));
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(workgroupCountOp);
   loc = workgroupCountOp.getLoc();
-  for (auto tileSize : llvm::enumerate(unpackedTileSizes)) {
-    if (isConstantIntValue(tileSize.value(), 0)) {
-      workgroupCount.push_back(workload[tileSize.index()]);
-      continue;
-    }
+  int64_t nextTiledDim = 0;
+  for (int64_t workgroupsDim : mappingPermutation) {
+    // Skip dims with tile size 0. These are not tiled.
+    while (isConstantIntValue(unpackedTileSizes[nextTiledDim], 0))
+      ++nextTiledDim;
     AffineExpr s0, s1;
     bindSymbols(rewriter.getContext(), s0, s1);
     auto m = AffineMap::get(0, 2, s0.ceilDiv(s1));
-    OpFoldResult count = makeComposedFoldedAffineApply(
+    workgroupCount[workgroupsDim] = makeComposedFoldedAffineApply(
         rewriter, loc, m,
-        ArrayRef<OpFoldResult>{workload[tileSize.index()], tileSize.value()});
-    workgroupCount.push_back(count);
+        ArrayRef<OpFoldResult>{workload[nextTiledDim],
+                               unpackedTileSizes[nextTiledDim]});
+    ++nextTiledDim;
   }
-  // Make sure to fill unused dimensions with 1
-  workgroupCount.resize(3, rewriter.getIndexAttr(1));
-  permutedWorkgroupCount.resize(3, rewriter.getIndexAttr(1));
-  int mappingId = 0;
-  for (DeviceMappingAttrInterface map : mapping->getValue()) {
-    permutedWorkgroupCount[map.getMappingId()] = workgroupCount[mappingId++];
-  }
-  rewriter.replaceOp(
-      workgroupCountOp,
-      getValueOrCreateConstantIndexOp(rewriter, loc, permutedWorkgroupCount));
+
+  rewriter.replaceOp(workgroupCountOp, getValueOrCreateConstantIndexOp(
+                                           rewriter, loc, workgroupCount));
   return success();
 }
 
diff --git a/tests/transform_dialect/cuda/vecadd2d.mlir b/tests/transform_dialect/cuda/vecadd2d.mlir
index 9ddd270..41b7c9a 100644
--- a/tests/transform_dialect/cuda/vecadd2d.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d.mlir
@@ -40,6 +40,15 @@
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec.mlir | \
 // RUN: FileCheck %s --check-prefix=CHECK
 
+// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN:     --iree-abi-transformation-pipeline \
+// RUN:     --iree-flow-transformation-pipeline  \
+// RUN:     --iree-stream-transformation-pipeline \
+// RUN:     --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec_partial_tile.mlir | \
+// RUN: FileCheck %s --check-prefix=CHECK-PARTIAL-TILE
+
 // RUN: iree-compile %s --iree-hal-target-backends=cuda \
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/vecadd2d_codegen_spec.mlir | \
 // RUN: iree-run-module --entry_function=vecadd2d --device=cuda |\
@@ -56,6 +65,11 @@
 //     CHECK:  %[[BLKX:.*]] = hal.interface.workgroup.id[0] : index
 //     CHECK:  memref.subview %0[%[[BLKZ:.*]], %[[BLKX:.*]]]
 
+//     CHECK-PARTIAL-TILE:  hal.executable.export 
+//     CHECK-PARTIAL-TILE:  bb0(%[[DEV:.*]]: !hal.device, %[[A1:.*]]: index, %[[A2:.*]]: index):
+//     CHECK-PARTIAL-TILE:  %[[c1:.*]] = arith.constant 1 : index
+//     CHECK-PARTIAL-TILE:  %[[dim:.*]] = affine.apply #map()[%[[A2]]]
+//     CHECK-PARTIAL-TILE:  hal.return %[[c1]], %[[c1]], %[[dim]] : index, index, index
 
 //      EXEC: EXEC @vecadd2d
 //      EXEC: result[0]: hal.buffer_view
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
new file mode 100644
index 0000000..75e2dad
--- /dev/null
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -0,0 +1,7 @@
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+  %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op
+  // Tile only one dimension, skip the other one.
+  transform.iree.tile_to_foreach_thread_and_workgroup_count_region %generics 
+                  tile_sizes [0, 3] ( mapping = [#gpu.block<z>])
+}