[Codegen][Common] Resolve `scf.forall` that are used for workgroup distribution (#18368)

The `scf.forall` for workgroup distribution need to be resolved after
startegy lowering. Since this needs an update to the entry point region,
the `scf.forall` resolution is now added with
`ReconcileTranslationInfoPass`. If needed this could be moved out of
here.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
index e8e18f5..d7a81fc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
@@ -16,7 +16,10 @@
 
 #include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 namespace mlir::iree_compiler {
 
 #define GEN_PASS_DEF_RECONCILETRANSLATIONINFOPASS
@@ -32,6 +35,270 @@
 };
 } // namespace
 
+//===---------------------------------------------------------------------===//
+// Resolve `scf.forall` operations
+//===---------------------------------------------------------------------===//
+
+/// Verify that the mapping attribute provided is of the right form.
+static FailureOr<SmallVector<IREE::Codegen::WorkgroupMappingAttr>>
+verifyWorkgroupMappingAttrArray(scf::ForallOp forallOp) {
+  std::optional<ArrayAttr> mappingAttr = forallOp.getMapping();
+  if (!mappingAttr) {
+    return forallOp.emitOpError("expected mapped for all op");
+  }
+  if (mappingAttr.value().empty()) {
+    return forallOp.emitOpError("mapping attribute cannot be empty");
+  }
+  if (failed(IREE::Codegen::WorkgroupMappingAttr::verifyAttrList(
+          forallOp.getContext(), forallOp.getLoc(), mappingAttr->getValue()))) {
+    return failure();
+  }
+  SmallVector<IREE::Codegen::WorkgroupMappingAttr> workgroupMappingAttrs =
+      llvm::map_to_vector(mappingAttr.value(), [](Attribute attr) {
+        return cast<IREE::Codegen::WorkgroupMappingAttr>(attr);
+      });
+  return workgroupMappingAttrs;
+}
+
+/// Get the permutation that represents the mapping of loop dimensions to
+/// process dimensions.
+SmallVector<int64_t>
+getMappingPermutation(ArrayRef<IREE::Codegen::WorkgroupMappingAttr> mapping) {
+  return invertPermutationVector(llvm::map_to_vector(mapping, [&](auto a) {
+    return static_cast<int64_t>(mapping.size() - 1) - a.getMappingId();
+  }));
+}
+
+/// Return the procId and nprocs to use for each of the distributed loops,
+/// derived from `hal.interface.workgroup.id/count`s.
+static FailureOr<
+    std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>>
+getProcIdsAndNprocs(
+    scf::ForallOp forallOp, RewriterBase &builder, Location loc,
+    SmallVector<IREE::Codegen::WorkgroupMappingAttr> workgroupMappings,
+    SmallVector<OpFoldResult> lowerBounds,
+    SmallVector<OpFoldResult> upperBounds, SmallVector<OpFoldResult> steps) {
+  if (workgroupMappings.size() != lowerBounds.size()) {
+    return forallOp.emitOpError(
+        "expected as many workgroup mapping attributes as number of loops");
+  }
+
+  auto permutation = getMappingPermutation(workgroupMappings);
+  applyPermutationToVector(workgroupMappings, permutation);
+  applyPermutationToVector(lowerBounds, permutation);
+  applyPermutationToVector(upperBounds, permutation);
+  applyPermutationToVector(steps, permutation);
+
+  SmallVector<OpFoldResult> procId(workgroupMappings.size(),
+                                   builder.getIndexAttr(0));
+  SmallVector<OpFoldResult> nprocs(workgroupMappings.size(),
+                                   builder.getIndexAttr(1));
+
+  AffineExpr s0, s1, s2;
+  bindSymbols(builder.getContext(), s0, s1, s2);
+  AffineExpr extentExpr = (s1 - s0).ceilDiv(s2);
+  IREE::Codegen::WorkgroupMappingAttr baseZDim =
+      IREE::Codegen::WorkgroupMappingAttr::get(builder.getContext(),
+                                               IREE::Codegen::WorkgroupId::IdZ);
+  SmallVector<OpFoldResult> loopExtents;
+  if (workgroupMappings.size() > baseZDim.getMappingId()) {
+    loopExtents.resize(workgroupMappings.size() - baseZDim.getMappingId());
+  }
+  for (int index = workgroupMappings.size() - 1; index >= 0; --index) {
+    auto workgroupMapping = workgroupMappings[index];
+    auto lowerBound = lowerBounds[index];
+    auto upperBound = upperBounds[index];
+    auto step = steps[index];
+    switch (workgroupMapping.getId()) {
+    case IREE::Codegen::WorkgroupId::IdX:
+      procId[index] =
+          builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(loc, 0).getResult();
+      nprocs[index] =
+          builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0)
+              .getResult();
+      break;
+    case IREE::Codegen::WorkgroupId::IdY:
+      procId[index] =
+          builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(loc, 1).getResult();
+      nprocs[index] =
+          builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1)
+              .getResult();
+      break;
+    case IREE::Codegen::WorkgroupId::IdZ: {
+      OpFoldResult extent = affine::makeComposedFoldedAffineApply(
+          builder, loc, extentExpr, {lowerBound, upperBound, step});
+      loopExtents[index] = extent;
+      break;
+    }
+    }
+  }
+
+  // Delinearize the z-dim based on the loop extents.
+  if (!loopExtents.empty()) {
+    Value zDimId =
+        builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(loc, 2).getResult();
+    OpFoldResult zNprocs =
+        builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 2)
+            .getResult();
+
+    if (loopExtents.size() != 1) {
+      auto delinearizeOp = builder.create<affine::AffineDelinearizeIndexOp>(
+          loc, zDimId, loopExtents);
+      SmallVector<OpFoldResult> orderedDelinearizedDimIds =
+          llvm::map_to_vector(delinearizeOp.getResults(),
+                              [](Value v) -> OpFoldResult { return v; });
+      SmallVector<OpFoldResult> orderedDelinearizedNprocs;
+      AffineMap minMap = AffineMap::get(0, 2, {s0, s1}, builder.getContext());
+      AffineExpr ceilDivExpr = s0.ceilDiv(s1);
+      for (int index = loopExtents.size() - 1; index >= 0; --index) {
+        auto extent = loopExtents[index];
+        procId[index] = delinearizeOp->getResult(index);
+        OpFoldResult currNprocs = affine::makeComposedFoldedAffineMin(
+            builder, loc, minMap, {extent, zNprocs});
+        nprocs[index] = currNprocs;
+        zNprocs = affine::makeComposedFoldedAffineApply(
+            builder, loc, ceilDivExpr, {zNprocs, currNprocs});
+      }
+    } else {
+      // If there is only one z-dim mapping, just use the ID directly.
+      procId[0] = zDimId;
+      nprocs[0] = zNprocs;
+    }
+  }
+
+  auto inversePermutation = invertPermutationVector(permutation);
+  applyPermutationToVector(procId, inversePermutation);
+  applyPermutationToVector(nprocs, inversePermutation);
+  return std::make_pair(procId, nprocs);
+}
+
+/// Resolve scf.forall operation by using the workgroup ID and counts.
+static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
+                                            scf::ForallOp forallOp) {
+  if (forallOp->getNumResults() != 0) {
+    return forallOp.emitOpError(
+        "cannot resolve for all ops with return values");
+  }
+  SmallVector<OpFoldResult> mixedLowerBound = forallOp.getMixedLowerBound();
+  SmallVector<OpFoldResult> mixedUpperBound = forallOp.getMixedUpperBound();
+  SmallVector<OpFoldResult> mixedStep = forallOp.getMixedStep();
+  FailureOr<SmallVector<IREE::Codegen::WorkgroupMappingAttr>> workgroupMapping =
+      verifyWorkgroupMappingAttrArray(forallOp);
+  if (failed(workgroupMapping)) {
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+
+  SmallVector<OpFoldResult> procId;
+
+  {
+    FailureOr<std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>>
+        procInfo =
+            getProcIdsAndNprocs(forallOp, rewriter, forallOp.getLoc(),
+                                workgroupMapping.value(), mixedLowerBound,
+                                mixedUpperBound, mixedStep);
+    if (failed(procInfo)) {
+      return failure();
+    }
+    std::swap(procId, procInfo->first);
+  }
+
+  /// For now this is assuming that number of workgroups is exactly equal to
+  /// the iterations for each loop dimension. Just inline the forall body into
+  /// the parent.
+  Block *parentBlock = forallOp->getBlock();
+  Block *remainingBlock =
+      rewriter.splitBlock(parentBlock, Block::iterator(forallOp));
+  auto argReplacements =
+      getValueOrCreateConstantIndexOp(rewriter, forallOp.getLoc(), procId);
+  Block *loopBody = forallOp.getBody();
+  rewriter.eraseOp(loopBody->getTerminator());
+  rewriter.mergeBlocks(loopBody, parentBlock, argReplacements);
+  rewriter.mergeBlocks(remainingBlock, parentBlock, ValueRange{});
+  rewriter.eraseOp(forallOp);
+  return success();
+}
+
+static LogicalResult resolveWorkgroupCount(RewriterBase &rewriter,
+                                           mlir::FunctionOpInterface funcOp,
+                                           scf::ForallOp forAllOp) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(forAllOp);
+  SmallVector<OpFoldResult> lowerBounds = forAllOp.getMixedLowerBound();
+  SmallVector<OpFoldResult> upperBounds = forAllOp.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = forAllOp.getMixedStep();
+  SmallVector<OpFoldResult> workgroupCount(lowerBounds.size());
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr countExpr = (s1 - s0).ceilDiv(s2);
+  for (auto [index, lb, ub, step] :
+       llvm::enumerate(lowerBounds, upperBounds, steps)) {
+    workgroupCount[index] = affine::makeComposedFoldedAffineApply(
+        rewriter, forAllOp.getLoc(), countExpr, {lb, ub, step});
+  }
+  auto mappingAttr =
+      llvm::map_to_vector(forAllOp.getMapping().value(), [](auto a) {
+        return cast<IREE::Codegen::WorkgroupMappingAttr>(a);
+      });
+  auto permutation = getMappingPermutation(mappingAttr);
+  workgroupCount = applyPermutation(workgroupCount, permutation);
+  return lowerWorkgroupCountFromSliceOp(rewriter, funcOp, workgroupCount);
+}
+
+static LogicalResult resolveWorkgroupForAll(RewriterBase &rewriter,
+                                            FunctionOpInterface funcOp) {
+  Region &body = funcOp.getFunctionBody();
+
+  if (body.empty()) {
+    return success();
+  }
+
+  if (!llvm::hasSingleElement(body)) {
+    return funcOp.emitOpError("unhandled function with multiple blocks");
+  }
+
+  auto forAllOps = body.getOps<scf::ForallOp>();
+  SmallVector<scf::ForallOp> workgroupForAllOps = llvm::to_vector(
+      llvm::make_filter_range(forAllOps, [&](scf::ForallOp forAllOp) {
+        auto mapping = forAllOp.getMapping();
+        if (!mapping) {
+          return false;
+        }
+        if (!llvm::all_of(mapping.value(), [](Attribute attr) {
+              return isa<IREE::Codegen::WorkgroupMappingAttr>(attr);
+            })) {
+          return false;
+        }
+        return true;
+      }));
+
+  if (workgroupForAllOps.empty()) {
+    // If there are no workgroup distribution loops, set the default
+    // number of workgroups to {1, 1, 1}. Note: that this only kicks
+    // in if the export op region has
+    // `flow.dispatch.workgroup_count_from_slice
+    return lowerWorkgroupCountFromSliceOp(rewriter, funcOp,
+                                          ArrayRef<OpFoldResult>{});
+  }
+  if (!llvm::hasSingleElement(workgroupForAllOps)) {
+    return funcOp.emitOpError("unhandled resolution of zero/multiple "
+                              "scf.forall ops withing the function");
+  }
+
+  scf::ForallOp forallOp = *forAllOps.begin();
+  if (failed(resolveWorkgroupCount(rewriter, funcOp, forallOp))) {
+    return failure();
+  }
+
+  return resolveWorkgroupForAll(rewriter, *forAllOps.begin());
+}
+
+//===---------------------------------------------------------------------===//
+// End Resolve `scf.forall` operations
+//===---------------------------------------------------------------------===//
+
 // Reconcile workgroup sizes across all translation infos.
 static FailureOr<SmallVector<int64_t>> reconcileWorkgroupSize(
     ArrayRef<IREE::Codegen::TranslationInfoAttr> translationInfos) {
@@ -89,25 +356,37 @@
     return signalPassFailure();
   }
   auto exportOp = *exportOps.begin();
