[Codegen][PCF] Add foldForallIntoPCFLoop for split-k loop handling (#23452)

Adds a generic `foldForallIntoPCFLoop` function that folds an
`scf.forall` containing a `pcf.loop` into a single `pcf.generic`
operation. This enables incorporating user expressed extra levels of
parallelism into a scope. The immediate use case this enables is split-k
by folding the split-k loop into the workgroup loop.

Key changes:
- Add `foldForallIntoPCFLoop` API in PCF/Transforms/Transforms.h
- Implement structural matching helpers (matchFoldTerminator,
matchFoldPCFLoop, matchFoldWriteSlices) that validate requirements
without scope-specific logic
- Add TestFoldForallIntoPCFLoopPass for unit testing with local_mapping
+ sequential scope
- Add FoldSplitKWorkgroupLoop pattern in ConvertWorkgroupForallToPCF.cpp
that matches split_reduction_mapping + workgroup_scope and calls the
generic fold
- Run fold pattern as second pass in ConvertWorkgroupForallToPCFPass

The fold operation:
1. Creates pcf.generic with same scope as inner pcf.loop
2. Linearizes forall iteration space and delinearizes inside generic
3. Converts pcf.loop to a nested scf.forall loop to handle spillover
4. Composes tensor.parallel_insert_slice with pcf.write_slice ops

The reason this pattern has to generate a pcf.generic instead of
pcf.loop is to avoid extra execution of code that was inside
the scf.forall but not inside the pcf.loop.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp
index 80b3830..651109e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp
@@ -9,6 +9,11 @@
 #include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
 #include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -25,6 +30,17 @@
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Folds an scf.forall with split-reduction mapping containing a pcf.loop
+/// with workgroup scope into a single pcf.generic. This handles the "split-k"
+/// pattern where the outer forall represents additional parallelism from
+/// K-dimension splitting and the inner pcf.loop represents the original
+/// workgroup-level iteration.
+struct FoldSplitKWorkgroupLoop : OpRewritePattern<scf::ForallOp> {
+  using Base::Base;
+  LogicalResult matchAndRewrite(scf::ForallOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
 struct ConvertWorkgroupForallToPCFPass final
     : impl::ConvertWorkgroupForallToPCFPassBase<
           ConvertWorkgroupForallToPCFPass> {
@@ -71,10 +87,111 @@
   return success();
 }
 
+LogicalResult
+FoldSplitKWorkgroupLoop::matchAndRewrite(scf::ForallOp op,
+                                         PatternRewriter &rewriter) const {
+  // Match scf.forall with split-reduction mapping.
+  if (!forallOpHasMappingType<IREE::LinalgExt::SplitReductionMappingAttr>(op)) {
+    return failure();
+  }
+
+  // Find pcf.loop with workgroup scope in the forall body.
+  IREE::PCF::LoopOp loopOp = nullptr;
+  for (Operation &bodyOp : op.getBody()->without_terminator()) {
+    if (auto loop = dyn_cast<IREE::PCF::LoopOp>(&bodyOp)) {
+      if (isa<IREE::Codegen::WorkgroupScopeAttr>(loop.getScope())) {
+        loopOp = loop;
+        break;
+      }
+    }
+  }
+  if (!loopOp) {
+    return failure();
+  }
+
+  // Capture values needed for workgroup count computation before folding.
+  // The fold erases the forall op, so any mixed bounds/steps must be
+  // materialized up front.
+  Location loc = op.getLoc();
+  SmallVector<OpFoldResult> lowerBounds = op.getMixedLowerBound();
+  SmallVector<OpFoldResult> upperBounds = op.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = op.getMixedStep();
+  SmallVector<Value> loopCounts(loopOp.getCount());
+
+  // Fold forall + pcf.loop into pcf.generic.
+  FailureOr<IREE::PCF::GenericOp> result =
+      IREE::PCF::foldForallIntoPCFLoop(rewriter, op);
+  if (failed(result)) {
+    return failure();
+  }
+
+  // Compute total workgroup count after folding (forall iterations * loop
+  // count). Generate IR before the pcf.generic.
+  rewriter.setInsertionPoint(*result);
+
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr numItersExpr = (s0 - s1).ceilDiv(s2);
+
+  Value forallCount = nullptr;
+  for (int64_t i = 0, e = upperBounds.size(); i < e; ++i) {
+    OpFoldResult lb = i < (int64_t)lowerBounds.size()
+                          ? lowerBounds[i]
+                          : rewriter.getIndexAttr(0);
+    OpFoldResult ub = upperBounds[i];
+    OpFoldResult step =
+        i < (int64_t)steps.size() ? steps[i] : rewriter.getIndexAttr(1);
+
+    Value iterCount = getValueOrCreateConstantIndexOp(
+        rewriter, loc,
+        affine::makeComposedFoldedAffineApply(rewriter, loc, numItersExpr,
+                                              {ub, lb, step}));
+    if (!forallCount) {
+      forallCount = iterCount;
+    } else {
+      forallCount =
+          arith::MulIOp::create(rewriter, loc, forallCount, iterCount);
+    }
+  }
+
+  Value totalLoopCount = nullptr;
+  for (Value count : loopCounts) {
+    if (!totalLoopCount) {
+      totalLoopCount = count;
+    } else {
+      totalLoopCount =
+          arith::MulIOp::create(rewriter, loc, totalLoopCount, count);
+    }
+  }
+
+  Value totalCount =
+      arith::MulIOp::create(rewriter, loc, forallCount, totalLoopCount);
+
+  // Create workgroup count hint for the folded generic.
+  SmallVector<OpFoldResult> counts = {OpFoldResult(totalCount)};
+  [[maybe_unused]] LogicalResult hintRes = createWorkgroupCountHint(
+      rewriter, loc, counts, /*maxWorkgroupParallelDims=*/1,
+      /*reverse=*/false);
+  assert(succeeded(hintRes) &&
+         "Unexpected failure to construct workgroup count hint");
+
+  return success();
+}
+
 void ConvertWorkgroupForallToPCFPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  patterns.add<ConvertWorkgroupForall>(&getContext());
-  walkAndApplyPatterns(getOperation(), std::move(patterns));
+  // First pass: Convert workgroup foralls to pcf.loop.
+  {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<ConvertWorkgroupForall>(&getContext());
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+
+  // Second pass: Fold split-k foralls containing pcf.loop into pcf.generic.
+  {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<FoldSplitKWorkgroupLoop>(&getContext());
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir
index 7773e4f..b827cfe 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir
@@ -116,3 +116,80 @@
 // CHECK-LABEL: @forall_gpu_thread_mapping_not_converted
 // CHECK: scf.forall
 // CHECK-NOT: pcf.loop
+
+// -----
+
+// Test folding split-reduction forall containing workgroup pcf.loop into pcf.generic.
+func.func @fold_split_reduction_into_pcf_generic(%init: tensor<16xf32>, %slice: tensor<1xf32>) -> tensor<16xf32> {
+  %c4 = arith.constant 4 : index
+  %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<16xf32>) {
+    %tile_init = tensor.extract_slice %iter[%id] [4] [1]
+        : tensor<16xf32> to tensor<4xf32>
+    %loop_result = pcf.loop scope(#iree_codegen.workgroup_scope<linearize>) count(%c4)
+        execute(%ref = %tile_init)[%loop_id: index]
+            : (!pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>)
+           -> (tensor<4xf32>) {
+      pcf.write_slice %slice into %ref[%loop_id] [1] [1]
+          : tensor<1xf32> into !pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>
+      pcf.return
+    }
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %loop_result into %iter[%id] [4] [1]
+          : tensor<4xf32> into tensor<16xf32>
+    }
+  } {mapping = [#iree_linalg_ext.split_reduction_mapping<0>]}
+  return %0 : tensor<16xf32>
+}
+
+// Total workgroup count = forall iterations (4) * loop count (4).
+// The arith.muli computes the product of the forall upper bound and the loop count.
+// CHECK-LABEL: @fold_split_reduction_into_pcf_generic
+//  CHECK-SAME:   %[[INIT:[A-Za-z0-9_]+]]: tensor<16xf32>
+//  CHECK-SAME:   %[[SLICE:[A-Za-z0-9_]+]]: tensor<1xf32>
+//   CHECK-DAG:   %[[C4_A:[a-zA-Z0-9_]+]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[C4_B:[a-zA-Z0-9_]+]] = arith.constant 4 : index
+//       CHECK:   %[[TOTAL:.+]] = arith.muli %[[C4_B]], %[[C4_A]] : index
+//       CHECK:   iree_codegen.workgroup_count_hint(%[[TOTAL]])
+//       CHECK:   %[[GENERIC:.+]] = pcf.generic
+//       CHECK:     scope(#iree_codegen.workgroup_scope<linearize>)
+//       CHECK:     execute(%[[REF:[A-Za-z0-9_]+]] = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %{{.*}}: index]
+//       CHECK:          : (!pcf.sref<16xf32, sync(#iree_codegen.workgroup_scope<linearize>)>)
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into
+//       CHECK:     %[[FORALL_LIN:.+]] = affine.linearize_index disjoint
+//       CHECK:     scf.forall (%{{.+}}) = (%[[FORALL_LIN]])
+//       CHECK:       scf.forall (%{{.+}}) = (%[[DELIN]]#1)
+//       CHECK:         pcf.write_slice %[[SLICE]] into %[[REF]]{{.*}} [1] [1]
+//  CHECK-SAME:           into !pcf.sref<16xf32, sync(#iree_codegen.workgroup_scope<linearize>)>
+//       CHECK:     pcf.return
+//       CHECK:   return %[[GENERIC]]
+
+// -----
+
+// Non-split-reduction mapping should not be folded by the split-k pattern.
+// The workgroup forall is converted to pcf.loop, but the outer forall (with
+// local mapping) should remain unconverted.
+func.func @non_split_reduction_not_folded(%init: tensor<16xf32>, %slice: tensor<1xf32>) -> tensor<16xf32> {
+  %c4 = arith.constant 4 : index
+  %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<16xf32>) {
+    %tile_init = tensor.extract_slice %iter[%id] [4] [1]
+        : tensor<16xf32> to tensor<4xf32>
+    %loop_result = pcf.loop scope(#iree_codegen.workgroup_scope<linearize>) count(%c4)
+        execute(%ref = %tile_init)[%loop_id: index]
+            : (!pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>)
+           -> (tensor<4xf32>) {
+      pcf.write_slice %slice into %ref[%loop_id] [1] [1]
+          : tensor<1xf32> into !pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>
+      pcf.return
+    }
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %loop_result into %iter[%id] [4] [1]
+          : tensor<4xf32> into tensor<16xf32>
+    }
+  } {mapping = [#iree_codegen.local_mapping<0>]}
+  return %0 : tensor<16xf32>
+}
+
+// CHECK-LABEL: @non_split_reduction_not_folded
+//       CHECK:   scf.forall
+//       CHECK:     pcf.loop
+//   CHECK-NOT:   pcf.generic
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel
index 225a744..580ad2e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel
@@ -37,6 +37,7 @@
     srcs = [
         "ConvertForallToPCF.cpp",
         "ConvertSRefToMemRef.cpp",
+        "FoldForallIntoPCFLoop.cpp",
         "FuseConsumers.cpp",
         "FusePCFWrites.cpp",
         "FuseProducers.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt
index 4a9e11e..64b852c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt
@@ -30,6 +30,7 @@
   SRCS
     "ConvertForallToPCF.cpp"
     "ConvertSRefToMemRef.cpp"
+    "FoldForallIntoPCFLoop.cpp"
     "FuseConsumers.cpp"
     "FusePCFWrites.cpp"
     "FuseProducers.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FoldForallIntoPCFLoop.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FoldForallIntoPCFLoop.cpp
new file mode 100644
index 0000000..5e937eb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FoldForallIntoPCFLoop.cpp
@@ -0,0 +1,559 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFAttrs.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFOps.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.h"
+#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h"
+#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir::iree_compiler::IREE::PCF {
+
+#define GEN_PASS_DEF_TESTFOLDFORALLINTOPCFLOOPPASS
+#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// scf.forall + pcf.loop -> pcf.generic fold helpers
+//===----------------------------------------------------------------------===//
+
+/// Validates that the forall terminator has the expected structure for folding:
+/// - All ops are tensor.parallel_insert_slice.
+/// - All insert sources come from the same pcf.loop result.
+/// - All insert destinations are forall shared_outs.
+/// Returns the found pcf.loop on success.
+static FailureOr<PCF::LoopOp> matchFoldTerminator(scf::ForallOp forallOp) {
+  auto terminator =
+      cast<scf::InParallelOp>(forallOp.getRegion().front().getTerminator());
+
+  PCF::LoopOp foundLoop = nullptr;
+
+  for (Operation &op : terminator.getYieldingOps()) {
+    // All ops must be tensor.parallel_insert_slice.
+    auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
+    if (!insertSliceOp) {
+      return failure();
+    }
+
+    // Source must be from a pcf.loop result.
+    auto loopResult = dyn_cast<OpResult>(insertSliceOp.getSource());
+    if (!loopResult) {
+      return failure();
+    }
+
+    auto currentLoop = dyn_cast<PCF::LoopOp>(loopResult.getOwner());
+    if (!currentLoop) {
+      return failure();
+    }
+
+    // All inserts must reference the same pcf.loop.
+    if (foundLoop && foundLoop != currentLoop) {
+      return failure();
+    }
+    foundLoop = currentLoop;
+
+    // Destination must be a shared_out of the forall.
+    auto destArg = dyn_cast<BlockArgument>(insertSliceOp.getDest());
+    if (!destArg || destArg.getOwner() != &forallOp.getRegion().front()) {
+      return failure();
+    }
+
+    // Verify it's a shared_out (comes after induction vars).
+    if (destArg.getArgNumber() < forallOp.getRank()) {
+      return failure();
+    }
+  }
+
+  if (!foundLoop) {
+    return failure();
+  }
+
+  return foundLoop;
+}
+
+/// Validates the pcf.loop structure for folding:
+/// - Single count argument (linearized).
+/// - Loop is last op before terminator.
+static LogicalResult matchFoldPCFLoop(scf::ForallOp forallOp,
+                                      PCF::LoopOp loopOp) {
+  // Single count argument required.
+  if (loopOp.getCount().size() != 1) {
+    return failure();
+  }
+
+  // Loop must be last op before terminator.
+  Operation *lastOp =
+      forallOp.getRegion().front().getTerminator()->getPrevNode();
+  if (lastOp != loopOp) {
+    return failure();
+  }
+
+  return success();
+}
+
+/// Validates pcf.loop region ref args:
+/// - All users are pcf.write_slice ops.
+/// - Ref args have SyncOnReturnAttr sync scope.
+static LogicalResult matchFoldWriteSlices(PCF::LoopOp loopOp) {
+  for (BlockArgument refArg : loopOp.getRegionRefArgs()) {
+    // Check sync scope is sync_on_return.
+    auto srefType = cast<PCF::ShapedRefType>(refArg.getType());
+    Attribute syncScope = srefType.getSyncScope();
+    if (!isa_and_nonnull<PCF::SyncOnReturnAttr>(syncScope)) {
+      return failure();
+    }
+
+    // All users must be write_slice ops.
+    for (Operation *user : refArg.getUsers()) {
+      auto writeOp = dyn_cast<PCF::WriteSliceOp>(user);
+      if (!writeOp) {
+        return failure();
+      }
+    }
+  }
+
+  return success();
+}
+
+/// Computes the iteration count per dimension from forall bounds.
+/// Returns ceildiv(ub - lb, step) for each dimension.
+static SmallVector<OpFoldResult>
+computeForallIterCounts(RewriterBase &rewriter, Location loc,
+                        scf::ForallOp forallOp) {
+  SmallVector<OpFoldResult> lowerBounds = forallOp.getMixedLowerBound();
+  SmallVector<OpFoldResult> upperBounds = forallOp.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
+
+  int64_t numDims = upperBounds.size();
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr numItersExpr = (s0 - s1).ceilDiv(s2);
+
+  SmallVector<OpFoldResult> iterCountOFRs;
+  for (int64_t i = 0, e = numDims; i < e; ++i) {
+    OpFoldResult iterCount = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, numItersExpr,
+        {upperBounds[i], lowerBounds[i], steps[i]});
+    iterCountOFRs.push_back(iterCount);
+  }
+  return iterCountOFRs;
+}
+
+/// Computes actual forall induction variables from delinearized indices
+/// by applying lower bounds and steps: iv = delinearized * step + lb.
+static void computeForallIVs(RewriterBase &rewriter, Location loc,
+                             scf::ForallOp forallOp, ValueRange delinearizedIvs,
+                             IRMapping &forallMapping) {
+  SmallVector<OpFoldResult> lowerBounds = forallOp.getMixedLowerBound();
+  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
+  int64_t numDims = forallOp.getRank();
+
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr applyLbAndStep = s0 * s1 + s2;
+
+  for (int64_t i = 0, e = numDims; i < e; ++i) {
+    Value actualIv = getValueOrCreateConstantIndexOp(
+        rewriter, loc,
+        affine::makeComposedFoldedAffineApply(
+            rewriter, loc, applyLbAndStep,
+            {delinearizedIvs[i], steps[i], lowerBounds[i]}));
+    forallMapping.map(forallOp.getInductionVar(i), actualIv);
+  }
+}
+
+/// Composes pcf.write_slice ops with tensor.parallel_insert_slice ops from
+/// the forall terminator, creating new write_slice ops that write directly
+/// to the pcf.generic's sref arguments.
+///
+/// This reuses composeNestedSliceParameters() by treating the
+/// tensor.parallel_insert_slice as the outer slice and the pcf.write_slice as
+/// the inner slice.
+static void composeWriteSlicesIntoGeneric(RewriterBase &rewriter,
+                                          scf::ForallOp forallOp,
+                                          PCF::LoopOp loopOp,
+                                          PCF::GenericOp genericOp,
+                                          scf::InParallelOp terminator,
+                                          IRMapping &forallMapping) {
+  // Build mapping from pcf.loop result index -> generic ref arg index.
+  SmallVector<unsigned> resultToRefArgIdx(loopOp->getNumResults());
+  for (Operation &op : terminator.getYieldingOps()) {
+    auto insertOp = cast<tensor::ParallelInsertSliceOp>(&op);
+    auto loopResult = cast<OpResult>(insertOp.getSource());
+    unsigned resultIdx = loopResult.getResultNumber();
+    auto destArg = cast<BlockArgument>(insertOp.getDest());
+    unsigned argIdx = destArg.getArgNumber() - forallOp.getRank();
+    resultToRefArgIdx[resultIdx] = argIdx;
+  }
+
+  // Helper to remap OpFoldResult Values through forallMapping.
+  auto remapOFR = [&](OpFoldResult ofr) -> OpFoldResult {
+    if (auto val = dyn_cast<Value>(ofr)) {
+      return forallMapping.lookupOrDefault(val);
+    }
+    return ofr;
+  };
+
+  for (Operation &op :
+       llvm::make_early_inc_range(terminator.getYieldingOps())) {
+    auto insertOp = cast<tensor::ParallelInsertSliceOp>(&op);
+    auto loopResult = cast<OpResult>(insertOp.getSource());
+    unsigned resultIdx = loopResult.getResultNumber();
+
+    BlockArgument genericRefArg =
+        genericOp.getRegionRefArgs()[resultToRefArgIdx[resultIdx]];
+    BlockArgument movedRefArg = loopOp.getRegionRefArgs()[resultIdx];
+
+    // Compose all write_slice ops that write to this ref arg.
+    for (Operation *user : llvm::make_early_inc_range(movedRefArg.getUsers())) {
+      auto writeOp = dyn_cast<PCF::WriteSliceOp>(user);
+      if (!writeOp) {
+        continue;
+      }
+
+      rewriter.setInsertionPoint(writeOp);
+
+      SmallVector<OpFoldResult> writeOffsets = writeOp.getMixedOffsets();
+      SmallVector<OpFoldResult> writeStrides = writeOp.getMixedStrides();
+      SmallVector<OpFoldResult> writeSizes = writeOp.getMixedSizes();
+
+      SmallVector<OpFoldResult> insertOffsets =
+          llvm::map_to_vector(insertOp.getMixedOffsets(), remapOFR);
+      SmallVector<OpFoldResult> insertSizes =
+          llvm::map_to_vector(insertOp.getMixedSizes(), remapOFR);
+      SmallVector<OpFoldResult> insertStrides =
+          llvm::map_to_vector(insertOp.getMixedStrides(), remapOFR);
+
+      SmallVector<OpFoldResult> composedOffsets;
+      SmallVector<OpFoldResult> composedSizes;
+      SmallVector<OpFoldResult> composedStrides;
+      composeNestedSliceParameters(rewriter, writeOp.getLoc(), insertOffsets,
+                                   insertSizes, insertStrides, writeOffsets,
+                                   writeSizes, writeStrides, composedOffsets,
+                                   composedSizes, composedStrides);
+
+      // Expand source to match sref rank by adding unit dims.
+      auto sourceType = cast<RankedTensorType>(writeOp.getSource().getType());
+      auto srefType = cast<PCF::ShapedRefType>(genericRefArg.getType());
+      int64_t srefRank = srefType.getRank();
+      int64_t sourceRank = sourceType.getRank();
+
+      Value expandedSource = writeOp.getSource();
+      if (srefRank > sourceRank) {
+        SmallVector<int64_t> expandedShape;
+        SmallVector<ReassociationIndices> reassociation;
+
+        // First sourceRank-1 dimensions map 1:1.
+        for (int64_t i = 0; i < sourceRank - 1; ++i) {
+          expandedShape.push_back(sourceType.getDimSize(i));
+          reassociation.push_back({i});
+        }
+
+        // Last input dimension expands to include itself plus unit dims.
+        ReassociationIndices lastGroup;
+        if (sourceRank > 0) {
+          expandedShape.push_back(sourceType.getDimSize(sourceRank - 1));
+          lastGroup.push_back(sourceRank - 1);
+        }
+
+        // Add unit dimensions to reach sref rank.
+        int64_t numUnitDims = srefRank - sourceRank;
+        for (int64_t i = 0; i < numUnitDims; ++i) {
+          expandedShape.push_back(1);
+          lastGroup.push_back(sourceRank + i);
+        }
+
+        if (!lastGroup.empty()) {
+          reassociation.push_back(lastGroup);
+        }
+
+        auto expandedType =
+            RankedTensorType::get(expandedShape, sourceType.getElementType());
+        expandedSource = tensor::ExpandShapeOp::create(
+            rewriter, writeOp.getLoc(), expandedType, writeOp.getSource(),
+            reassociation);
+      }
+
+      // Replace old write_slice with composed one.
+      rewriter.replaceOpWithNewOp<PCF::WriteSliceOp>(
+          writeOp, expandedSource, genericRefArg, composedOffsets,
+          composedSizes, composedStrides);
+    }
+
+    rewriter.eraseOp(insertOp);
+  }
+}
+
+/// Core implementation of foldForallIntoPCFLoop after matching succeeds.
+static PCF::GenericOp foldForallIntoPCFLoopImpl(RewriterBase &rewriter,
+                                                scf::ForallOp forallOp,
+                                                PCF::LoopOp loopOp) {
+  Location loc = forallOp.getLoc();
+  scf::InParallelOp terminator = forallOp.getTerminator();
+
+  // Replace RegionIterArgs with initial values (except in terminator).
+  for (auto [iterArg, init] :
+       llvm::zip(forallOp.getRegionIterArgs(), forallOp.getOutputs())) {
+    rewriter.replaceUsesWithIf(iterArg, init, [&](OpOperand &use) {
+      return use.getOwner()->getParentOp() != terminator;
+    });
+  }
+
+  Value loopCount = loopOp.getCount()[0];
+
+  // Create pcf.generic.
+  auto genericOp = PCF::GenericOp::create(
+      rewriter, loc,
+      /*resultTypes=*/forallOp.getResultTypes(),
+      /*scope=*/loopOp.getScope(),
+      /*inits=*/forallOp.getOutputs(),
+      /*dynamic_sizes=*/ValueRange{},
+      /*is_tied=*/SmallVector<bool>(forallOp.getNumResults(), true),
+      /*num_iterators=*/1);
+
+  // Set sync scope to SyncOnReturn for the pcf.generic sref arguments.
+  Attribute syncScope = PCF::SyncOnReturnAttr::get(rewriter.getContext());
+  for (auto regionRefArg : genericOp.getRegionRefArgs()) {
+    auto srefType = cast<PCF::ShapedRefType>(regionRefArg.getType());
+    auto newSrefType = PCF::ShapedRefType::get(
+        rewriter.getContext(), srefType.getShape(), srefType.getElementType(),
+        srefType.getScope(), syncScope);
+    regionRefArg.setType(newSrefType);
+  }
+
+  Block *forallBody = &forallOp.getRegion().front();
+  Block *genericBody = &genericOp.getRegion().front();
+  rewriter.setInsertionPointToStart(genericBody);
+
+  // Compute per-dimension iteration counts.
+  SmallVector<OpFoldResult> iterCountOFRs =
+      computeForallIterCounts(rewriter, loc, forallOp);
+  int64_t numDims = iterCountOFRs.size();
+
+  SmallVector<Value> iterCountValues =
+      llvm::map_to_vector(iterCountOFRs, [&](OpFoldResult ofr) -> Value {
+        return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+      });
+
+  // Delinearize pcf.generic id into (forall dim IVs, pcf loop id) using
+  // a single delinearization with the full basis. This allows
+  // linearize/delinearize pairs to cancel during canonicalization.
+  SmallVector<OpFoldResult> fullBasis(iterCountOFRs);
+  fullBasis.push_back(loopCount);
+
+  BlockArgument linearId = genericOp.getIdArgs()[0];
+  auto delinOp = affine::AffineDelinearizeIndexOp::create(rewriter, loc,
+                                                          linearId, fullBasis);
+
+  // Results 0..numDims-1 are forall dimension indices, last is pcf loop id.
+  SmallVector<Value> forallDimIvs;
+  for (int64_t i = 0; i < numDims; ++i) {
+    forallDimIvs.push_back(delinOp.getResult(i));
+  }
+  Value pcfLoopLinearId = delinOp.getResult(numDims);
+
+  // Reconstruct forall linear id using affine.linearize_index so that the
+  // linearize/delinearize pair can be folded by canonicalization.
+  Value forallLinearId =
+      affine::AffineLinearizeIndexOp::create(
+          rewriter, loc, forallDimIvs, ArrayRef<OpFoldResult>(iterCountOFRs),
+          /*disjoint=*/true)
+          .getResult();
+
+  // Compute total forall iteration count.
+  AffineExpr s0, s1;
+  bindSymbols(rewriter.getContext(), s0, s1);
+  OpFoldResult totalItersOFR = iterCountOFRs[0];
+  for (int64_t i = 1, e = numDims; i < e; ++i) {
+    totalItersOFR = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, s0 * s1, {totalItersOFR, iterCountOFRs[i]});
+  }
+  Value totalIters =
+      getValueOrCreateConstantIndexOp(rewriter, loc, totalItersOFR);
+
+  Value totalWorkers = genericOp.getCountArgs()[0];
+  // Each worker handles ceil(totalWorkers / loopCount) iterations.
+  AffineExpr ceilDiv = s0.ceilDiv(s1);
+  Value outerStep = getValueOrCreateConstantIndexOp(
+      rewriter, loc,
+      affine::makeComposedFoldedAffineApply(rewriter, loc, ceilDiv,
+                                            {totalWorkers, loopCount}));
+
+  auto outerForall = scf::ForallOp::create(
+      rewriter, loc, ArrayRef<OpFoldResult>{forallLinearId},
+      ArrayRef<OpFoldResult>{totalIters}, ArrayRef<OpFoldResult>{outerStep},
+      /*outputs=*/ValueRange{}, /*mapping=*/std::nullopt);
+
+  // Compute actual forall induction vars inside the outer forall.
+  rewriter.setInsertionPointToStart(outerForall.getBody());
+
+  auto forallIvDelinOp = affine::AffineDelinearizeIndexOp::create(
+      rewriter, loc, outerForall.getInductionVar(0), iterCountValues);
+
+  IRMapping forallMapping;
+  computeForallIVs(rewriter, loc, forallOp, forallIvDelinOp.getMultiIndex(),
+                   forallMapping);
+
+  // Move forall body operations into outer forall (except terminator).
+  Block *outerForallBody = outerForall.getBody();
+  Operation *outerForallTerminator = outerForallBody->getTerminator();
+
+  for (Operation &op :
+       llvm::make_early_inc_range(forallBody->without_terminator())) {
+    op.moveBefore(outerForallTerminator);
+  }
+
+  // Remap induction var block arguments in moved operations.
+  for (Value iv : forallOp.getInductionVars()) {
+    Value mapped = forallMapping.lookup(iv);
+    iv.replaceAllUsesWith(mapped);
+  }
+
+  // Compose write_slice ops with parallel_insert_slice ops.
+  composeWriteSlicesIntoGeneric(rewriter, forallOp, loopOp, genericOp,
+                                terminator, forallMapping);
+
+  // Replace forall results with generic results.
+  for (auto [forallResult, genericResult] :
+       llvm::zip(forallOp.getResults(), genericOp.getResults())) {
+    rewriter.replaceAllUsesWith(forallResult, genericResult);
+  }
+
+  rewriter.eraseOp(forallOp);
+
+  // Convert moved pcf.loop to inner scf.forall (no mapping = parallel).
+  // The pcf.loop was required to have a single count arg (linearized), so
+  // pcfLoopLinearId is sufficient as the sole id.
+  rewriter.setInsertionPoint(loopOp);
+
+  auto innerForall = scf::ForallOp::create(
+      rewriter, loc, ArrayRef<OpFoldResult>{pcfLoopLinearId},
+      ArrayRef<OpFoldResult>{loopCount},
+      ArrayRef<OpFoldResult>{getValueOrCreateConstantIndexOp(
+          rewriter, loc,
+          affine::makeComposedFoldedAffineApply(rewriter, loc, ceilDiv,
+                                                {totalWorkers, loopCount}))},
+      /*outputs=*/ValueRange{}, /*mapping=*/std::nullopt);
+
+  // Replace loop's id arg with inner forall's induction variable.
+  rewriter.replaceAllUsesWith(loopOp.getIdArgs()[0],
+                              innerForall.getInductionVar(0));
+
+  // Move operations from loop body to inner forall.
+  Block *loopBody = &loopOp.getRegion().front();
+  Block *innerForallBody = innerForall.getBody();
+
+  innerForallBody->getOperations().splice(
+      std::prev(innerForallBody->end()), loopBody->getOperations(),
+      loopBody->begin(), std::prev(loopBody->end()));
+
+  rewriter.eraseOp(loopOp);
+
+  // Add terminator to generic's region.
+  rewriter.setInsertionPointToEnd(genericBody);
+  PCF::ReturnOp::create(rewriter, loc);
+
+  return genericOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass: TestFoldForallIntoPCFLoopPass
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the forall op has LocalMappingAttr mapping attributes,
+/// or the mapping is empty/not present.
+static bool hasEmptyOrLocalMapping(scf::ForallOp forallOp) {
+  std::optional<ArrayAttr> mapping = forallOp.getMapping();
+  if (!mapping || mapping->empty()) {
+    return true;
+  }
+  return llvm::all_of(mapping.value(),
+                      llvm::IsaPred<IREE::Codegen::LocalMappingAttr>);
+}
+
+struct TestFoldForallIntoPCFLoopPass final
+    : impl::TestFoldForallIntoPCFLoopPassBase<TestFoldForallIntoPCFLoopPass> {
+  void runOnOperation() override {
+    SmallVector<scf::ForallOp> forallOps;
+    getOperation()->walk([&](scf::ForallOp forallOp) {
+      // Only match foralls with local_mapping attribute.
+      if (!hasEmptyOrLocalMapping(forallOp)) {
+        return;
+      }
+      // Check if there's a pcf.loop with sequential scope inside.
+      scf::InParallelOp terminator = forallOp.getTerminator();
+      Operation *lastOp = terminator->getPrevNode();
+      auto loopOp = dyn_cast_if_present<PCF::LoopOp>(lastOp);
+      if (!loopOp) {
+        return;
+      }
+      // Check for sequential scope.
+      if (!isa<PCF::SequentialAttr>(loopOp.getScope())) {
+        return;
+      }
+      forallOps.push_back(forallOp);
+    });
+
+    IRRewriter rewriter(&getContext());
+    for (scf::ForallOp forallOp : forallOps) {
+      rewriter.setInsertionPoint(forallOp);
+      FailureOr<PCF::GenericOp> result =
+          foldForallIntoPCFLoop(rewriter, forallOp);
+      if (failed(result)) {
+        forallOp.emitError("failed to fold forall into pcf.loop");
+        return signalPassFailure();
+      }
+    }
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Public API: foldForallIntoPCFLoop
+//===----------------------------------------------------------------------===//
+
+FailureOr<GenericOp> foldForallIntoPCFLoop(RewriterBase &rewriter,
+                                           scf::ForallOp forallOp) {
+  // Step 1: Validate terminator structure.
+  FailureOr<LoopOp> loopOpOrFailure = matchFoldTerminator(forallOp);
+  if (failed(loopOpOrFailure)) {
+    return rewriter.notifyMatchFailure(
+        forallOp, "Failed to validate forall op terminator");
+  }
+  LoopOp loopOp = *loopOpOrFailure;
+
+  if (failed(matchFoldPCFLoop(forallOp, loopOp))) {
+    return rewriter.notifyMatchFailure(forallOp,
+                                       "Failed to validate pcf.loop structure");
+  }
+
+  // Step 3: Validate write_slice ops.
+  if (failed(matchFoldWriteSlices(loopOp))) {
+    return rewriter.notifyMatchFailure(forallOp,
+                                       "Failed to validate write_slice ops");
+  }
+
+  // Step 4: Move count definitions into place before rewriting.
+  if (failed(moveValueDefinitions(rewriter, loopOp.getCount(), forallOp))) {
+    return rewriter.notifyMatchFailure(
+        forallOp, "Failed to move loop trip count definitions");
+  }
+
+  // All validations passed, perform the fold.
+  return foldForallIntoPCFLoopImpl(rewriter, forallOp, loopOp);
+}
+
+} // namespace mlir::iree_compiler::IREE::PCF
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp
index f254fd1..1213dda 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp
@@ -8,6 +8,7 @@
 #include "iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.h"
 #include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h"
 #include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -16,6 +17,8 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
 
+#include <cassert>
+
 #define DEBUG_TYPE "iree-pcf-fuse-pcf-writes"
 
 namespace mlir::iree_compiler::IREE::PCF {
@@ -64,6 +67,53 @@
 
 } // namespace
 
+void composeNestedSliceParameters(
+    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> outerOffsets,
+    ArrayRef<OpFoldResult> outerSizes, ArrayRef<OpFoldResult> outerStrides,
+    ArrayRef<OpFoldResult> innerOffsets, ArrayRef<OpFoldResult> innerSizes,
+    ArrayRef<OpFoldResult> innerStrides,
+    SmallVectorImpl<OpFoldResult> &composedOffsets,
+    SmallVectorImpl<OpFoldResult> &composedSizes,
+    SmallVectorImpl<OpFoldResult> &composedStrides) {
+  assert(outerOffsets.size() == outerSizes.size() &&
+         "outer slice offsets/sizes length mismatch");
+  assert(outerOffsets.size() == outerStrides.size() &&
+         "outer slice offsets/strides length mismatch");
+  assert(innerOffsets.size() == innerSizes.size() &&
+         "inner slice offsets/sizes length mismatch");
+  assert(innerOffsets.size() == innerStrides.size() &&
+         "inner slice offsets/strides length mismatch");
+  assert(outerOffsets.size() >= innerOffsets.size() &&
+         "inner slice rank cannot exceed outer slice rank");
+
+  composedOffsets.clear();
+  composedSizes.clear();
+  composedStrides.clear();
+  composedOffsets.reserve(outerOffsets.size());
+  composedSizes.reserve(outerOffsets.size());
+  composedStrides.reserve(outerOffsets.size());
+
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr composeOffExpr = s0 + s1 * s2;
+  AffineExpr mulExpr = s0 * s1;
+
+  for (int64_t i = 0, e = innerOffsets.size(); i < e; ++i) {
+    composedOffsets.push_back(affine::makeComposedFoldedAffineApply(
+        rewriter, loc, composeOffExpr,
+        {outerOffsets[i], innerOffsets[i], outerStrides[i]}));
+    composedSizes.push_back(innerSizes[i]);
+    composedStrides.push_back(affine::makeComposedFoldedAffineApply(
+        rewriter, loc, mulExpr, {outerStrides[i], innerStrides[i]}));
+  }
+
+  for (int64_t i = innerOffsets.size(), e = outerOffsets.size(); i < e; ++i) {
+    composedOffsets.push_back(outerOffsets[i]);
+    composedSizes.push_back(outerSizes[i]);
+    composedStrides.push_back(outerStrides[i]);
+  }
+}
+
 FailureOr<PCF::WriteSliceOp>
 composeWriteSliceWithParallelInsert(RewriterBase &rewriter,
                                     PCF::WriteSliceOp writeSliceOp) {
@@ -133,47 +183,24 @@
   // - sizes: insertSlice.sizes
   // - strides: writeSlice.strides * insertSlice.strides
 
-  SmallVector<OpFoldResult> composedOffsets;
-  SmallVector<OpFoldResult> composedSizes = insertSliceOp.getMixedSizes();
-  SmallVector<OpFoldResult> composedStrides;
-
   OpBuilder::InsertionGuard guard(rewriter);
   // Insert before the in_parallel terminator, not inside it.
   rewriter.setInsertionPoint(inParallelOp);
 
   SmallVector<OpFoldResult> writeOffsets = writeSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> writeSizes = writeSliceOp.getMixedSizes();
   SmallVector<OpFoldResult> insertOffsets = insertSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> insertSizes = insertSliceOp.getMixedSizes();
   SmallVector<OpFoldResult> writeStrides = writeSliceOp.getMixedStrides();
   SmallVector<OpFoldResult> insertStrides = insertSliceOp.getMixedStrides();
 
-  // Compose offsets: writeOffset + insertOffset * writeStride.
-  for (auto [writeOffset, insertOffset, writeStride] :
-       llvm::zip_equal(writeOffsets, insertOffsets, writeStrides)) {
-    Value writeOffsetVal = getValueOrCreateConstantIndexOp(
-        rewriter, insertSliceOp.getLoc(), writeOffset);
-    Value insertOffsetVal = getValueOrCreateConstantIndexOp(
-        rewriter, insertSliceOp.getLoc(), insertOffset);
-    Value writeStrideVal = getValueOrCreateConstantIndexOp(
-        rewriter, insertSliceOp.getLoc(), writeStride);
-
-    Value scaled = rewriter.createOrFold<arith::MulIOp>(
-        insertSliceOp.getLoc(), insertOffsetVal, writeStrideVal);
-    Value composed = rewriter.createOrFold<arith::AddIOp>(
-        insertSliceOp.getLoc(), writeOffsetVal, scaled);
-    composedOffsets.push_back(composed);
-  }
-
-  // Compose strides: writeStride * insertStride.
-  for (auto [writeStride, insertStride] :
-       llvm::zip_equal(writeStrides, insertStrides)) {
-    Value writeStrideVal = getValueOrCreateConstantIndexOp(
-        rewriter, insertSliceOp.getLoc(), writeStride);
-    Value insertStrideVal = getValueOrCreateConstantIndexOp(
-        rewriter, insertSliceOp.getLoc(), insertStride);
-    Value composed = rewriter.createOrFold<arith::MulIOp>(
-        insertSliceOp.getLoc(), writeStrideVal, insertStrideVal);
-    composedStrides.push_back(composed);
-  }
+  SmallVector<OpFoldResult> composedOffsets;
+  SmallVector<OpFoldResult> composedSizes;
+  SmallVector<OpFoldResult> composedStrides;
+  composeNestedSliceParameters(rewriter, insertSliceOp.getLoc(), writeOffsets,
+                               writeSizes, writeStrides, insertOffsets,
+                               insertSizes, insertStrides, composedOffsets,
+                               composedSizes, composedStrides);
 
   // Handle rank-reduced parallel_insert_slice sources.
   // The source may have fewer dimensions than the destination sref (e.g.,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td
index 20e6187..a4e960a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td
@@ -33,6 +33,36 @@
   let dependentDialects = ["::mlir::iree_compiler::IREE::PCF::PCFDialect"];
 }
 
+def TestFoldForallIntoPCFLoopPass : Pass<"iree-pcf-test-fold-forall-into-pcf-loop", ""> {
+  let summary = "Test pass for foldForallIntoPCFLoop transform";
+  let description = [{
+    Test pass for folding `scf.forall` ops containing `pcf.loop` ops into a
+    single `pcf.generic` operation.
+
+    The input is IR containing `scf.forall` ops with `iree_codegen.local_mapping`
+    attributes that contain a `pcf.loop` with `#pcf.sequential` scope as the
+    last operation before the terminator.
+
+    Structural requirements validated:
+    - All ops in `scf.forall.in_parallel` are `tensor.parallel_insert_slice`.
+    - All insert sources come from the same `pcf.loop` result.
+    - All insert destinations are `scf.forall` shared_outs.
+    - The `pcf.loop` is the last op before terminator in forall body.
+    - The `pcf.loop` has a single count argument (linearized).
+    - All `pcf.loop` region ref arg users are `pcf.write_slice` ops.
+    - Ref args have `SyncOnReturnAttr` sync scope.
+
+    The underlying transform is exposed via `foldForallIntoPCFLoop()` for use
+    in custom pipelines with different matching criteria.
+  }];
+  let dependentDialects = [
+    "::mlir::iree_compiler::IREE::PCF::PCFDialect",
+    "::mlir::arith::ArithDialect",
+    "::mlir::affine::AffineDialect",
+    "::mlir::scf::SCFDialect"
+  ];
+}
+
 def FuseConsumersPass : Pass<"iree-pcf-fuse-consumers", ""> {
   let summary = "Fuses all consumers of pcf.generic/loop ops";
   let description = [{
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h
index deea43e..5a01628 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h
@@ -192,6 +192,26 @@
                                      PCF::GenericOp genericOp,
                                      tensor::CollapseShapeOp collapseOp);
 
+// Composes two nested slice parameter sets so the resulting slice addresses the
+// outer base directly. The outer slice describes how an intermediate value is
+// embedded in the final destination, and the inner slice describes how a value
+// is embedded in that intermediate.
+//
+// For each dimension present in both slices:
+//   offsets = outerOffset + innerOffset * outerStride
+//   sizes = innerSizes
+//   strides = outerStride * innerStride
+//
+// Any remaining outer dimensions are forwarded unchanged.
+void composeNestedSliceParameters(
+    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> outerOffsets,
+    ArrayRef<OpFoldResult> outerSizes, ArrayRef<OpFoldResult> outerStrides,
+    ArrayRef<OpFoldResult> innerOffsets, ArrayRef<OpFoldResult> innerSizes,
+    ArrayRef<OpFoldResult> innerStrides,
+    SmallVectorImpl<OpFoldResult> &composedOffsets,
+    SmallVectorImpl<OpFoldResult> &composedSizes,
+    SmallVectorImpl<OpFoldResult> &composedStrides);
+
 // Composes a pcf.write_slice with a tensor.parallel_insert_slice from an
 // scf.forall terminator. The write_slice's destination must be produced by the
 // forall op, and the parallel_insert_slice must be inserting into that result.
@@ -200,6 +220,24 @@
 composeWriteSliceWithParallelInsert(RewriterBase &rewriter,
                                     PCF::WriteSliceOp writeSliceOp);
 
+/// Folds an scf.forall containing a pcf.loop into a single pcf.generic.
+///
+/// Validates structural requirements:
+/// - All ops in scf.forall.in_parallel are tensor.parallel_insert_slice.
+/// - All insert sources come from the same pcf.loop result.
+/// - All insert destinations are scf.forall shared_outs.
+/// - The pcf.loop is the last op before terminator in forall body.
+/// - The pcf.loop has single count argument (linearized).
+/// - All pcf.loop region ref arg users are pcf.write_slice ops.
+/// - Ref args have SyncOnReturnAttr sync scope.
+///
+/// Does NOT validate mapping or scope attributes (caller's responsibility).
+/// Does NOT create workgroup_count_hint (caller's responsibility).
+///
+/// On success, the forall and inner loop are replaced with a pcf.generic.
+FailureOr<PCF::GenericOp> foldForallIntoPCFLoop(RewriterBase &rewriter,
+                                                scf::ForallOp forallOp);
+
 } // namespace mlir::iree_compiler::IREE::PCF
 
 #endif // IREE_COMPILER_CODEGEN_DIALECT_PCF_TRANSFORMS_TRANSFORMS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel
index d68b2e2..4bfc338 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel
@@ -22,6 +22,7 @@
             "convert_forall_to_generic_nest.mlir",
             "convert_forall_to_loops.mlir",
             "convert_sref_to_memref.mlir",
+            "fold_forall_into_pcf_loop.mlir",
             "fuse_collapse_shape.mlir",
             "fuse_consumers.mlir",
             "fuse_pcf_writes.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt
index 0d56caa..c10934d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
     "convert_forall_to_generic_nest.mlir"
     "convert_forall_to_loops.mlir"
     "convert_sref_to_memref.mlir"
+    "fold_forall_into_pcf_loop.mlir"
     "fuse_collapse_shape.mlir"
     "fuse_consumers.mlir"
     "fuse_pcf_writes.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fold_forall_into_pcf_loop.mlir b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fold_forall_into_pcf_loop.mlir
new file mode 100644
index 0000000..a3d35ec
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fold_forall_into_pcf_loop.mlir
@@ -0,0 +1,229 @@
+// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-test-fold-forall-into-pcf-loop)" --mlir-print-local-scope --split-input-file | FileCheck %s
+
+// Test folding scf.forall containing pcf.loop into a single pcf.generic.
+// Forall has 2D iteration space (4, 8) with loop count 4.
+// The write_slice writes at [loop_id, 0] with size [1, 4] into the 4x4 tile.
+// The parallel_insert_slice inserts at [id0, id1] with size [4, 4].
+// Composed write should be at [loop_id + id0, id1] with size [1, 4] strides [1, 1].
+
+func.func @fold_forall_into_pcf_loop(%init: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %c4 = arith.constant 4 : index
+  %0 = scf.forall (%id0, %id1) in (4, 8) shared_outs(%iter = %init) -> (tensor<16x32xf32>) {
+    %tile_init = tensor.extract_slice %iter[%id0, %id1] [4, 4] [1, 1]
+        : tensor<16x32xf32> to tensor<4x4xf32>
+    %loop_result = pcf.loop scope(#pcf.sequential) count(%c4)
+        execute(%ref = %tile_init)[%loop_id: index]
+            : (!pcf.sref<4x4xf32, sync(#pcf.sequential)>)
+           -> (tensor<4x4xf32>) {
+      %slice = tensor.extract_slice %init[%id0, %loop_id] [1, 4] [1, 1]
+          : tensor<16x32xf32> to tensor<1x4xf32>
+      pcf.write_slice %slice into %ref[%loop_id, 0] [1, 4] [1, 1]
+          : tensor<1x4xf32> into !pcf.sref<4x4xf32, sync(#pcf.sequential)>
+      pcf.return
+    }
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %loop_result into %iter[%id0, %id1] [4, 4] [1, 1]
+          : tensor<4x4xf32> into tensor<16x32xf32>
+    }
+  } {mapping = [#iree_codegen.local_mapping<0>, #iree_codegen.local_mapping<1>]}
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @fold_forall_into_pcf_loop
+//  CHECK-SAME:   %[[INIT:[A-Za-z0-9_]+]]: tensor<16x32xf32>
+
+//       CHECK:   %[[GENERIC:.+]] = pcf.generic
+//       CHECK:     scope(#pcf.sequential)
+//       CHECK:     execute(%[[REF:[A-Za-z0-9_]+]] = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %[[GEN_COUNT:[A-Za-z0-9_]+]]: index]
+//       CHECK:          : (!pcf.sref<16x32xf32, sync(#pcf.sequential)>)
+//       CHECK:         -> (tensor<16x32xf32>) {
+
+// Delinearize generic id into (forall_dim0, forall_dim1, pcf_loop_id) with
+// full basis [4, 8, loop_count].
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[GEN_ID]] into
+//       CHECK-SAME:  : index, index, index
+
+// Linearize forall dim indices back to forall_linear_id (cancels with delinearize).
+//       CHECK:     %[[FORALL_LIN:.+]] = affine.linearize_index disjoint [%[[DELIN]]#0, %[[DELIN]]#1] by (4, 8)
+
+// Outer scf.forall starts at forall_linear_id, upper bound = 4*8 = 32.
+//       CHECK:     scf.forall (%[[OUTER_IV:.+]]) = (%[[FORALL_LIN]]) to
+//       CHECK-SAME:  {
+
+// Delinearize outer IV into 2D forall space (4, 8).
+//       CHECK:       %[[FORALL_DELIN:.+]]:2 = affine.delinearize_index %[[OUTER_IV]] into
+//       CHECK-SAME:    : index, index
+
+// Inner scf.forall starts at pcf_loop_id, upper bound = loop count.
+//       CHECK:       scf.forall (%[[INNER_IV:.+]]) = (%[[DELIN]]#2) to
+//       CHECK-SAME:    {
+
+// Composed write: sizes [1, 4] and strides [1, 1] from write_slice.
+// Offset dim 0 = insertOff(id0) + writeOff(loop_id) * insertStride(1).
+// Offset dim 1 = insertOff(id1) + writeOff(0) * insertStride(1) = id1.
+//       CHECK:         %[[COMPOSED_OFF_0:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[FORALL_DELIN]]#0, %[[INNER_IV]]]
+//       CHECK:         pcf.write_slice %{{.+}} into %[[REF]][%[[COMPOSED_OFF_0]], %[[FORALL_DELIN]]#1]
+//  CHECK-SAME:           [1, 4] [1, 1]
+//  CHECK-SAME:           into !pcf.sref<16x32xf32, sync(#pcf.sequential)>
+//       CHECK:     pcf.return
+//       CHECK:   return %[[GENERIC]]
+
+// -----
+
+// Test with multiple results from pcf.loop.
+// Both results get composed writes targeting different ref args.
+
+func.func @fold_forall_multiple_results(%init0: tensor<16xf32>, %init1: tensor<16xf32>)
+    -> (tensor<16xf32>, tensor<16xf32>) {
+  %c2 = arith.constant 2 : index
+  %0:2 = scf.forall (%id) in (4) shared_outs(%iter0 = %init0, %iter1 = %init1)
+      -> (tensor<16xf32>, tensor<16xf32>) {
+    %tile_init0 = tensor.extract_slice %iter0[%id] [4] [1]
+        : tensor<16xf32> to tensor<4xf32>
+    %tile_init1 = tensor.extract_slice %iter1[%id] [4] [1]
+        : tensor<16xf32> to tensor<4xf32>
+    %loop_result:2 = pcf.loop scope(#pcf.sequential) count(%c2)
+        execute(%ref0 = %tile_init0, %ref1 = %tile_init1)[%loop_id: index]
+            : (!pcf.sref<4xf32, sync(#pcf.sequential)>,
+               !pcf.sref<4xf32, sync(#pcf.sequential)>)
+           -> (tensor<4xf32>, tensor<4xf32>) {
+      %slice0 = tensor.extract_slice %init0[%loop_id] [2] [1]
+          : tensor<16xf32> to tensor<2xf32>
+      %slice1 = tensor.extract_slice %init1[%loop_id] [2] [1]
+          : tensor<16xf32> to tensor<2xf32>
+      pcf.write_slice %slice0 into %ref0[%loop_id] [2] [1]
+          : tensor<2xf32> into !pcf.sref<4xf32, sync(#pcf.sequential)>
+      pcf.write_slice %slice1 into %ref1[%loop_id] [2] [1]
+          : tensor<2xf32> into !pcf.sref<4xf32, sync(#pcf.sequential)>
+      pcf.return
+    }
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %loop_result#0 into %iter0[%id] [4] [1]
+          : tensor<4xf32> into tensor<16xf32>
+      tensor.parallel_insert_slice %loop_result#1 into %iter1[%id] [4] [1]
+          : tensor<4xf32> into tensor<16xf32>
+    }
+  } {mapping = [#iree_codegen.local_mapping<0>]}
+  return %0#0, %0#1 : tensor<16xf32>, tensor<16xf32>
+}
+
+// CHECK-LABEL: @fold_forall_multiple_results
+//  CHECK-SAME:   %[[INIT0:[A-Za-z0-9_]+]]: tensor<16xf32>
+//  CHECK-SAME:   %[[INIT1:[A-Za-z0-9_]+]]: tensor<16xf32>
+
+//       CHECK:   %[[GENERIC:.+]]:2 = pcf.generic
+//       CHECK:     scope(#pcf.sequential)
+//       CHECK:     execute(%[[REF0:[A-Za-z0-9_]+]] = %[[INIT0]], %[[REF1:[A-Za-z0-9_]+]] = %[[INIT1]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %{{.*}}: index]
+//       CHECK:          : (!pcf.sref<16xf32, sync(#pcf.sequential)>, !pcf.sref<16xf32, sync(#pcf.sequential)>)
+//       CHECK:         -> (tensor<16xf32>, tensor<16xf32>) {
+
+// Delinearize generic id into (forall_dim0, pcf_loop_id).
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into
+//       CHECK-SAME:  : index, index
+
+// Linearize forall dim index (1D, so linearize is trivial).
+//       CHECK:     %[[FORALL_LIN:.+]] = affine.linearize_index disjoint [%[[DELIN]]#0] by (4)
+
+// Outer scf.forall from forall_linear_id.
+//       CHECK:     scf.forall (%{{.+}}) = (%[[FORALL_LIN]])
+
+// Inner scf.forall from pcf_loop_id.
+//       CHECK:       scf.forall (%{{.+}}) = (%[[DELIN]]#1)
+
+// Composed writes: write offset[loop_id] + insert offset[id].
+// Both writes have size 2 and stride 1, targeting different ref args.
+//       CHECK:         %[[COMPOSED_OFF_0:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.+}}, %{{.+}}]
+//       CHECK:         pcf.write_slice %{{.+}} into %[[REF0]][%[[COMPOSED_OFF_0]]] [2] [1]
+//  CHECK-SAME:           into !pcf.sref<16xf32, sync(#pcf.sequential)>
+//       CHECK:         %[[COMPOSED_OFF_1:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.+}}, %{{.+}}]
+//       CHECK:         pcf.write_slice %{{.+}} into %[[REF1]][%[[COMPOSED_OFF_1]]] [2] [1]
+//  CHECK-SAME:           into !pcf.sref<16xf32, sync(#pcf.sequential)>
+//       CHECK:     pcf.return
+//       CHECK:   return %[[GENERIC]]#0, %[[GENERIC]]#1
+
+// -----
+
+// Test write_slice + parallel_insert_slice composition with non-unit write
+// stride. The write has stride 2, the insert has stride 1.
+// Composed stride = 2 * 1 = 2. Write size = 2, stays 2.
+
+func.func @fold_compose_strides(%init: tensor<64xf32>) -> tensor<64xf32> {
+  %c3 = arith.constant 3 : index
+  %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<64xf32>) {
+    %tile_init = tensor.extract_slice %iter[%id] [16] [1]
+        : tensor<64xf32> to tensor<16xf32>
+    %loop_result = pcf.loop scope(#pcf.sequential) count(%c3)
+        execute(%ref = %tile_init)[%loop_id: index]
+            : (!pcf.sref<16xf32, sync(#pcf.sequential)>)
+           -> (tensor<16xf32>) {
+      %val = tensor.extract_slice %init[0] [2] [1]
+          : tensor<64xf32> to tensor<2xf32>
+      pcf.write_slice %val into %ref[1] [2] [2]
+          : tensor<2xf32> into !pcf.sref<16xf32, sync(#pcf.sequential)>
+      pcf.return
+    }
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %loop_result into %iter[%id] [16] [1]
+          : tensor<16xf32> into tensor<64xf32>
+    }
+  } {mapping = [#iree_codegen.local_mapping<0>]}
+  return %0 : tensor<64xf32>
+}
+
+// Composed offset = insert(id) + write(1) * insert_stride(1) = id + 1.
+// Size = 2 (from write). Stride = 2 * 1 = 2.
+// CHECK-LABEL: @fold_compose_strides
+//  CHECK-SAME:   %[[INIT:[A-Za-z0-9_]+]]: tensor<64xf32>
+//       CHECK:   %[[GENERIC:.+]] = pcf.generic
+//       CHECK:     execute(%[[REF:[A-Za-z0-9_]+]] = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index
+//       CHECK:          : (!pcf.sref<64xf32, sync(#pcf.sequential)>)
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into
+//       CHECK:     %[[FORALL_LIN:.+]] = affine.linearize_index disjoint
+//       CHECK:     %[[OUTER_STEP:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 3)>()[%{{.+}}]
+//       CHECK:     scf.forall (%[[OUTER_IV:.+]]) = (%[[FORALL_LIN]]) to (%{{.+}}) step (%[[OUTER_STEP]])
+//       CHECK:       %[[INNER_STEP:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 3)>()[%{{.+}}]
+//       CHECK:       scf.forall (%[[INNER_IV:.+]]) = (%[[DELIN]]#1) to (%{{.+}}) step (%[[INNER_STEP]])
+// Composed: size 2, stride 2 (write_stride=2 * insert_stride=1).
+//       CHECK:         %[[COMPOSED_OFF:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[OUTER_IV]]]
+//       CHECK:         pcf.write_slice %{{.+}} into %[[REF]][%[[COMPOSED_OFF]]] [2] [2]
+//  CHECK-SAME:           into !pcf.sref<64xf32, sync(#pcf.sequential)>
+//       CHECK:     pcf.return
+
+// -----
+
+// Test folding when the pcf.loop count is computed in the forall body and
+// needs to be hoisted before rewriting.
+func.func @fold_hoists_loop_count(%init: tensor<8xf32>, %one: index)
+    -> tensor<8xf32> {
+  %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<8xf32>) {
+    %tile_init = tensor.extract_slice %iter[%id] [2] [1]
+        : tensor<8xf32> to tensor<2xf32>
+    %count = arith.addi %one, %one : index
+    %loop_result = pcf.loop scope(#pcf.sequential) count(%count)
+        execute(%ref = %tile_init)[%loop_id: index]
+            : (!pcf.sref<2xf32, sync(#pcf.sequential)>)
+           -> (tensor<2xf32>) {
+      %slice = tensor.extract_slice %init[%loop_id] [1] [1]
+          : tensor<8xf32> to tensor<1xf32>
+      pcf.write_slice %slice into %ref[%loop_id] [1] [1]
+          : tensor<1xf32> into !pcf.sref<2xf32, sync(#pcf.sequential)>
+      pcf.return
+    }
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %loop_result into %iter[%id] [2] [1]
+          : tensor<2xf32> into tensor<8xf32>
+    }
+  } {mapping = [#iree_codegen.local_mapping<0>]}
+  return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: @fold_hoists_loop_count
+//  CHECK-SAME:   %[[INIT:[A-Za-z0-9_]+]]: tensor<8xf32>, %[[ONE:[A-Za-z0-9_]+]]: index
+//       CHECK:   %[[COUNT:[A-Za-z0-9_]+]] = arith.addi %[[ONE]], %[[ONE]] : index
+//       CHECK:   %[[GENERIC:.+]] = pcf.generic
+//       CHECK:     execute(%{{.+}} = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %[[GEN_COUNT:[A-Za-z0-9_]+]]: index]
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into (4, %[[COUNT]])
+//       CHECK:     %[[OUTER_STEP:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 ceildiv s1)>()[%[[GEN_COUNT]], %[[COUNT]]]
+//       CHECK:       scf.forall (%{{.+}}) = (%{{.+}}) to (%{{.+}}) step (%[[OUTER_STEP]])
+//       CHECK:       %[[INNER_STEP:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 ceildiv s1)>()[%[[GEN_COUNT]], %[[COUNT]]]
+//       CHECK:       scf.forall (%{{.+}}) = (%[[DELIN]]#1) to (%[[COUNT]]) step (%[[INNER_STEP]])
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir
index 55fcb90..5190118 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-fuse-pcf-writes)" --split-input-file | FileCheck %s
+// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-fuse-pcf-writes)" --mlir-print-local-scope --split-input-file | FileCheck %s
 
 func.func @fuse_write_slice_with_parallel_insert(%init: tensor<32x64xf32>, %dest: !pcf.sref<32x64xf32, sync(#pcf.sequential)>) {
   %result = scf.forall (%i, %j) in (4, 8) shared_outs(%iter = %init) -> tensor<32x64xf32> {
@@ -48,10 +48,9 @@
 //  CHECK-SAME:   %[[INIT:[A-Za-z0-9_]+]]: tensor<32x64xf32>
 //  CHECK-SAME:   %[[DEST:[A-Za-z0-9_]+]]: !pcf.sref<32x64xf32, sync(#pcf.sequential)>
 
-//   CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
 //       CHECK:   scf.forall (%[[I:.+]], %[[J:.+]]) in (4, 8) {
 //       CHECK:     %[[TILE:.+]] = tensor.generate
-//       CHECK:     %[[COMPOSED_OFFSET:.+]] = arith.addi %[[I]], %[[C16]]
+//       CHECK:     %[[COMPOSED_OFFSET:.+]] = affine.apply affine_map<()[s0] -> (s0 + 16)>()[%[[I]]]
 //       CHECK:     pcf.write_slice %[[TILE]] into %[[DEST]][%[[COMPOSED_OFFSET]], %[[J]]] [8, 8] [1, 1]
 //       CHECK:   }
 //       CHECK-NOT:   pcf.write_slice
@@ -78,11 +77,10 @@
 //  CHECK-SAME:   %[[INIT:[A-Za-z0-9_]+]]: tensor<32x64xf32>
 //  CHECK-SAME:   %[[DEST:[A-Za-z0-9_]+]]: !pcf.sref<64x128xf32, sync(#pcf.sequential)>
 
-//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //       CHECK:   scf.forall (%[[I:.+]], %[[J:.+]]) in (4, 8) {
 //       CHECK:     %[[TILE:.+]] = tensor.generate
-//   CHECK-DAG:     %[[OFFSET_0:.+]] = arith.muli %[[I]], %[[C2]]
-//   CHECK-DAG:     %[[OFFSET_1:.+]] = arith.muli %[[J]], %[[C2]]
+//   CHECK-DAG:     %[[OFFSET_0:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[I]]]
+//   CHECK-DAG:     %[[OFFSET_1:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[J]]]
 //       CHECK:     pcf.write_slice %[[TILE]] into %[[DEST]][%[[OFFSET_0]], %[[OFFSET_1]]] [8, 8] [2, 2]
 //       CHECK:   }
 //       CHECK-NOT:   pcf.write_slice
@@ -140,10 +138,10 @@
 //  CHECK-SAME:   %[[OFFSET_BASE:[A-Za-z0-9_]+]]: index
 
 //   CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
-//   CHECK-DAG:   %[[OFFSET:.+]] = arith.addi %[[OFFSET_BASE]], %[[C16]]
+//   CHECK-DAG:   %[[OFFSET:.+]] = arith.addi %[[OFFSET_BASE]], %[[C16]] : index
 //       CHECK:   scf.forall (%[[I:.+]], %[[J:.+]]) in (4, 8) {
 //       CHECK:     %[[TILE:.+]] = tensor.generate
-//       CHECK:     %[[COMPOSED_OFFSET:.+]] = arith.addi %[[OFFSET]], %[[I]]
+//       CHECK:     %[[COMPOSED_OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET]], %[[I]]]
 //       CHECK:     pcf.write_slice %[[TILE]] into %[[DEST]][%[[COMPOSED_OFFSET]], %[[J]]] [8, 8] [1, 1]
 //       CHECK:   }
 //       CHECK-NOT:   pcf.write_slice