[Common][TransformDialect] Fix the usage of num_threads for TileToFor… (#12438)

…allAndWorkgroupCountRegionOp

Prior to this patch, it was impossible to use
`TileToForallAndWorkgroupCountRegionOp` without specifying the tile
sizes. The lowering of the workgroup_count op would be simply skipped
leading to compiler errors down the line.

This patch fixes that by teaching `lowerWorkgroupCountComputingRegion`
(the function that lowers workgroup_count ops) how to get the workgroup
sizes from `num_threads`.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 65e2bcc..b28cb17 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -741,8 +741,8 @@
 /// pdl::OperationType handles on the fly.
 static LogicalResult lowerWorkgroupCountComputingRegion(
     transform::TransformState &state, RewriterBase &rewriter, Location loc,
-    HAL::ExecutableExportOp exportOp, ArrayRef<OpFoldResult> tileSizes,
-    Optional<ArrayAttr> mapping) {
+    HAL::ExecutableExportOp exportOp, ArrayRef<OpFoldResult> numThreads,
+    ArrayRef<OpFoldResult> tileSizes, Optional<ArrayAttr> mapping) {
   Region &r = exportOp.getWorkgroupCount();
   if (!r.hasOneBlock()) {
     return rewriter.notifyMatchFailure(exportOp,
@@ -760,31 +760,37 @@
   auto workgroupCountOp = *workgroupCountOps.begin();
   auto workload = workgroupCountOp.getOperands();
 
-  SmallVector<OpFoldResult> unpackedTileSizes;
+  bool useNumThreads = !numThreads.empty();
+  ArrayRef<OpFoldResult> tileSizesOrNumThreads =
+      useNumThreads ? numThreads : tileSizes;
+  StringRef kindStr = useNumThreads ? "num thread" : "tile size";
+
+  SmallVector<OpFoldResult> unpackedTileSizesOrNumThreads;
   int64_t numTiledDims = 0;
-  for (auto ofr : tileSizes) {
+  for (auto ofr : tileSizesOrNumThreads) {
     if (ofr.is<Value>() &&
         ofr.get<Value>().getType().isa<pdl::OperationType>()) {
       for (Operation *sizeProducer : state.getPayloadOps(ofr.get<Value>())) {
         if (sizeProducer->getNumResults() != 1) {
-          auto diag =
-              mlir::emitDefiniteFailure(sizeProducer)
-              << "the operation producing tile size must have one result";
+          auto diag = mlir::emitDefiniteFailure(sizeProducer)
+                      << "the operation producing " << kindStr
+                      << " must have one result";
           diag.attachNote(loc) << "when applying this transform";
           return diag;
         }
-        unpackedTileSizes.push_back(sizeProducer->getResult(0));
+        unpackedTileSizesOrNumThreads.push_back(sizeProducer->getResult(0));
       }
     } else {
-      unpackedTileSizes.push_back(ofr);
+      unpackedTileSizesOrNumThreads.push_back(ofr);
     }
-    if (!isConstantIntValue(unpackedTileSizes.back(), 0)) ++numTiledDims;
+    if (!isConstantIntValue(unpackedTileSizesOrNumThreads.back(), 0))
+      ++numTiledDims;
   }
 
-  if (unpackedTileSizes.size() > workload.size()) {
+  if (unpackedTileSizesOrNumThreads.size() > workload.size()) {
     return rewriter.notifyMatchFailure(
         exportOp,
-        "number of tile sizes overflow the dimension from the workload");
+        "number of " + kindStr + "s overflow the dimension from the workload");
   }
 
   // Generate permutation of tiled dims based on the specified mapping.
@@ -793,7 +799,8 @@
     if (numTiledDims != mapping->size()) {
       return rewriter.notifyMatchFailure(exportOp,
                                          "number of mapping elements must "
-                                         "match number of non-zero tile sizes");
+                                         "match number of non-zero " +
+                                             kindStr + "s");
     }
     for (DeviceMappingAttrInterface map : mapping.value())
       mappingPermutation.push_back(map.getMappingId());
@@ -810,15 +817,20 @@
   int64_t nextTiledDim = 0;
   for (int64_t workgroupsDim : mappingPermutation) {
     // Skip dims with tile size 0. These are not tiled.
-    while (isConstantIntValue(unpackedTileSizes[nextTiledDim], 0))
+    while (isConstantIntValue(unpackedTileSizesOrNumThreads[nextTiledDim], 0))
       ++nextTiledDim;
-    AffineExpr s0, s1;
-    bindSymbols(rewriter.getContext(), s0, s1);
-    auto m = AffineMap::get(0, 2, s0.ceilDiv(s1));
-    workgroupCount[workgroupsDim] = makeComposedFoldedAffineApply(
-        rewriter, loc, m,
-        ArrayRef<OpFoldResult>{workload[nextTiledDim],
-                               unpackedTileSizes[nextTiledDim]});
+    if (useNumThreads) {
+      workgroupCount[workgroupsDim] =
+          unpackedTileSizesOrNumThreads[nextTiledDim];
+    } else {
+      AffineExpr s0, s1;
+      bindSymbols(rewriter.getContext(), s0, s1);
+      auto m = AffineMap::get(0, 2, s0.ceilDiv(s1));
+      workgroupCount[workgroupsDim] = makeComposedFoldedAffineApply(
+          rewriter, loc, m,
+          ArrayRef<OpFoldResult>{workload[nextTiledDim],
+                                 unpackedTileSizesOrNumThreads[nextTiledDim]});
+    }
     ++nextTiledDim;
   }
 
@@ -882,8 +894,8 @@
   /// regions are created by default in IREEs compilation flow.
   IRRewriter rewriter(getContext());
   if (failed(lowerWorkgroupCountComputingRegion(
-          state, rewriter, getLoc(), exportOp.value(), getMixedTileSizes(),
-          getMapping()))) {
+          state, rewriter, getLoc(), exportOp.value(), getMixedNumThreads(),
+          getMixedTileSizes(), getMapping()))) {
     return mlir::emitDefiniteFailure(exportOp.value(),
                                      "failed to lower workgroup count region");
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
index b52c7ec..19d07e9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -46,6 +46,7 @@
             "synchronize_symbol_visibility.mlir",
             "test_config_mmt4d.mlir",
             "transform_dialect_bufferize.mlir",
+            "transform_dialect_iree_tile_to_forall.mlir",
             "transpose_avx2_lowering.mlir",
             "triple_tiling_expert_pipeline.mlir",
             "unfused_fma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index c8dfb63..29b3167 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -41,6 +41,7 @@
     "synchronize_symbol_visibility.mlir"
     "test_config_mmt4d.mlir"
     "transform_dialect_bufferize.mlir"
+    "transform_dialect_iree_tile_to_forall.mlir"
     "transpose_avx2_lowering.mlir"
     "triple_tiling_expert_pipeline.mlir"
     "unfused_fma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir
new file mode 100644
index 0000000..94e984e
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir
@@ -0,0 +1,59 @@
+// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
+
+// Check that we can specify `num_threads` when lowering
+// `workgroup_count_from_dag_root` using
+// `transform.iree.tile_to_forall_and_workgroup_count_region`
+
+
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
+
+// Check that num_threads (32) is reflected in the map.
+// CHECK: #[[$NUM_THREADS_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
+
+hal.executable private @matmul_static_dispatch_0 {
+  hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+
+    hal.executable.export public @matmul_static_dispatch_0_matmul_1024x4096x12345 ordinal(0) layout(#pipeline_layout) {
+    // Check that num_threads is reflected in the workgroup size.
+    // CHECK-LABEL: hal.executable.export public @matmul_static_dispatch_0_matmul_1024x4096x12345
+    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK: hal.return %[[C32]], %[[C1]], %[[C1]] : index, index, index
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
+      hal.return %x, %y, %z : index, index, index
+    }
+
+    builtin.module {
+      func.func @matmul_static_dispatch_0_matmul_1024x4096x12345() {
+        // Check that the tiling matches num_threads.
+        // CHECK-LABEL: func.func @matmul_static_dispatch_0_matmul_1024x4096x12345
+        // CHECK: = scf.forall (%[[IV:.*]]) in (32) shared_outs(%{{.*}}) -> (tensor<1024x4096xf32>) {
+        // CHECK: %[[OFFSET:.*]] = affine.apply #[[$NUM_THREADS_MAP]](%[[IV]])
+        // CHECK: %extracted_slice = tensor.extract_slice %{{.*}}[%[[OFFSET]], 0] [32, 12345] [1, 1] : tensor<1024x12345xf32> to tensor<32x12345xf32>
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1024x12345xf32>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<12345x4096xf32>>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1024x4096xf32>>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 12345], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1024x12345xf32>> -> tensor<1024x12345xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [12345, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<12345x4096xf32>> -> tensor<12345x4096xf32>
+        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1024x4096xf32>> -> tensor<1024x4096xf32>
+        %6 = linalg.matmul ins(%3, %4 : tensor<1024x12345xf32>, tensor<12345x4096xf32>) outs(%5 : tensor<1024x4096xf32>) -> tensor<1024x4096xf32>
+        flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : tensor<1024x4096xf32> -> !flow.dispatch.tensor<readwrite:tensor<1024x4096xf32>>
+        return
+      }
+    }
+  }
+}
+
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+  %original_matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op
+    : (!pdl.operation) -> !pdl.operation
+
+  %forall, %matmul =
+    transform.iree.tile_to_forall_and_workgroup_count_region %original_matmul
+      num_threads [32]
+      ( mapping = [#gpu.block<x>] )
+}