-  Builder builder(&getContext());
+  IRRewriter rewriter(&getContext());
 
   SmallVector<IREE::Codegen::TranslationInfoAttr> translationInfos;
-  innerModuleOp->walk([&](FunctionOpInterface funcOp) {
-    auto translationInfo = getTranslationInfo(funcOp);
-    if (!translationInfo) {
-      return;
-    }
+  auto walkResult =
+      innerModuleOp->walk([&](FunctionOpInterface funcOp) -> WalkResult {
+        // Resolve workgroup distribution related `scf.forall` ops.
+        if (failed(resolveWorkgroupForAll(rewriter, funcOp))) {
+          return failure();
+        }
 
-    translationInfos.push_back(translationInfo);
-    // The following is moving the target-func-attrs specification from
-    // translation info into the func-like op. This is not the best
-    // place to do this, but the intent is after this pass all the
-    // lowering configs and translation infos will be deleted.
-    DictionaryAttr targetFuncAttrs = getTargetFuncAttrs(translationInfo);
-    if (targetFuncAttrs) {
-      funcOp->setAttr("llvm_func_attrs", targetFuncAttrs);
-    }
-  });
+        auto translationInfo = getTranslationInfo(funcOp);
+        if (!translationInfo) {
+          return WalkResult::advance();
+        }
+
+        translationInfos.push_back(translationInfo);
+        // The following is moving the target-func-attrs specification from
+        // translation info into the func-like op. This is not the best
+        // place to do this, but the intent is after this pass all the
+        // lowering configs and translation infos will be deleted.
+        DictionaryAttr targetFuncAttrs = getTargetFuncAttrs(translationInfo);
+        if (targetFuncAttrs) {
+          funcOp->setAttr("llvm_func_attrs", targetFuncAttrs);
+        }
+        return WalkResult::advance();
+      });
+  if (walkResult.wasInterrupted()) {
+    variantOp.emitOpError(
+        "failed in iree-codegen-reconcile-translation-info pass");
+    return signalPassFailure();
+  }
 
   // Reconcile workgroup sizes.
   FailureOr<SmallVector<int64_t>> reconciledWorkgroupSize =
