[Codegen][PCF] Add foldForallIntoPCFLoop for split-k loop handling (#23452)
Adds a generic `foldForallIntoPCFLoop` function that folds an
`scf.forall` containing a `pcf.loop` into a single `pcf.generic`
operation. This enables incorporating user expressed extra levels of
parallelism into a scope. The immediate use case this enables is split-k
by folding the split-k loop into the workgroup loop.
Key changes:
- Add `foldForallIntoPCFLoop` API in PCF/Transforms/Transforms.h
- Implement structural matching helpers (matchFoldTerminator,
matchFoldPCFLoop, matchFoldWriteSlices) that validate requirements
without scope-specific logic
- Add TestFoldForallIntoPCFLoopPass for unit testing with local_mapping
+ sequential scope
- Add FoldSplitKWorkgroupLoop pattern in ConvertWorkgroupForallToPCF.cpp
that matches split_reduction_mapping + workgroup_scope and calls the
generic fold
- Run fold pattern as second pass in ConvertWorkgroupForallToPCFPass
The fold operation:
1. Creates pcf.generic with same scope as inner pcf.loop
2. Linearizes forall iteration space and delinearizes inside generic
3. Converts pcf.loop to a nested scf.forall loop to handle spillover
4. Composes tensor.parallel_insert_slice with pcf.write_slice ops
The reason this pattern has to generate a pcf.generic instead of
pcf.loop is to avoid extra execution of code that was inside
the scf.forall but not inside the pcf.loop.
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp
index 80b3830..651109e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp
@@ -9,6 +9,11 @@
#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -25,6 +30,17 @@
PatternRewriter &rewriter) const override;
};
+/// Folds an scf.forall with split-reduction mapping containing a pcf.loop
+/// with workgroup scope into a single pcf.generic. This handles the "split-k"
+/// pattern where the outer forall represents additional parallelism from
+/// K-dimension splitting and the inner pcf.loop represents the original
+/// workgroup-level iteration.
+struct FoldSplitKWorkgroupLoop : OpRewritePattern<scf::ForallOp> {
+ using Base::Base;
+ LogicalResult matchAndRewrite(scf::ForallOp op,
+ PatternRewriter &rewriter) const override;
+};
+
struct ConvertWorkgroupForallToPCFPass final
: impl::ConvertWorkgroupForallToPCFPassBase<
ConvertWorkgroupForallToPCFPass> {
@@ -71,10 +87,111 @@
return success();
}
+LogicalResult
+FoldSplitKWorkgroupLoop::matchAndRewrite(scf::ForallOp op,
+ PatternRewriter &rewriter) const {
+ // Match scf.forall with split-reduction mapping.
+ if (!forallOpHasMappingType<IREE::LinalgExt::SplitReductionMappingAttr>(op)) {
+ return failure();
+ }
+
+ // Find pcf.loop with workgroup scope in the forall body.
+ IREE::PCF::LoopOp loopOp = nullptr;
+ for (Operation &bodyOp : op.getBody()->without_terminator()) {
+ if (auto loop = dyn_cast<IREE::PCF::LoopOp>(&bodyOp)) {
+ if (isa<IREE::Codegen::WorkgroupScopeAttr>(loop.getScope())) {
+ loopOp = loop;
+ break;
+ }
+ }
+ }
+ if (!loopOp) {
+ return failure();
+ }
+
+ // Capture values needed for workgroup count computation before folding.
+ // The fold erases the forall op, so any mixed bounds/steps must be
+ // materialized up front.
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> lowerBounds = op.getMixedLowerBound();
+ SmallVector<OpFoldResult> upperBounds = op.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = op.getMixedStep();
+ SmallVector<Value> loopCounts(loopOp.getCount());
+
+ // Fold forall + pcf.loop into pcf.generic.
+ FailureOr<IREE::PCF::GenericOp> result =
+ IREE::PCF::foldForallIntoPCFLoop(rewriter, op);
+ if (failed(result)) {
+ return failure();
+ }
+
+ // Compute total workgroup count after folding (forall iterations * loop
+ // count). Generate IR before the pcf.generic.
+ rewriter.setInsertionPoint(*result);
+
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr numItersExpr = (s0 - s1).ceilDiv(s2);
+
+ Value forallCount = nullptr;
+ for (int64_t i = 0, e = upperBounds.size(); i < e; ++i) {
+ OpFoldResult lb = i < (int64_t)lowerBounds.size()
+ ? lowerBounds[i]
+ : rewriter.getIndexAttr(0);
+ OpFoldResult ub = upperBounds[i];
+ OpFoldResult step =
+ i < (int64_t)steps.size() ? steps[i] : rewriter.getIndexAttr(1);
+
+ Value iterCount = getValueOrCreateConstantIndexOp(
+ rewriter, loc,
+ affine::makeComposedFoldedAffineApply(rewriter, loc, numItersExpr,
+ {ub, lb, step}));
+ if (!forallCount) {
+ forallCount = iterCount;
+ } else {
+ forallCount =
+ arith::MulIOp::create(rewriter, loc, forallCount, iterCount);
+ }
+ }
+
+ Value totalLoopCount = nullptr;
+ for (Value count : loopCounts) {
+ if (!totalLoopCount) {
+ totalLoopCount = count;
+ } else {
+ totalLoopCount =
+ arith::MulIOp::create(rewriter, loc, totalLoopCount, count);
+ }
+ }
+
+ Value totalCount =
+ arith::MulIOp::create(rewriter, loc, forallCount, totalLoopCount);
+
+ // Create workgroup count hint for the folded generic.
+ SmallVector<OpFoldResult> counts = {OpFoldResult(totalCount)};
+ [[maybe_unused]] LogicalResult hintRes = createWorkgroupCountHint(
+ rewriter, loc, counts, /*maxWorkgroupParallelDims=*/1,
+ /*reverse=*/false);
+ assert(succeeded(hintRes) &&
+ "Unexpected failure to construct workgroup count hint");
+
+ return success();
+}
+
void ConvertWorkgroupForallToPCFPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- patterns.add<ConvertWorkgroupForall>(&getContext());
- walkAndApplyPatterns(getOperation(), std::move(patterns));
+ // First pass: Convert workgroup foralls to pcf.loop.
+ {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<ConvertWorkgroupForall>(&getContext());
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
+
+ // Second pass: Fold split-k foralls containing pcf.loop into pcf.generic.
+ {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<FoldSplitKWorkgroupLoop>(&getContext());
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir
index 7773e4f..b827cfe 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_workgroup_forall_to_pcf.mlir
@@ -116,3 +116,80 @@
// CHECK-LABEL: @forall_gpu_thread_mapping_not_converted
// CHECK: scf.forall
// CHECK-NOT: pcf.loop
+
+// -----
+
+// Test folding split-reduction forall containing workgroup pcf.loop into pcf.generic.
+func.func @fold_split_reduction_into_pcf_generic(%init: tensor<16xf32>, %slice: tensor<1xf32>) -> tensor<16xf32> {
+ %c4 = arith.constant 4 : index
+ %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<16xf32>) {
+ %tile_init = tensor.extract_slice %iter[%id] [4] [1]
+ : tensor<16xf32> to tensor<4xf32>
+ %loop_result = pcf.loop scope(#iree_codegen.workgroup_scope<linearize>) count(%c4)
+ execute(%ref = %tile_init)[%loop_id: index]
+ : (!pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>)
+ -> (tensor<4xf32>) {
+ pcf.write_slice %slice into %ref[%loop_id] [1] [1]
+ : tensor<1xf32> into !pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>
+ pcf.return
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %loop_result into %iter[%id] [4] [1]
+ : tensor<4xf32> into tensor<16xf32>
+ }
+ } {mapping = [#iree_linalg_ext.split_reduction_mapping<0>]}
+ return %0 : tensor<16xf32>
+}
+
+// Total workgroup count = forall iterations (4) * loop count (4).
+// The arith.muli computes the product of the forall upper bound and the loop count.
+// CHECK-LABEL: @fold_split_reduction_into_pcf_generic
+// CHECK-SAME: %[[INIT:[A-Za-z0-9_]+]]: tensor<16xf32>
+// CHECK-SAME: %[[SLICE:[A-Za-z0-9_]+]]: tensor<1xf32>
+// CHECK-DAG: %[[C4_A:[a-zA-Z0-9_]+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C4_B:[a-zA-Z0-9_]+]] = arith.constant 4 : index
+// CHECK: %[[TOTAL:.+]] = arith.muli %[[C4_B]], %[[C4_A]] : index
+// CHECK: iree_codegen.workgroup_count_hint(%[[TOTAL]])
+// CHECK: %[[GENERIC:.+]] = pcf.generic
+// CHECK: scope(#iree_codegen.workgroup_scope<linearize>)
+// CHECK: execute(%[[REF:[A-Za-z0-9_]+]] = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %{{.*}}: index]
+// CHECK: : (!pcf.sref<16xf32, sync(#iree_codegen.workgroup_scope<linearize>)>)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into
+// CHECK: %[[FORALL_LIN:.+]] = affine.linearize_index disjoint
+// CHECK: scf.forall (%{{.+}}) = (%[[FORALL_LIN]])
+// CHECK: scf.forall (%{{.+}}) = (%[[DELIN]]#1)
+// CHECK: pcf.write_slice %[[SLICE]] into %[[REF]]{{.*}} [1] [1]
+// CHECK-SAME: into !pcf.sref<16xf32, sync(#iree_codegen.workgroup_scope<linearize>)>
+// CHECK: pcf.return
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+// Non-split-reduction mapping should not be folded by the split-k pattern.
+// The workgroup forall is converted to pcf.loop, but the outer forall (with
+// local mapping) should remain unconverted.
+func.func @non_split_reduction_not_folded(%init: tensor<16xf32>, %slice: tensor<1xf32>) -> tensor<16xf32> {
+ %c4 = arith.constant 4 : index
+ %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<16xf32>) {
+ %tile_init = tensor.extract_slice %iter[%id] [4] [1]
+ : tensor<16xf32> to tensor<4xf32>
+ %loop_result = pcf.loop scope(#iree_codegen.workgroup_scope<linearize>) count(%c4)
+ execute(%ref = %tile_init)[%loop_id: index]
+ : (!pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>)
+ -> (tensor<4xf32>) {
+ pcf.write_slice %slice into %ref[%loop_id] [1] [1]
+ : tensor<1xf32> into !pcf.sref<4xf32, sync(#iree_codegen.workgroup_scope<linearize>)>
+ pcf.return
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %loop_result into %iter[%id] [4] [1]
+ : tensor<4xf32> into tensor<16xf32>
+ }
+ } {mapping = [#iree_codegen.local_mapping<0>]}
+ return %0 : tensor<16xf32>
+}
+
+// CHECK-LABEL: @non_split_reduction_not_folded
+// CHECK: scf.forall
+// CHECK: pcf.loop
+// CHECK-NOT: pcf.generic
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel
index 225a744..580ad2e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel
@@ -37,6 +37,7 @@
srcs = [
"ConvertForallToPCF.cpp",
"ConvertSRefToMemRef.cpp",
+ "FoldForallIntoPCFLoop.cpp",
"FuseConsumers.cpp",
"FusePCFWrites.cpp",
"FuseProducers.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt
index 4a9e11e..64b852c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt
@@ -30,6 +30,7 @@
SRCS
"ConvertForallToPCF.cpp"
"ConvertSRefToMemRef.cpp"
+ "FoldForallIntoPCFLoop.cpp"
"FuseConsumers.cpp"
"FusePCFWrites.cpp"
"FuseProducers.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FoldForallIntoPCFLoop.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FoldForallIntoPCFLoop.cpp
new file mode 100644
index 0000000..5e937eb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FoldForallIntoPCFLoop.cpp
@@ -0,0 +1,559 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFAttrs.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFOps.h"
+#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.h"
+#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h"
+#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir::iree_compiler::IREE::PCF {
+
+#define GEN_PASS_DEF_TESTFOLDFORALLINTOPCFLOOPPASS
+#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// scf.forall + pcf.loop -> pcf.generic fold helpers
+//===----------------------------------------------------------------------===//
+
+/// Validates that the forall terminator has the expected structure for folding:
+/// - All ops are tensor.parallel_insert_slice.
+/// - All insert sources come from the same pcf.loop result.
+/// - All insert destinations are forall shared_outs.
+/// Returns the found pcf.loop on success.
+static FailureOr<PCF::LoopOp> matchFoldTerminator(scf::ForallOp forallOp) {
+ auto terminator =
+ cast<scf::InParallelOp>(forallOp.getRegion().front().getTerminator());
+
+ PCF::LoopOp foundLoop = nullptr;
+
+ for (Operation &op : terminator.getYieldingOps()) {
+ // All ops must be tensor.parallel_insert_slice.
+ auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
+ if (!insertSliceOp) {
+ return failure();
+ }
+
+ // Source must be from a pcf.loop result.
+ auto loopResult = dyn_cast<OpResult>(insertSliceOp.getSource());
+ if (!loopResult) {
+ return failure();
+ }
+
+ auto currentLoop = dyn_cast<PCF::LoopOp>(loopResult.getOwner());
+ if (!currentLoop) {
+ return failure();
+ }
+
+ // All inserts must reference the same pcf.loop.
+ if (foundLoop && foundLoop != currentLoop) {
+ return failure();
+ }
+ foundLoop = currentLoop;
+
+ // Destination must be a shared_out of the forall.
+ auto destArg = dyn_cast<BlockArgument>(insertSliceOp.getDest());
+ if (!destArg || destArg.getOwner() != &forallOp.getRegion().front()) {
+ return failure();
+ }
+
+ // Verify it's a shared_out (comes after induction vars).
+ if (destArg.getArgNumber() < forallOp.getRank()) {
+ return failure();
+ }
+ }
+
+ if (!foundLoop) {
+ return failure();
+ }
+
+ return foundLoop;
+}
+
+/// Validates the pcf.loop structure for folding:
+/// - Single count argument (linearized).
+/// - Loop is last op before terminator.
+static LogicalResult matchFoldPCFLoop(scf::ForallOp forallOp,
+ PCF::LoopOp loopOp) {
+ // Single count argument required.
+ if (loopOp.getCount().size() != 1) {
+ return failure();
+ }
+
+ // Loop must be last op before terminator.
+ Operation *lastOp =
+ forallOp.getRegion().front().getTerminator()->getPrevNode();
+ if (lastOp != loopOp) {
+ return failure();
+ }
+
+ return success();
+}
+
+/// Validates pcf.loop region ref args:
+/// - All users are pcf.write_slice ops.
+/// - Ref args have SyncOnReturnAttr sync scope.
+static LogicalResult matchFoldWriteSlices(PCF::LoopOp loopOp) {
+ for (BlockArgument refArg : loopOp.getRegionRefArgs()) {
+ // Check sync scope is sync_on_return.
+ auto srefType = cast<PCF::ShapedRefType>(refArg.getType());
+ Attribute syncScope = srefType.getSyncScope();
+ if (!isa_and_nonnull<PCF::SyncOnReturnAttr>(syncScope)) {
+ return failure();
+ }
+
+ // All users must be write_slice ops.
+ for (Operation *user : refArg.getUsers()) {
+ auto writeOp = dyn_cast<PCF::WriteSliceOp>(user);
+ if (!writeOp) {
+ return failure();
+ }
+ }
+ }
+
+ return success();
+}
+
+/// Computes the iteration count per dimension from forall bounds.
+/// Returns ceildiv(ub - lb, step) for each dimension.
+static SmallVector<OpFoldResult>
+computeForallIterCounts(RewriterBase &rewriter, Location loc,
+ scf::ForallOp forallOp) {
+ SmallVector<OpFoldResult> lowerBounds = forallOp.getMixedLowerBound();
+ SmallVector<OpFoldResult> upperBounds = forallOp.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
+
+ int64_t numDims = upperBounds.size();
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr numItersExpr = (s0 - s1).ceilDiv(s2);
+
+ SmallVector<OpFoldResult> iterCountOFRs;
+ for (int64_t i = 0, e = numDims; i < e; ++i) {
+ OpFoldResult iterCount = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, numItersExpr,
+ {upperBounds[i], lowerBounds[i], steps[i]});
+ iterCountOFRs.push_back(iterCount);
+ }
+ return iterCountOFRs;
+}
+
+/// Computes actual forall induction variables from delinearized indices
+/// by applying lower bounds and steps: iv = delinearized * step + lb.
+static void computeForallIVs(RewriterBase &rewriter, Location loc,
+ scf::ForallOp forallOp, ValueRange delinearizedIvs,
+ IRMapping &forallMapping) {
+ SmallVector<OpFoldResult> lowerBounds = forallOp.getMixedLowerBound();
+ SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
+ int64_t numDims = forallOp.getRank();
+
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr applyLbAndStep = s0 * s1 + s2;
+
+ for (int64_t i = 0, e = numDims; i < e; ++i) {
+ Value actualIv = getValueOrCreateConstantIndexOp(
+ rewriter, loc,
+ affine::makeComposedFoldedAffineApply(
+ rewriter, loc, applyLbAndStep,
+ {delinearizedIvs[i], steps[i], lowerBounds[i]}));
+ forallMapping.map(forallOp.getInductionVar(i), actualIv);
+ }
+}
+
+/// Composes pcf.write_slice ops with tensor.parallel_insert_slice ops from
+/// the forall terminator, creating new write_slice ops that write directly
+/// to the pcf.generic's sref arguments.
+///
+/// This reuses composeNestedSliceParameters() by treating the
+/// tensor.parallel_insert_slice as the outer slice and the pcf.write_slice as
+/// the inner slice.
+static void composeWriteSlicesIntoGeneric(RewriterBase &rewriter,
+ scf::ForallOp forallOp,
+ PCF::LoopOp loopOp,
+ PCF::GenericOp genericOp,
+ scf::InParallelOp terminator,
+ IRMapping &forallMapping) {
+ // Build mapping from pcf.loop result index -> generic ref arg index.
+ SmallVector<unsigned> resultToRefArgIdx(loopOp->getNumResults());
+ for (Operation &op : terminator.getYieldingOps()) {
+ auto insertOp = cast<tensor::ParallelInsertSliceOp>(&op);
+ auto loopResult = cast<OpResult>(insertOp.getSource());
+ unsigned resultIdx = loopResult.getResultNumber();
+ auto destArg = cast<BlockArgument>(insertOp.getDest());
+ unsigned argIdx = destArg.getArgNumber() - forallOp.getRank();
+ resultToRefArgIdx[resultIdx] = argIdx;
+ }
+
+ // Helper to remap OpFoldResult Values through forallMapping.
+ auto remapOFR = [&](OpFoldResult ofr) -> OpFoldResult {
+ if (auto val = dyn_cast<Value>(ofr)) {
+ return forallMapping.lookupOrDefault(val);
+ }
+ return ofr;
+ };
+
+ for (Operation &op :
+ llvm::make_early_inc_range(terminator.getYieldingOps())) {
+ auto insertOp = cast<tensor::ParallelInsertSliceOp>(&op);
+ auto loopResult = cast<OpResult>(insertOp.getSource());
+ unsigned resultIdx = loopResult.getResultNumber();
+
+ BlockArgument genericRefArg =
+ genericOp.getRegionRefArgs()[resultToRefArgIdx[resultIdx]];
+ BlockArgument movedRefArg = loopOp.getRegionRefArgs()[resultIdx];
+
+ // Compose all write_slice ops that write to this ref arg.
+ for (Operation *user : llvm::make_early_inc_range(movedRefArg.getUsers())) {
+ auto writeOp = dyn_cast<PCF::WriteSliceOp>(user);
+ if (!writeOp) {
+ continue;
+ }
+
+ rewriter.setInsertionPoint(writeOp);
+
+ SmallVector<OpFoldResult> writeOffsets = writeOp.getMixedOffsets();
+ SmallVector<OpFoldResult> writeStrides = writeOp.getMixedStrides();
+ SmallVector<OpFoldResult> writeSizes = writeOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> insertOffsets =
+ llvm::map_to_vector(insertOp.getMixedOffsets(), remapOFR);
+ SmallVector<OpFoldResult> insertSizes =
+ llvm::map_to_vector(insertOp.getMixedSizes(), remapOFR);
+ SmallVector<OpFoldResult> insertStrides =
+ llvm::map_to_vector(insertOp.getMixedStrides(), remapOFR);
+
+ SmallVector<OpFoldResult> composedOffsets;
+ SmallVector<OpFoldResult> composedSizes;
+ SmallVector<OpFoldResult> composedStrides;
+ composeNestedSliceParameters(rewriter, writeOp.getLoc(), insertOffsets,
+ insertSizes, insertStrides, writeOffsets,
+ writeSizes, writeStrides, composedOffsets,
+ composedSizes, composedStrides);
+
+ // Expand source to match sref rank by adding unit dims.
+ auto sourceType = cast<RankedTensorType>(writeOp.getSource().getType());
+ auto srefType = cast<PCF::ShapedRefType>(genericRefArg.getType());
+ int64_t srefRank = srefType.getRank();
+ int64_t sourceRank = sourceType.getRank();
+
+ Value expandedSource = writeOp.getSource();
+ if (srefRank > sourceRank) {
+ SmallVector<int64_t> expandedShape;
+ SmallVector<ReassociationIndices> reassociation;
+
+ // First sourceRank-1 dimensions map 1:1.
+ for (int64_t i = 0; i < sourceRank - 1; ++i) {
+ expandedShape.push_back(sourceType.getDimSize(i));
+ reassociation.push_back({i});
+ }
+
+ // Last input dimension expands to include itself plus unit dims.
+ ReassociationIndices lastGroup;
+ if (sourceRank > 0) {
+ expandedShape.push_back(sourceType.getDimSize(sourceRank - 1));
+ lastGroup.push_back(sourceRank - 1);
+ }
+
+ // Add unit dimensions to reach sref rank.
+ int64_t numUnitDims = srefRank - sourceRank;
+ for (int64_t i = 0; i < numUnitDims; ++i) {
+ expandedShape.push_back(1);
+ lastGroup.push_back(sourceRank + i);
+ }
+
+ if (!lastGroup.empty()) {
+ reassociation.push_back(lastGroup);
+ }
+
+ auto expandedType =
+ RankedTensorType::get(expandedShape, sourceType.getElementType());
+ expandedSource = tensor::ExpandShapeOp::create(
+ rewriter, writeOp.getLoc(), expandedType, writeOp.getSource(),
+ reassociation);
+ }
+
+ // Replace old write_slice with composed one.
+ rewriter.replaceOpWithNewOp<PCF::WriteSliceOp>(
+ writeOp, expandedSource, genericRefArg, composedOffsets,
+ composedSizes, composedStrides);
+ }
+
+ rewriter.eraseOp(insertOp);
+ }
+}
+
+/// Core implementation of foldForallIntoPCFLoop after matching succeeds.
+static PCF::GenericOp foldForallIntoPCFLoopImpl(RewriterBase &rewriter,
+ scf::ForallOp forallOp,
+ PCF::LoopOp loopOp) {
+ Location loc = forallOp.getLoc();
+ scf::InParallelOp terminator = forallOp.getTerminator();
+
+ // Replace RegionIterArgs with initial values (except in terminator).
+ for (auto [iterArg, init] :
+ llvm::zip(forallOp.getRegionIterArgs(), forallOp.getOutputs())) {
+ rewriter.replaceUsesWithIf(iterArg, init, [&](OpOperand &use) {
+ return use.getOwner()->getParentOp() != terminator;
+ });
+ }
+
+ Value loopCount = loopOp.getCount()[0];
+
+ // Create pcf.generic.
+ auto genericOp = PCF::GenericOp::create(
+ rewriter, loc,
+ /*resultTypes=*/forallOp.getResultTypes(),
+ /*scope=*/loopOp.getScope(),
+ /*inits=*/forallOp.getOutputs(),
+ /*dynamic_sizes=*/ValueRange{},
+ /*is_tied=*/SmallVector<bool>(forallOp.getNumResults(), true),
+ /*num_iterators=*/1);
+
+ // Set sync scope to SyncOnReturn for the pcf.generic sref arguments.
+ Attribute syncScope = PCF::SyncOnReturnAttr::get(rewriter.getContext());
+ for (auto regionRefArg : genericOp.getRegionRefArgs()) {
+ auto srefType = cast<PCF::ShapedRefType>(regionRefArg.getType());
+ auto newSrefType = PCF::ShapedRefType::get(
+ rewriter.getContext(), srefType.getShape(), srefType.getElementType(),
+ srefType.getScope(), syncScope);
+ regionRefArg.setType(newSrefType);
+ }
+
+ Block *forallBody = &forallOp.getRegion().front();
+ Block *genericBody = &genericOp.getRegion().front();
+ rewriter.setInsertionPointToStart(genericBody);
+
+ // Compute per-dimension iteration counts.
+ SmallVector<OpFoldResult> iterCountOFRs =
+ computeForallIterCounts(rewriter, loc, forallOp);
+ int64_t numDims = iterCountOFRs.size();
+
+ SmallVector<Value> iterCountValues =
+ llvm::map_to_vector(iterCountOFRs, [&](OpFoldResult ofr) -> Value {
+ return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+ });
+
+ // Delinearize pcf.generic id into (forall dim IVs, pcf loop id) using
+ // a single delinearization with the full basis. This allows
+ // linearize/delinearize pairs to cancel during canonicalization.
+ SmallVector<OpFoldResult> fullBasis(iterCountOFRs);
+ fullBasis.push_back(loopCount);
+
+ BlockArgument linearId = genericOp.getIdArgs()[0];
+ auto delinOp = affine::AffineDelinearizeIndexOp::create(rewriter, loc,
+ linearId, fullBasis);
+
+ // Results 0..numDims-1 are forall dimension indices, last is pcf loop id.
+ SmallVector<Value> forallDimIvs;
+ for (int64_t i = 0; i < numDims; ++i) {
+ forallDimIvs.push_back(delinOp.getResult(i));
+ }
+ Value pcfLoopLinearId = delinOp.getResult(numDims);
+
+ // Reconstruct forall linear id using affine.linearize_index so that the
+ // linearize/delinearize pair can be folded by canonicalization.
+ Value forallLinearId =
+ affine::AffineLinearizeIndexOp::create(
+ rewriter, loc, forallDimIvs, ArrayRef<OpFoldResult>(iterCountOFRs),
+ /*disjoint=*/true)
+ .getResult();
+
+ // Compute total forall iteration count.
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ OpFoldResult totalItersOFR = iterCountOFRs[0];
+ for (int64_t i = 1, e = numDims; i < e; ++i) {
+ totalItersOFR = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 * s1, {totalItersOFR, iterCountOFRs[i]});
+ }
+ Value totalIters =
+ getValueOrCreateConstantIndexOp(rewriter, loc, totalItersOFR);
+
+ Value totalWorkers = genericOp.getCountArgs()[0];
+ // Each worker handles ceil(totalWorkers / loopCount) iterations.
+ AffineExpr ceilDiv = s0.ceilDiv(s1);
+ Value outerStep = getValueOrCreateConstantIndexOp(
+ rewriter, loc,
+ affine::makeComposedFoldedAffineApply(rewriter, loc, ceilDiv,
+ {totalWorkers, loopCount}));
+
+ auto outerForall = scf::ForallOp::create(
+ rewriter, loc, ArrayRef<OpFoldResult>{forallLinearId},
+ ArrayRef<OpFoldResult>{totalIters}, ArrayRef<OpFoldResult>{outerStep},
+ /*outputs=*/ValueRange{}, /*mapping=*/std::nullopt);
+
+ // Compute actual forall induction vars inside the outer forall.
+ rewriter.setInsertionPointToStart(outerForall.getBody());
+
+ auto forallIvDelinOp = affine::AffineDelinearizeIndexOp::create(
+ rewriter, loc, outerForall.getInductionVar(0), iterCountValues);
+
+ IRMapping forallMapping;
+ computeForallIVs(rewriter, loc, forallOp, forallIvDelinOp.getMultiIndex(),
+ forallMapping);
+
+ // Move forall body operations into outer forall (except terminator).
+ Block *outerForallBody = outerForall.getBody();
+ Operation *outerForallTerminator = outerForallBody->getTerminator();
+
+ for (Operation &op :
+ llvm::make_early_inc_range(forallBody->without_terminator())) {
+ op.moveBefore(outerForallTerminator);
+ }
+
+ // Remap induction var block arguments in moved operations.
+ for (Value iv : forallOp.getInductionVars()) {
+ Value mapped = forallMapping.lookup(iv);
+ iv.replaceAllUsesWith(mapped);
+ }
+
+ // Compose write_slice ops with parallel_insert_slice ops.
+ composeWriteSlicesIntoGeneric(rewriter, forallOp, loopOp, genericOp,
+ terminator, forallMapping);
+
+ // Replace forall results with generic results.
+ for (auto [forallResult, genericResult] :
+ llvm::zip(forallOp.getResults(), genericOp.getResults())) {
+ rewriter.replaceAllUsesWith(forallResult, genericResult);
+ }
+
+ rewriter.eraseOp(forallOp);
+
+ // Convert moved pcf.loop to inner scf.forall (no mapping = parallel).
+ // The pcf.loop was required to have a single count arg (linearized), so
+ // pcfLoopLinearId is sufficient as the sole id.
+ rewriter.setInsertionPoint(loopOp);
+
+ auto innerForall = scf::ForallOp::create(
+ rewriter, loc, ArrayRef<OpFoldResult>{pcfLoopLinearId},
+ ArrayRef<OpFoldResult>{loopCount},
+ ArrayRef<OpFoldResult>{getValueOrCreateConstantIndexOp(
+ rewriter, loc,
+ affine::makeComposedFoldedAffineApply(rewriter, loc, ceilDiv,
+ {totalWorkers, loopCount}))},
+ /*outputs=*/ValueRange{}, /*mapping=*/std::nullopt);
+
+ // Replace loop's id arg with inner forall's induction variable.
+ rewriter.replaceAllUsesWith(loopOp.getIdArgs()[0],
+ innerForall.getInductionVar(0));
+
+ // Move operations from loop body to inner forall.
+ Block *loopBody = &loopOp.getRegion().front();
+ Block *innerForallBody = innerForall.getBody();
+
+ innerForallBody->getOperations().splice(
+ std::prev(innerForallBody->end()), loopBody->getOperations(),
+ loopBody->begin(), std::prev(loopBody->end()));
+
+ rewriter.eraseOp(loopOp);
+
+ // Add terminator to generic's region.
+ rewriter.setInsertionPointToEnd(genericBody);
+ PCF::ReturnOp::create(rewriter, loc);
+
+ return genericOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass: TestFoldForallIntoPCFLoopPass
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the forall op has LocalMappingAttr mapping attributes,
+/// or the mapping is empty/not present.
+static bool hasEmptyOrLocalMapping(scf::ForallOp forallOp) {
+ std::optional<ArrayAttr> mapping = forallOp.getMapping();
+ if (!mapping || mapping->empty()) {
+ return true;
+ }
+ return llvm::all_of(mapping.value(),
+ llvm::IsaPred<IREE::Codegen::LocalMappingAttr>);
+}
+
+struct TestFoldForallIntoPCFLoopPass final
+ : impl::TestFoldForallIntoPCFLoopPassBase<TestFoldForallIntoPCFLoopPass> {
+ void runOnOperation() override {
+ SmallVector<scf::ForallOp> forallOps;
+ getOperation()->walk([&](scf::ForallOp forallOp) {
+ // Only match foralls with local_mapping attribute.
+ if (!hasEmptyOrLocalMapping(forallOp)) {
+ return;
+ }
+ // Check if there's a pcf.loop with sequential scope inside.
+ scf::InParallelOp terminator = forallOp.getTerminator();
+ Operation *lastOp = terminator->getPrevNode();
+ auto loopOp = dyn_cast_if_present<PCF::LoopOp>(lastOp);
+ if (!loopOp) {
+ return;
+ }
+ // Check for sequential scope.
+ if (!isa<PCF::SequentialAttr>(loopOp.getScope())) {
+ return;
+ }
+ forallOps.push_back(forallOp);
+ });
+
+ IRRewriter rewriter(&getContext());
+ for (scf::ForallOp forallOp : forallOps) {
+ rewriter.setInsertionPoint(forallOp);
+ FailureOr<PCF::GenericOp> result =
+ foldForallIntoPCFLoop(rewriter, forallOp);
+ if (failed(result)) {
+ forallOp.emitError("failed to fold forall into pcf.loop");
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Public API: foldForallIntoPCFLoop
+//===----------------------------------------------------------------------===//
+
+FailureOr<GenericOp> foldForallIntoPCFLoop(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
+ // Step 1: Validate terminator structure.
+ FailureOr<LoopOp> loopOpOrFailure = matchFoldTerminator(forallOp);
+ if (failed(loopOpOrFailure)) {
+ return rewriter.notifyMatchFailure(
+ forallOp, "Failed to validate forall op terminator");
+ }
+ LoopOp loopOp = *loopOpOrFailure;
+
+ if (failed(matchFoldPCFLoop(forallOp, loopOp))) {
+ return rewriter.notifyMatchFailure(forallOp,
+ "Failed to validate pcf.loop structure");
+ }
+
+ // Step 3: Validate write_slice ops.
+ if (failed(matchFoldWriteSlices(loopOp))) {
+ return rewriter.notifyMatchFailure(forallOp,
+ "Failed to validate write_slice ops");
+ }
+
+ // Step 4: Move count definitions into place before rewriting.
+ if (failed(moveValueDefinitions(rewriter, loopOp.getCount(), forallOp))) {
+ return rewriter.notifyMatchFailure(
+ forallOp, "Failed to move loop trip count definitions");
+ }
+
+ // All validations passed, perform the fold.
+ return foldForallIntoPCFLoopImpl(rewriter, forallOp, loopOp);
+}
+
+} // namespace mlir::iree_compiler::IREE::PCF
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp
index f254fd1..1213dda 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/FusePCFWrites.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.h"
#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -16,6 +17,8 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
+#include <cassert>
+
#define DEBUG_TYPE "iree-pcf-fuse-pcf-writes"
namespace mlir::iree_compiler::IREE::PCF {
@@ -64,6 +67,53 @@
} // namespace
+void composeNestedSliceParameters(
+ RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> outerOffsets,
+ ArrayRef<OpFoldResult> outerSizes, ArrayRef<OpFoldResult> outerStrides,
+ ArrayRef<OpFoldResult> innerOffsets, ArrayRef<OpFoldResult> innerSizes,
+ ArrayRef<OpFoldResult> innerStrides,
+ SmallVectorImpl<OpFoldResult> &composedOffsets,
+ SmallVectorImpl<OpFoldResult> &composedSizes,
+ SmallVectorImpl<OpFoldResult> &composedStrides) {
+ assert(outerOffsets.size() == outerSizes.size() &&
+ "outer slice offsets/sizes length mismatch");
+ assert(outerOffsets.size() == outerStrides.size() &&
+ "outer slice offsets/strides length mismatch");
+ assert(innerOffsets.size() == innerSizes.size() &&
+ "inner slice offsets/sizes length mismatch");
+ assert(innerOffsets.size() == innerStrides.size() &&
+ "inner slice offsets/strides length mismatch");
+ assert(outerOffsets.size() >= innerOffsets.size() &&
+ "inner slice rank cannot exceed outer slice rank");
+
+ composedOffsets.clear();
+ composedSizes.clear();
+ composedStrides.clear();
+ composedOffsets.reserve(outerOffsets.size());
+ composedSizes.reserve(outerOffsets.size());
+ composedStrides.reserve(outerOffsets.size());
+
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr composeOffExpr = s0 + s1 * s2;
+ AffineExpr mulExpr = s0 * s1;
+
+ for (int64_t i = 0, e = innerOffsets.size(); i < e; ++i) {
+ composedOffsets.push_back(affine::makeComposedFoldedAffineApply(
+ rewriter, loc, composeOffExpr,
+ {outerOffsets[i], innerOffsets[i], outerStrides[i]}));
+ composedSizes.push_back(innerSizes[i]);
+ composedStrides.push_back(affine::makeComposedFoldedAffineApply(
+ rewriter, loc, mulExpr, {outerStrides[i], innerStrides[i]}));
+ }
+
+ for (int64_t i = innerOffsets.size(), e = outerOffsets.size(); i < e; ++i) {
+ composedOffsets.push_back(outerOffsets[i]);
+ composedSizes.push_back(outerSizes[i]);
+ composedStrides.push_back(outerStrides[i]);
+ }
+}
+
FailureOr<PCF::WriteSliceOp>
composeWriteSliceWithParallelInsert(RewriterBase &rewriter,
PCF::WriteSliceOp writeSliceOp) {
@@ -133,47 +183,24 @@
// - sizes: insertSlice.sizes
// - strides: writeSlice.strides * insertSlice.strides
- SmallVector<OpFoldResult> composedOffsets;
- SmallVector<OpFoldResult> composedSizes = insertSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> composedStrides;
-
OpBuilder::InsertionGuard guard(rewriter);
// Insert before the in_parallel terminator, not inside it.
rewriter.setInsertionPoint(inParallelOp);
SmallVector<OpFoldResult> writeOffsets = writeSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> writeSizes = writeSliceOp.getMixedSizes();
SmallVector<OpFoldResult> insertOffsets = insertSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> insertSizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> writeStrides = writeSliceOp.getMixedStrides();
SmallVector<OpFoldResult> insertStrides = insertSliceOp.getMixedStrides();
- // Compose offsets: writeOffset + insertOffset * writeStride.
- for (auto [writeOffset, insertOffset, writeStride] :
- llvm::zip_equal(writeOffsets, insertOffsets, writeStrides)) {
- Value writeOffsetVal = getValueOrCreateConstantIndexOp(
- rewriter, insertSliceOp.getLoc(), writeOffset);
- Value insertOffsetVal = getValueOrCreateConstantIndexOp(
- rewriter, insertSliceOp.getLoc(), insertOffset);
- Value writeStrideVal = getValueOrCreateConstantIndexOp(
- rewriter, insertSliceOp.getLoc(), writeStride);
-
- Value scaled = rewriter.createOrFold<arith::MulIOp>(
- insertSliceOp.getLoc(), insertOffsetVal, writeStrideVal);
- Value composed = rewriter.createOrFold<arith::AddIOp>(
- insertSliceOp.getLoc(), writeOffsetVal, scaled);
- composedOffsets.push_back(composed);
- }
-
- // Compose strides: writeStride * insertStride.
- for (auto [writeStride, insertStride] :
- llvm::zip_equal(writeStrides, insertStrides)) {
- Value writeStrideVal = getValueOrCreateConstantIndexOp(
- rewriter, insertSliceOp.getLoc(), writeStride);
- Value insertStrideVal = getValueOrCreateConstantIndexOp(
- rewriter, insertSliceOp.getLoc(), insertStride);
- Value composed = rewriter.createOrFold<arith::MulIOp>(
- insertSliceOp.getLoc(), writeStrideVal, insertStrideVal);
- composedStrides.push_back(composed);
- }
+ SmallVector<OpFoldResult> composedOffsets;
+ SmallVector<OpFoldResult> composedSizes;
+ SmallVector<OpFoldResult> composedStrides;
+ composeNestedSliceParameters(rewriter, insertSliceOp.getLoc(), writeOffsets,
+ writeSizes, writeStrides, insertOffsets,
+ insertSizes, insertStrides, composedOffsets,
+ composedSizes, composedStrides);
// Handle rank-reduced parallel_insert_slice sources.
// The source may have fewer dimensions than the destination sref (e.g.,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td
index 20e6187..a4e960a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td
@@ -33,6 +33,36 @@
let dependentDialects = ["::mlir::iree_compiler::IREE::PCF::PCFDialect"];
}
+def TestFoldForallIntoPCFLoopPass : Pass<"iree-pcf-test-fold-forall-into-pcf-loop", ""> {
+ let summary = "Test pass for foldForallIntoPCFLoop transform";
+ let description = [{
+ Test pass for folding `scf.forall` ops containing `pcf.loop` ops into a
+ single `pcf.generic` operation.
+
+ The input is IR containing `scf.forall` ops with `iree_codegen.local_mapping`
+ attributes that contain a `pcf.loop` with `#pcf.sequential` scope as the
+ last operation before the terminator.
+
+ Structural requirements validated:
+ - All ops in `scf.forall.in_parallel` are `tensor.parallel_insert_slice`.
+ - All insert sources come from the same `pcf.loop` result.
+ - All insert destinations are `scf.forall` shared_outs.
+ - The `pcf.loop` is the last op before terminator in forall body.
+ - The `pcf.loop` has a single count argument (linearized).
+ - All `pcf.loop` region ref arg users are `pcf.write_slice` ops.
+ - Ref args have `SyncOnReturnAttr` sync scope.
+
+ The underlying transform is exposed via `foldForallIntoPCFLoop()` for use
+ in custom pipelines with different matching criteria.
+ }];
+ let dependentDialects = [
+ "::mlir::iree_compiler::IREE::PCF::PCFDialect",
+ "::mlir::arith::ArithDialect",
+ "::mlir::affine::AffineDialect",
+ "::mlir::scf::SCFDialect"
+ ];
+}
+
def FuseConsumersPass : Pass<"iree-pcf-fuse-consumers", ""> {
let summary = "Fuses all consumers of pcf.generic/loop ops";
let description = [{
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h
index deea43e..5a01628 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h
@@ -192,6 +192,26 @@
PCF::GenericOp genericOp,
tensor::CollapseShapeOp collapseOp);
+// Composes two nested slice parameter sets so the resulting slice addresses the
+// outer base directly. The outer slice describes how an intermediate value is
+// embedded in the final destination, and the inner slice describes how a value
+// is embedded in that intermediate.
+//
+// For each dimension present in both slices:
+// offsets = outerOffset + innerOffset * outerStride
+// sizes = innerSizes
+// strides = outerStride * innerStride
+//
+// Any remaining outer dimensions are forwarded unchanged.
+void composeNestedSliceParameters(
+ RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> outerOffsets,
+ ArrayRef<OpFoldResult> outerSizes, ArrayRef<OpFoldResult> outerStrides,
+ ArrayRef<OpFoldResult> innerOffsets, ArrayRef<OpFoldResult> innerSizes,
+ ArrayRef<OpFoldResult> innerStrides,
+ SmallVectorImpl<OpFoldResult> &composedOffsets,
+ SmallVectorImpl<OpFoldResult> &composedSizes,
+ SmallVectorImpl<OpFoldResult> &composedStrides);
+
// Composes a pcf.write_slice with a tensor.parallel_insert_slice from an
// scf.forall terminator. The write_slice's destination must be produced by the
// forall op, and the parallel_insert_slice must be inserting into that result.
@@ -200,6 +220,24 @@
composeWriteSliceWithParallelInsert(RewriterBase &rewriter,
PCF::WriteSliceOp writeSliceOp);
+/// Folds an scf.forall containing a pcf.loop into a single pcf.generic.
+///
+/// Validates structural requirements:
+/// - All ops in scf.forall.in_parallel are tensor.parallel_insert_slice.
+/// - All insert sources come from the same pcf.loop result.
+/// - All insert destinations are scf.forall shared_outs.
+/// - The pcf.loop is the last op before terminator in forall body.
+/// - The pcf.loop has single count argument (linearized).
+/// - All pcf.loop region ref arg users are pcf.write_slice ops.
+/// - Ref args have SyncOnReturnAttr sync scope.
+///
+/// Does NOT validate mapping or scope attributes (caller's responsibility).
+/// Does NOT create workgroup_count_hint (caller's responsibility).
+///
+/// On success, the forall and inner loop are replaced with a pcf.generic.
+FailureOr<PCF::GenericOp> foldForallIntoPCFLoop(RewriterBase &rewriter,
+ scf::ForallOp forallOp);
+
} // namespace mlir::iree_compiler::IREE::PCF
#endif // IREE_COMPILER_CODEGEN_DIALECT_PCF_TRANSFORMS_TRANSFORMS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel
index d68b2e2..4bfc338 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel
@@ -22,6 +22,7 @@
"convert_forall_to_generic_nest.mlir",
"convert_forall_to_loops.mlir",
"convert_sref_to_memref.mlir",
+ "fold_forall_into_pcf_loop.mlir",
"fuse_collapse_shape.mlir",
"fuse_consumers.mlir",
"fuse_pcf_writes.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt
index 0d56caa..c10934d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"convert_forall_to_generic_nest.mlir"
"convert_forall_to_loops.mlir"
"convert_sref_to_memref.mlir"
+ "fold_forall_into_pcf_loop.mlir"
"fuse_collapse_shape.mlir"
"fuse_consumers.mlir"
"fuse_pcf_writes.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fold_forall_into_pcf_loop.mlir b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fold_forall_into_pcf_loop.mlir
new file mode 100644
index 0000000..a3d35ec
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fold_forall_into_pcf_loop.mlir
@@ -0,0 +1,229 @@
+// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-test-fold-forall-into-pcf-loop)" --mlir-print-local-scope --split-input-file | FileCheck %s
+
+// Test folding scf.forall containing pcf.loop into a single pcf.generic.
+// Forall has 2D iteration space (4, 8) with loop count 4.
+// The write_slice writes at [loop_id, 0] with size [1, 4] into the 4x4 tile.
+// The parallel_insert_slice inserts at [id0, id1] with size [4, 4].
+// Composed write should be at [loop_id + id0, id1] with size [1, 4] strides [1, 1].
+
+func.func @fold_forall_into_pcf_loop(%init: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %c4 = arith.constant 4 : index
+ %0 = scf.forall (%id0, %id1) in (4, 8) shared_outs(%iter = %init) -> (tensor<16x32xf32>) {
+ %tile_init = tensor.extract_slice %iter[%id0, %id1] [4, 4] [1, 1]
+ : tensor<16x32xf32> to tensor<4x4xf32>
+ %loop_result = pcf.loop scope(#pcf.sequential) count(%c4)
+ execute(%ref = %tile_init)[%loop_id: index]
+ : (!pcf.sref<4x4xf32, sync(#pcf.sequential)>)
+ -> (tensor<4x4xf32>) {
+ %slice = tensor.extract_slice %init[%id0, %loop_id] [1, 4] [1, 1]
+ : tensor<16x32xf32> to tensor<1x4xf32>
+ pcf.write_slice %slice into %ref[%loop_id, 0] [1, 4] [1, 1]
+ : tensor<1x4xf32> into !pcf.sref<4x4xf32, sync(#pcf.sequential)>
+ pcf.return
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %loop_result into %iter[%id0, %id1] [4, 4] [1, 1]
+ : tensor<4x4xf32> into tensor<16x32xf32>
+ }
+ } {mapping = [#iree_codegen.local_mapping<0>, #iree_codegen.local_mapping<1>]}
+ return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @fold_forall_into_pcf_loop
+// CHECK-SAME: %[[INIT:[A-Za-z0-9_]+]]: tensor<16x32xf32>
+
+// CHECK: %[[GENERIC:.+]] = pcf.generic
+// CHECK: scope(#pcf.sequential)
+// CHECK: execute(%[[REF:[A-Za-z0-9_]+]] = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %[[GEN_COUNT:[A-Za-z0-9_]+]]: index]
+// CHECK: : (!pcf.sref<16x32xf32, sync(#pcf.sequential)>)
+// CHECK: -> (tensor<16x32xf32>) {
+
+// Delinearize generic id into (forall_dim0, forall_dim1, pcf_loop_id) with
+// full basis [4, 8, loop_count].
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[GEN_ID]] into
+// CHECK-SAME: : index, index, index
+
+// Linearize forall dim indices back to forall_linear_id (cancels with delinearize).
+// CHECK: %[[FORALL_LIN:.+]] = affine.linearize_index disjoint [%[[DELIN]]#0, %[[DELIN]]#1] by (4, 8)
+
+// Outer scf.forall starts at forall_linear_id, upper bound = 4*8 = 32.
+// CHECK: scf.forall (%[[OUTER_IV:.+]]) = (%[[FORALL_LIN]]) to
+// CHECK-SAME: {
+
+// Delinearize outer IV into 2D forall space (4, 8).
+// CHECK: %[[FORALL_DELIN:.+]]:2 = affine.delinearize_index %[[OUTER_IV]] into
+// CHECK-SAME: : index, index
+
+// Inner scf.forall starts at pcf_loop_id, upper bound = loop count.
+// CHECK: scf.forall (%[[INNER_IV:.+]]) = (%[[DELIN]]#2) to
+// CHECK-SAME: {
+
+// Composed write: sizes [1, 4] and strides [1, 1] from write_slice.
+// Offset dim 0 = insertOff(id0) + writeOff(loop_id) * insertStride(1).
+// Offset dim 1 = insertOff(id1) + writeOff(0) * insertStride(1) = id1.
+// CHECK: %[[COMPOSED_OFF_0:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[FORALL_DELIN]]#0, %[[INNER_IV]]]
+// CHECK: pcf.write_slice %{{.+}} into %[[REF]][%[[COMPOSED_OFF_0]], %[[FORALL_DELIN]]#1]
+// CHECK-SAME: [1, 4] [1, 1]
+// CHECK-SAME: into !pcf.sref<16x32xf32, sync(#pcf.sequential)>
+// CHECK: pcf.return
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+// Test with multiple results from pcf.loop.
+// Both results get composed writes targeting different ref args.
+
+func.func @fold_forall_multiple_results(%init0: tensor<16xf32>, %init1: tensor<16xf32>)
+ -> (tensor<16xf32>, tensor<16xf32>) {
+ %c2 = arith.constant 2 : index
+ %0:2 = scf.forall (%id) in (4) shared_outs(%iter0 = %init0, %iter1 = %init1)
+ -> (tensor<16xf32>, tensor<16xf32>) {
+ %tile_init0 = tensor.extract_slice %iter0[%id] [4] [1]
+ : tensor<16xf32> to tensor<4xf32>
+ %tile_init1 = tensor.extract_slice %iter1[%id] [4] [1]
+ : tensor<16xf32> to tensor<4xf32>
+ %loop_result:2 = pcf.loop scope(#pcf.sequential) count(%c2)
+ execute(%ref0 = %tile_init0, %ref1 = %tile_init1)[%loop_id: index]
+ : (!pcf.sref<4xf32, sync(#pcf.sequential)>,
+ !pcf.sref<4xf32, sync(#pcf.sequential)>)
+ -> (tensor<4xf32>, tensor<4xf32>) {
+ %slice0 = tensor.extract_slice %init0[%loop_id] [2] [1]
+ : tensor<16xf32> to tensor<2xf32>
+ %slice1 = tensor.extract_slice %init1[%loop_id] [2] [1]
+ : tensor<16xf32> to tensor<2xf32>
+ pcf.write_slice %slice0 into %ref0[%loop_id] [2] [1]
+ : tensor<2xf32> into !pcf.sref<4xf32, sync(#pcf.sequential)>
+ pcf.write_slice %slice1 into %ref1[%loop_id] [2] [1]
+ : tensor<2xf32> into !pcf.sref<4xf32, sync(#pcf.sequential)>
+ pcf.return
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %loop_result#0 into %iter0[%id] [4] [1]
+ : tensor<4xf32> into tensor<16xf32>
+ tensor.parallel_insert_slice %loop_result#1 into %iter1[%id] [4] [1]
+ : tensor<4xf32> into tensor<16xf32>
+ }
+ } {mapping = [#iree_codegen.local_mapping<0>]}
+ return %0#0, %0#1 : tensor<16xf32>, tensor<16xf32>
+}
+
+// CHECK-LABEL: @fold_forall_multiple_results
+// CHECK-SAME: %[[INIT0:[A-Za-z0-9_]+]]: tensor<16xf32>
+// CHECK-SAME: %[[INIT1:[A-Za-z0-9_]+]]: tensor<16xf32>
+
+// CHECK: %[[GENERIC:.+]]:2 = pcf.generic
+// CHECK: scope(#pcf.sequential)
+// CHECK: execute(%[[REF0:[A-Za-z0-9_]+]] = %[[INIT0]], %[[REF1:[A-Za-z0-9_]+]] = %[[INIT1]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %{{.*}}: index]
+// CHECK: : (!pcf.sref<16xf32, sync(#pcf.sequential)>, !pcf.sref<16xf32, sync(#pcf.sequential)>)
+// CHECK: -> (tensor<16xf32>, tensor<16xf32>) {
+
+// Delinearize generic id into (forall_dim0, pcf_loop_id).
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into
+// CHECK-SAME: : index, index
+
+// Linearize forall dim index (1D, so linearize is trivial).
+// CHECK: %[[FORALL_LIN:.+]] = affine.linearize_index disjoint [%[[DELIN]]#0] by (4)
+
+// Outer scf.forall from forall_linear_id.
+// CHECK: scf.forall (%{{.+}}) = (%[[FORALL_LIN]])
+
+// Inner scf.forall from pcf_loop_id.
+// CHECK: scf.forall (%{{.+}}) = (%[[DELIN]]#1)
+
+// Composed writes: write offset[loop_id] + insert offset[id].
+// Both writes have size 2 and stride 1, targeting different ref args.
+// CHECK: %[[COMPOSED_OFF_0:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.+}}, %{{.+}}]
+// CHECK: pcf.write_slice %{{.+}} into %[[REF0]][%[[COMPOSED_OFF_0]]] [2] [1]
+// CHECK-SAME: into !pcf.sref<16xf32, sync(#pcf.sequential)>
+// CHECK: %[[COMPOSED_OFF_1:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.+}}, %{{.+}}]
+// CHECK: pcf.write_slice %{{.+}} into %[[REF1]][%[[COMPOSED_OFF_1]]] [2] [1]
+// CHECK-SAME: into !pcf.sref<16xf32, sync(#pcf.sequential)>
+// CHECK: pcf.return
+// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1
+
+// -----
+
+// Test write_slice + parallel_insert_slice composition with non-unit write
+// stride. The write has stride 2, the insert has stride 1.
+// Composed stride = 2 * 1 = 2. Write size = 2, stays 2.
+
+func.func @fold_compose_strides(%init: tensor<64xf32>) -> tensor<64xf32> {
+ %c3 = arith.constant 3 : index
+ %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<64xf32>) {
+ %tile_init = tensor.extract_slice %iter[%id] [16] [1]
+ : tensor<64xf32> to tensor<16xf32>
+ %loop_result = pcf.loop scope(#pcf.sequential) count(%c3)
+ execute(%ref = %tile_init)[%loop_id: index]
+ : (!pcf.sref<16xf32, sync(#pcf.sequential)>)
+ -> (tensor<16xf32>) {
+ %val = tensor.extract_slice %init[0] [2] [1]
+ : tensor<64xf32> to tensor<2xf32>
+ pcf.write_slice %val into %ref[1] [2] [2]
+ : tensor<2xf32> into !pcf.sref<16xf32, sync(#pcf.sequential)>
+ pcf.return
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %loop_result into %iter[%id] [16] [1]
+ : tensor<16xf32> into tensor<64xf32>
+ }
+ } {mapping = [#iree_codegen.local_mapping<0>]}
+ return %0 : tensor<64xf32>
+}
+
+// Composed offset = insert(id) + write(1) * insert_stride(1) = id + 1.
+// Size = 2 (from write). Stride = 2 * 1 = 2.
+// CHECK-LABEL: @fold_compose_strides
+// CHECK-SAME: %[[INIT:[A-Za-z0-9_]+]]: tensor<64xf32>
+// CHECK: %[[GENERIC:.+]] = pcf.generic
+// CHECK: execute(%[[REF:[A-Za-z0-9_]+]] = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index
+// CHECK: : (!pcf.sref<64xf32, sync(#pcf.sequential)>)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into
+// CHECK: %[[FORALL_LIN:.+]] = affine.linearize_index disjoint
+// CHECK: %[[OUTER_STEP:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 3)>()[%{{.+}}]
+// CHECK: scf.forall (%[[OUTER_IV:.+]]) = (%[[FORALL_LIN]]) to (%{{.+}}) step (%[[OUTER_STEP]])
+// CHECK: %[[INNER_STEP:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 3)>()[%{{.+}}]
+// CHECK: scf.forall (%[[INNER_IV:.+]]) = (%[[DELIN]]#1) to (%{{.+}}) step (%[[INNER_STEP]])
+// Composed: size 2, stride 2 (write_stride=2 * insert_stride=1).
+// CHECK: %[[COMPOSED_OFF:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[OUTER_IV]]]
+// CHECK: pcf.write_slice %{{.+}} into %[[REF]][%[[COMPOSED_OFF]]] [2] [2]
+// CHECK-SAME: into !pcf.sref<64xf32, sync(#pcf.sequential)>
+// CHECK: pcf.return
+
+// -----
+
+// Test folding when the pcf.loop count is computed in the forall body and
+// needs to be hoisted before rewriting.
+func.func @fold_hoists_loop_count(%init: tensor<8xf32>, %one: index)
+ -> tensor<8xf32> {
+ %0 = scf.forall (%id) in (4) shared_outs(%iter = %init) -> (tensor<8xf32>) {
+ %tile_init = tensor.extract_slice %iter[%id] [2] [1]
+ : tensor<8xf32> to tensor<2xf32>
+ %count = arith.addi %one, %one : index
+ %loop_result = pcf.loop scope(#pcf.sequential) count(%count)
+ execute(%ref = %tile_init)[%loop_id: index]
+ : (!pcf.sref<2xf32, sync(#pcf.sequential)>)
+ -> (tensor<2xf32>) {
+ %slice = tensor.extract_slice %init[%loop_id] [1] [1]
+ : tensor<8xf32> to tensor<1xf32>
+ pcf.write_slice %slice into %ref[%loop_id] [1] [1]
+ : tensor<1xf32> into !pcf.sref<2xf32, sync(#pcf.sequential)>
+ pcf.return
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %loop_result into %iter[%id] [2] [1]
+ : tensor<2xf32> into tensor<8xf32>
+ }
+ } {mapping = [#iree_codegen.local_mapping<0>]}
+ return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: @fold_hoists_loop_count
+// CHECK-SAME: %[[INIT:[A-Za-z0-9_]+]]: tensor<8xf32>, %[[ONE:[A-Za-z0-9_]+]]: index
+// CHECK: %[[COUNT:[A-Za-z0-9_]+]] = arith.addi %[[ONE]], %[[ONE]] : index
+// CHECK: %[[GENERIC:.+]] = pcf.generic
+// CHECK: execute(%{{.+}} = %[[INIT]])[%[[GEN_ID:[A-Za-z0-9_]+]]: index, %[[GEN_COUNT:[A-Za-z0-9_]+]]: index]
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[GEN_ID]] into (4, %[[COUNT]])
+// CHECK: %[[OUTER_STEP:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 ceildiv s1)>()[%[[GEN_COUNT]], %[[COUNT]]]
+// CHECK: scf.forall (%{{.+}}) = (%{{.+}}) to (%{{.+}}) step (%[[OUTER_STEP]])
+// CHECK: %[[INNER_STEP:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 ceildiv s1)>()[%[[GEN_COUNT]], %[[COUNT]]]
+// CHECK: scf.forall (%{{.+}}) = (%[[DELIN]]#1) to (%[[COUNT]]) step (%[[INNER_STEP]])
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir
index 55fcb90..5190118 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/fuse_pcf_writes.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-fuse-pcf-writes)" --split-input-file | FileCheck %s
+// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-fuse-pcf-writes)" --mlir-print-local-scope --split-input-file | FileCheck %s
func.func @fuse_write_slice_with_parallel_insert(%init: tensor<32x64xf32>, %dest: !pcf.sref<32x64xf32, sync(#pcf.sequential)>) {
%result = scf.forall (%i, %j) in (4, 8) shared_outs(%iter = %init) -> tensor<32x64xf32> {
@@ -48,10 +48,9 @@
// CHECK-SAME: %[[INIT:[A-Za-z0-9_]+]]: tensor<32x64xf32>
// CHECK-SAME: %[[DEST:[A-Za-z0-9_]+]]: !pcf.sref<32x64xf32, sync(#pcf.sequential)>
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK: scf.forall (%[[I:.+]], %[[J:.+]]) in (4, 8) {
// CHECK: %[[TILE:.+]] = tensor.generate
-// CHECK: %[[COMPOSED_OFFSET:.+]] = arith.addi %[[I]], %[[C16]]
+// CHECK: %[[COMPOSED_OFFSET:.+]] = affine.apply affine_map<()[s0] -> (s0 + 16)>()[%[[I]]]
// CHECK: pcf.write_slice %[[TILE]] into %[[DEST]][%[[COMPOSED_OFFSET]], %[[J]]] [8, 8] [1, 1]
// CHECK: }
// CHECK-NOT: pcf.write_slice
@@ -78,11 +77,10 @@
// CHECK-SAME: %[[INIT:[A-Za-z0-9_]+]]: tensor<32x64xf32>
// CHECK-SAME: %[[DEST:[A-Za-z0-9_]+]]: !pcf.sref<64x128xf32, sync(#pcf.sequential)>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: scf.forall (%[[I:.+]], %[[J:.+]]) in (4, 8) {
// CHECK: %[[TILE:.+]] = tensor.generate
-// CHECK-DAG: %[[OFFSET_0:.+]] = arith.muli %[[I]], %[[C2]]
-// CHECK-DAG: %[[OFFSET_1:.+]] = arith.muli %[[J]], %[[C2]]
+// CHECK-DAG: %[[OFFSET_0:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[I]]]
+// CHECK-DAG: %[[OFFSET_1:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[J]]]
// CHECK: pcf.write_slice %[[TILE]] into %[[DEST]][%[[OFFSET_0]], %[[OFFSET_1]]] [8, 8] [2, 2]
// CHECK: }
// CHECK-NOT: pcf.write_slice
@@ -140,10 +138,10 @@
// CHECK-SAME: %[[OFFSET_BASE:[A-Za-z0-9_]+]]: index
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[OFFSET:.+]] = arith.addi %[[OFFSET_BASE]], %[[C16]]
+// CHECK-DAG: %[[OFFSET:.+]] = arith.addi %[[OFFSET_BASE]], %[[C16]] : index
// CHECK: scf.forall (%[[I:.+]], %[[J:.+]]) in (4, 8) {
// CHECK: %[[TILE:.+]] = tensor.generate
-// CHECK: %[[COMPOSED_OFFSET:.+]] = arith.addi %[[OFFSET]], %[[I]]
+// CHECK: %[[COMPOSED_OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET]], %[[I]]]
// CHECK: pcf.write_slice %[[TILE]] into %[[DEST]][%[[COMPOSED_OFFSET]], %[[J]]] [8, 8] [1, 1]
// CHECK: }
// CHECK-NOT: pcf.write_slice