@@ -125,7 +404,7 @@
   for (auto [index, size] : llvm::enumerate(reconciledWorkgroupSize.value())) {
     workgroupSize[index] = size;
   }
-  auto workgroupSizeArrayAttr = builder.getIndexArrayAttr(workgroupSize);
+  auto workgroupSizeArrayAttr = rewriter.getIndexArrayAttr(workgroupSize);
   exportOp.setWorkgroupSizeAttr(workgroupSizeArrayAttr);
 
   // Reconcile subgroup sizes.
@@ -137,7 +416,7 @@
   }
   if (reconciledSubgroupSize.value() != int64_t()) {
     exportOp.setSubgroupSizeAttr(
-        builder.getIndexAttr(reconciledSubgroupSize.value()));
+        rewriter.getIndexAttr(reconciledSubgroupSize.value()));
   }
 
   // Erase all the lowering configs and translation infos.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 66cfa40..3f19b22 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/IR/StorageUniquerSupport.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir::iree_compiler {
@@ -279,8 +278,7 @@
   SmallVector<Attribute> deviceMappingAttribute =
       getMapping(context, tilingInfo->tileSizes);
   if (failed(IREE::Codegen::WorkgroupMappingAttr::verifyAttrList(
-          context, ::mlir::detail::getDefaultDiagnosticEmitFn(funcOp.getLoc()),
-          deviceMappingAttribute))) {
+          context, funcOp.getLoc(), deviceMappingAttribute))) {
     return signalPassFailure();
   }
   tilingOptions.setMapping(deviceMappingAttribute);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
index 18c1677..cc8aa9d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info)))" %s --verify-diagnostics | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-reconcile-translation-info, canonicalize)))" %s --verify-diagnostics --allow-unregistered-dialect | FileCheck %s
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>
@@ -160,3 +160,272 @@
 // CHECK-LABEL: hal.executable private @llvm_func_attrs
 //       CHECK:   func.func @fn1() attributes {llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}}
 //       CHECK:   func.func @fn2() attributes {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
+    #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_2D {
+  hal.executable.variant public @scf_forall_2D target(#hal.executable.target<"", "", {}>) {
+    hal.executable.export public @scf_forall_2D layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @scf_forall_2D() {
+        %cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+        %cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+        %cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
+        %0 = flow.dispatch.workload.ordinal %cst0, 0 : index
+        %1 = flow.dispatch.workload.ordinal %cst1, 1 : index
+        %2 = flow.dispatch.workload.ordinal %cst2, 2 : index
+        scf.forall (%arg0, %arg1) = (0, 0) to (%0, %1) step(64, 32) {
+          "use"(%arg0, %arg1) : (index, index) -> ()
+          scf.forall.in_parallel {}
+        } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
+//      CHECK: hal.executable.export public @scf_forall_2D layout
+// CHECK-NEXT:     %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-z0-9]+]]: index
+//  CHECK-DAG:   %[[WG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+//  CHECK-DAG:   %[[WG_X:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]]]
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
+//      CHECK: func @scf_forall_2D()
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-NOT:   scf.forall
+//      CHECK:   "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
+    #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_2D_dynamic_tile_size {
+  hal.executable.variant public @scf_forall_2D_dynamic_tile_size target(#hal.executable.target<"", "", {}>) {
+    hal.executable.export public @scf_forall_2D_dynamic_tile_size layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index, %arg4 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @scf_forall_2D_dynamic_tile_size() {
+        %cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+        %cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+        %cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
+        %cst3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
+        %0 = flow.dispatch.workload.ordinal %cst0, 0 : index
+        %1 = flow.dispatch.workload.ordinal %cst1, 1 : index
+        %2 = flow.dispatch.workload.ordinal %cst2, 2 : index
+        %3 = flow.dispatch.workload.ordinal %cst3, 3 : index
+        scf.forall (%arg0, %arg1) = (0, 0) to (%0, %1) step(%2, %3) {
+          "use"(%arg0, %arg1) : (index, index) -> ()
+          scf.forall.in_parallel {}
+        } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)
+//      CHECK: hal.executable.export public @scf_forall_2D_dynamic_tile_size layout
+// CHECK-NEXT:     %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-z0-9]+]]: index
+//  CHECK-DAG:   %[[WG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[WG_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG4]]]
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   hal.return %[[WG_X]], %[[WG_Y]], %[[C1]]
+//      CHECK: func @scf_forall_2D_dynamic_tile_size()
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-NOT:   scf.forall
+//      CHECK:   "use"(%[[WG_ID_Y]], %[[WG_ID_X]])
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 12, bindings = [
+    #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_4D {
+  hal.executable.variant public @scf_forall_4D target(#hal.executable.target<"", "", {}>) {
+    hal.executable.export public @scf_forall_4D layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index, %arg4 : index,
+         %arg5 : index, %arg6 : index, %arg7 : index, %arg8 : index,
+         %arg9 : index, %arg10 : index, %arg11 : index, %arg12 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3,
+          %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @scf_forall_4D() {
+        %cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+        %cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
+        %cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
+        %cst3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
+        %cst4 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : index
+        %cst5 = hal.interface.constant.load layout(#pipeline_layout) ordinal(5) : index
+        %cst6 = hal.interface.constant.load layout(#pipeline_layout) ordinal(6) : index
+        %cst7 = hal.interface.constant.load layout(#pipeline_layout) ordinal(7) : index
+        %cst8 = hal.interface.constant.load layout(#pipeline_layout) ordinal(8) : index
+        %cst9 = hal.interface.constant.load layout(#pipeline_layout) ordinal(9) : index
+        %cst10 = hal.interface.constant.load layout(#pipeline_layout) ordinal(10) : index
+        %cst11 = hal.interface.constant.load layout(#pipeline_layout) ordinal(11) : index
+        %0 = flow.dispatch.workload.ordinal %cst0, 0 : index
+        %1 = flow.dispatch.workload.ordinal %cst1, 1 : index
+        %2 = flow.dispatch.workload.ordinal %cst2, 2 : index
+        %3 = flow.dispatch.workload.ordinal %cst3, 3 : index
+        %4 = flow.dispatch.workload.ordinal %cst4, 4 : index
+        %5 = flow.dispatch.workload.ordinal %cst5, 5 : index
+        %6 = flow.dispatch.workload.ordinal %cst6, 6 : index
+        %7 = flow.dispatch.workload.ordinal %cst7, 7 : index
+        %8 = flow.dispatch.workload.ordinal %cst8, 8 : index
+        %9 = flow.dispatch.workload.ordinal %cst9, 9 : index
+        %10 = flow.dispatch.workload.ordinal %cst10, 10 : index
+        %11 = flow.dispatch.workload.ordinal %cst11, 11 : index
+        scf.forall (%arg0, %arg1, %arg2, %arg3) = (%0, %1, %2, %3) to (%4, %5, %6, %7) step(%8, %9, %10, %11) {
+          "use"(%arg0, %arg1, %arg2, %arg3) : (index, index, index, index) -> ()
+          scf.forall.in_parallel {}
+        } {mapping = [#iree_codegen.workgroup_mapping<z:1>,
+                      #iree_codegen.workgroup_mapping<z:0>,
+                      #iree_codegen.workgroup_mapping<y>,
+                      #iree_codegen.workgroup_mapping<x>]}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s0 + s1) ceildiv s2) * ((-s3 + s4) ceildiv s5))>
+//      CHECK: hal.executable.export public @scf_forall_4D layout
+// CHECK-NEXT:     %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG5:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG6:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG7:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG8:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG9:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG10:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG11:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG12:[a-zA-z0-9]+]]: index
+//  CHECK-DAG:   %[[WG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[ARG7]], %[[ARG11]]]
+//  CHECK-DAG:   %[[WG_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG4]], %[[ARG8]], %[[ARG12]]]
+//  CHECK-DAG:   %[[WG_Z:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]], %[[ARG10]], %[[ARG1]], %[[ARG5]], %[[ARG9]]]
+//      CHECK:   hal.return %[[WG_X]], %[[WG_Y]], %[[WG_Z]]
+//      CHECK: func @scf_forall_4D()
+//  CHECK-DAG:   %[[LB0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(0)
+//  CHECK-DAG:   %[[LB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(1)
+//  CHECK-DAG:   %[[UB0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(4)
+//  CHECK-DAG:   %[[UB1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(5)
+//  CHECK-DAG:   %[[STEP0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(8)
+//  CHECK-DAG:   %[[STEP1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(9)
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[NITERS1:.+]] = affine.apply #[[MAP0]]()[%[[LB1]], %[[UB1]], %[[STEP1]]]
+//  CHECK-DAG:   %[[NITERS0:.+]] = affine.apply #[[MAP0]]()[%[[LB0]], %[[UB0]], %[[STEP0]]]
+//  CHECK-DAG:   %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
+//  CHECK-NOT:   scf.forall
+//      CHECK:   %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[WG_ID_Z]] into (%[[NITERS0]], %[[NITERS1]])
+//      CHECK:   "use"(%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[WG_ID_Y]], %[[WG_ID_X]])
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 0, bindings = [
+    #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @scf_forall_4D_static_interchange {
+  hal.executable.variant public @scf_forall_4D_static_interchange target(#hal.executable.target<"", "", {}>) {
+    hal.executable.export public @scf_forall_4D_static_interchange layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @scf_forall_4D_static_interchange() {
+        scf.forall (%arg0, %arg1, %arg2, %arg3, %arg4) = (0, 1, 2, 3, 4) to (4, 10, 19, 29, 44) step(1, 2, 3, 4, 5) {
+          "use"(%arg0, %arg1, %arg2, %arg3, %arg4) : (index, index, index, index, index) -> ()
+          scf.forall.in_parallel {}
+        } {mapping = [#iree_codegen.workgroup_mapping<z:0>,
+                      #iree_codegen.workgroup_mapping<z:2>,
+                      #iree_codegen.workgroup_mapping<x>,
+                      #iree_codegen.workgroup_mapping<y>,
+                      #iree_codegen.workgroup_mapping<z:1>]}
+        return
+      }
+    }
+  }
+}
+//      CHECK: hal.executable.export public @scf_forall_4D_static_interchange layout
+//  CHECK-DAG:   %[[C6:.+]] = arith.constant 6 : index
+//  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
+//  CHECK-DAG:   %[[C160:.+]] = arith.constant 160 : index
+//      CHECK:   hal.return %[[C6]], %[[C7]], %[[C160]]
+//      CHECK: func @scf_forall_4D_static_interchange()
+//  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//  CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
+//  CHECK-NOT:   scf.forall
+//      CHECK:   %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]])
+//      CHECK:   "use"(%[[DELINEARIZE]]#2, %[[DELINEARIZE]]#0, %[[WG_ID_X]], %[[WG_ID_Y]], %[[DELINEARIZE]]#1)
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 0, bindings = [
+    #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @no_loop_do_nothing {
+  hal.executable.variant public @no_loop_do_nothing target(#hal.executable.target<"", "", {}>) {
+    hal.executable.export public @no_loop_do_nothing layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device):
+      %c1 = arith.constant 1: index
+      %c2 = arith.constant 2: index
+      hal.return %c1, %c2, %c1 : index, index, index
+    }
+    builtin.module {
+      func.func @no_loop_do_nothing() {
+        return
+      }
+    }
+  }
+}
+//      CHECK: hal.executable.export public @no_loop_do_nothing layout
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   hal.return %[[C1]], %[[C2]], %[[C1]]
+//      CHECK: func @no_loop_do_nothing()
+// CHECK-NEXT:   return
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 0, bindings = [
+    #hal.pipeline.binding<storage_buffer>]>
+hal.executable private @no_loop_default_workgroup_count {
+  hal.executable.variant public @no_loop_default_workgroup_count target(#hal.executable.target<"", "", {}>) {
+    hal.executable.export public @no_loop_default_workgroup_count layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %0:3 = flow.dispatch.workgroup_count_from_slice %arg1, %arg2
+      hal.return %0#1, %0#2, %0#0 : index, index, index
+    }
+    builtin.module {
+      func.func @no_loop_default_workgroup_count() {
+        return
+      }
+    }
+  }
+}
+//      CHECK: hal.executable.export public @no_loop_default_workgroup_count layout
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   hal.return %[[C1]], %[[C1]], %[[C1]]
+//      CHECK: func @no_loop_default_workgroup_count()
+// CHECK-NEXT:   return
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index 5f4be34..3ba545e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StorageUniquerSupport.h"
 
 #define GET_ATTRDEF_CLASSES
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp.inc"
@@ -395,14 +396,15 @@
   return getMappingId() < rhs.getMappingId();
 }
 
-LogicalResult WorkgroupMappingAttr::verifyAttrList(
-    MLIRContext *context, function_ref<InFlightDiagnostic()> emitError,
-    ArrayRef<Attribute> attrs) {
+LogicalResult WorkgroupMappingAttr::verifyAttrList(MLIRContext *context,
+                                                   Location loc,
+                                                   ArrayRef<Attribute> attrs) {
   if (attrs.empty()) {
     return success();
   }
   SmallVector<IREE::Codegen::WorkgroupMappingAttr> mappingAttrs;
   llvm::SmallDenseSet<IREE::Codegen::WorkgroupMappingAttr, 4> attrSet;
+  auto emitError = mlir::detail::getDefaultDiagnosticEmitFn(loc);
   for (auto attr : attrs) {
     auto typedAttr =
         ::mlir::dyn_cast_or_null<IREE::Codegen::WorkgroupMappingAttr>(attr);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 6190b6c..48c3917 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -188,8 +188,7 @@
   let extraClassDeclaration = [{
     // Checks that a list of attributes is well-defined.
     static LogicalResult verifyAttrList(::mlir::MLIRContext *context,
-      ::llvm::function_ref<::mlir::InFlightDiagnostic ()> emitError,
-      ArrayRef<Attribute> attrs);
+      Location loc, ArrayRef<Attribute> attrs);
 
     // Convert from mapping ID to attribute.
     static ::mlir::iree_compiler::IREE::Codegen::WorkgroupMappingAttr
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f1c7ac1..c95c484 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1137,10 +1137,12 @@
 
 void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager,
                                      bool useROCM) {
-  OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
-  modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
-  FunctionLikeNest(modulePassManager)
-      .addPass(createLLVMGPULowerExecutableTargetPass);
+  {
+    OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
+    modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
+    FunctionLikeNest(modulePassManager)
+        .addPass(createLLVMGPULowerExecutableTargetPass);
+  }
   variantPassManager.addPass(createReconcileTranslationInfoPass());
   //===--------------------------------------------------------------------===//
   // Convert Linalg ops to LLVM+NVVM/ROCDL ops.
@@ -1149,7 +1151,7 @@
   //   - All Linalg/Loops/GPU/Affine/Standard ops are converted away.
   //   - The module contains the final llvm.module ready to be serialized.
   //===--------------------------------------------------------------------===//
-  addLowerToLLVMGPUPasses(modulePassManager, useROCM);
+  addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(), useROCM);
 
   LLVM_DEBUG({
     llvm::dbgs() << "Using LLVMGPU pass pipeline:\n";
@@ -1181,12 +1183,16 @@
 }
 
 void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
-  OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
-  modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
-  FunctionLikeNest(modulePassManager)
-      .addPass(createROCDLLowerExecutableTargetPass);
+  {
+    OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
+    modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass());
+    FunctionLikeNest(modulePassManager)
+        .addPass(createROCDLLowerExecutableTargetPass);
+  }
   variantPassManager.addPass(createReconcileTranslationInfoPass());
-  addLowerToLLVMGPUPasses(modulePassManager, /*forROCDL=*/true);
+
+  addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(),
+                          /*forROCDL=*/true);
 
   LLVM_DEBUG({
     llvm::dbgs() << "Using ROCDL pass pipeline:\n";
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 74763d6..89aed57 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -454,12 +454,11 @@
   std::optional<IREE::HAL::ExecutableExportOp> exportOp =
       getEntryPoint(entryPointFn);
   if (!exportOp) {
-    return entryPointFn.emitOpError(
-        "expected function to be entry point function");
+    return success();
   }
   Block *body = exportOp->getWorkgroupCountBody();
   if (!body) {
-    return exportOp->emitOpError("unexpected empty workgroup count region");
+    return success();
   }
   auto countOps = body->getOps<IREE::Flow::DispatchWorkgroupCountFromSliceOp>();
   if (countOps.empty()) {