[LinalgExt] Retire RewriteForallToScfForOp transform op. (#16064)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index b9d52cf..9fd2a32 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -50,44 +50,6 @@
}];
}
-def RewriteForallToScfForOp :
- Op<Transform_Dialect, "forall_to_scf_for",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
-
- let description = [{
- Rewrite a bufferized scf.forall to a sequential scf.for.
-
- Return modes:
- =============
- This operation ignores non-Linalg ops and drops them in the return.
- This transform is currently only implemented for 1-D scf.forall that
- have been bufferized and definitely fail for the rest.
-
- If all the operations referred to by the `target` operand lower
- properly, the transform succeeds. Otherwise the transform silently fails.
-
- The returned handle points to only the subset of successfully produced
- scf.for operations, which can all be empty.
- }];
- let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface:$transformed);
-
- let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
- let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::scf::ForallOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
def TileAttentionOp : Op<Transform_Dialect, "tile_attention",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 7bc5970..3ff44a5 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -40,20 +40,6 @@
}
};
-/// Pattern to rewrite a ForallOp to an scf::ForOp.
-struct ForallOpToScfForRewriter : public OpRewritePattern<scf::ForallOp> {
- using OpRewritePattern::OpRewritePattern;
-
- FailureOr<scf::ForOp>
- returningMatchAndRewrite(scf::ForallOp forallOp,
- PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(scf::ForallOp forallOp,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(forallOp, rewriter);
- }
-};
-
//===----------------------------------------------------------------------===//
// Transformations exposed as patterns, moved from upstream MLIR as IREE still
// heavily relies on patterns that compose through filters.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index f9e5215..b97d0b0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -45,20 +45,6 @@
return DiagnosedSilenceableFailure::success();
}
-DiagnosedSilenceableFailure LinalgExt::RewriteForallToScfForOp::applyToOne(
- transform::TransformRewriter &rewriter, scf::ForallOp target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- LinalgExt::ForallOpToScfForRewriter pattern(this->getContext());
- SimplePatternRewriter patternRewriter(target);
- FailureOr<Operation *> result =
- pattern.returningMatchAndRewrite(target, patternRewriter);
- if (failed(result))
- return emitDefaultDefiniteFailure(target);
- results.push_back(*result);
- return DiagnosedSilenceableFailure::success();
-}
-
//===---------------------------------------------------------------------===//
// TileAndDecomposeAttention
//===---------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
index e75f0c1..73c75d8 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_library(IREELinalgExtTransforms
ForeachThreadToAsync.cpp
- ForeachThreadToSequentialFor.cpp
Utils.cpp
PARTIAL_SOURCES_INTENDED
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
deleted file mode 100644
index e02fc31..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-
-using namespace mlir;
-using namespace mlir::iree_compiler::IREE::LinalgExt;
-
-namespace {
-
-SmallVector<Value> getValuesToYield(scf::InParallelOp op) {
- return llvm::map_to_vector(op.getYieldingOps(), [](Operation &op) -> Value {
- return cast<tensor::ParallelInsertSliceOp>(&op).getDest();
- });
-}
-
-} // namespace
-
-FailureOr<scf::ForOp> ForallOpToScfForRewriter::returningMatchAndRewrite(
- scf::ForallOp forallOp, PatternRewriter &rewriter) const {
- if (forallOp.getNumResults() > 0)
- return forallOp->emitError("only bufferized scf.forall lowers to scf.for");
-
- if (forallOp.getRank() > 1)
- return forallOp->emitError(
- "only single-dimension scf.forall lowers to scf.for");
-
- // Construct the loop bounds based on the canonical arithmetic progression.
- Location loc = forallOp.getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- // TODO: allow multi-dim.
- Value numThreads = forallOp.getUpperBound(rewriter).front();
-
- // Construct the op without a body builder: we need to clone the ops in the
- // body explicitly after having access to the new bbArgs.
- // As a consequence, `ensureTerminator` is not called and the `forOp` body
- // has no terminator.
- scf::InParallelOp InParallelOp = forallOp.getTerminator();
- SmallVector<Value> valuesToYield = getValuesToYield(InParallelOp);
- scf::ForOp forOp =
- rewriter.create<scf::ForOp>(loc, zero, numThreads, one, valuesToYield);
-
- // Move the body while replacing the threadId by the forOp iv.
- SmallVector<Value> bbArgsTranslated{forOp.getInductionVar()};
- Block *body = forOp.getBody();
- bool hasTerminator =
- !body->empty() && body->back().hasTrait<OpTrait::IsTerminator>();
- if (hasTerminator) {
- rewriter.inlineBlockBefore(&forallOp.getRegion().front(),
- body->getTerminator(), bbArgsTranslated);
- } else {
- rewriter.mergeBlocks(&forallOp.getRegion().front(), body, bbArgsTranslated);
- }
-
- rewriter.setInsertionPointToStart(body);
- IRMapping bvm;
- bvm.map(valuesToYield, forOp.getRegionIterArgs());
-
- // Create sequential insertSlice ops.
- SmallVector<Value> toYield;
- rewriter.setInsertionPoint(InParallelOp);
- for (Operation &operation : InParallelOp.getYieldingOps()) {
- tensor::ParallelInsertSliceOp op =
- cast<tensor::ParallelInsertSliceOp>(&operation);
- toYield.push_back(rewriter.createOrFold<tensor::InsertSliceOp>(
- loc, op.getSource(), bvm.lookup(op.getDest()), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides()));
- }
-
- // InParallelOp.yieldedValues come from above, not from bbArgs.
- // There is no rewriter method to make mergeBlocks update non-bbArgs.
- // Need to manually clone + bvm all uses that are now nested under forOp.
- // Warning: this replacement is currently optimistic and may change the
- // semantics as explained in the pass description in Passes.td.
- SmallVector<Operation *> opsToReplace;
- for (Value toReplace : valuesToYield) {
- for (OpOperand &u : toReplace.getUses()) {
- Operation *op = u.getOwner();
- if (!forOp->isProperAncestor(op))
- continue;
- opsToReplace.push_back(op);
- }
- }
- for (Operation *op : opsToReplace) {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(op);
- Operation *cloned = rewriter.clone(*op, bvm);
- rewriter.replaceOp(op, cloned->getResults());
- }
-
- // Insert terminator.
- if (!hasTerminator) {
- rewriter.setInsertionPointToEnd(body);
- rewriter.create<scf::YieldOp>(loc, toYield);
- }
-
- // Cleanup and replace.
- rewriter.eraseOp(InParallelOp);
- rewriter.replaceOp(forallOp, forOp.getResults());
-
- return forOp;
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
deleted file mode 100644
index 6384abc..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
+++ /dev/null
@@ -1,48 +0,0 @@
-// RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s
-
-#map0 = affine_map<(d0)[s0] -> (d0 ceildiv s0)>
-#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
-#map2 = affine_map<(d0, d1) -> (d0 - d1)>
-#map3 = affine_map<(d0, d1) -> (d0, d1)>
-#map4 = affine_map<(d0) -> (d0)>
-
-// CHECK-LABEL: func.func @static_tile_buffers
-// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
-// CHECK-SAME: %[[IN:[0-9a-z]+]]: memref<?xf32>
-// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<?xf32>
-func.func @static_tile_buffers(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- %cst = arith.constant 4.200000e+01 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?xf32>
- %1 = affine.apply #map0(%0)[%arg0]
-
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
- // CHECK: scf.for %[[IV:.*]] = {{.*}}
- scf.forall (%arg3) in (%1) shared_outs() -> () {
- %3 = affine.apply #map1(%arg3)[%arg0]
- %4 = affine.apply #map2(%0, %3)
- %5 = affine.min #map3(%4, %arg0)
-
- %6 = memref.subview %arg2[%3] [%5] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- %7 = memref.subview %arg1[%3] [%5] [1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
-
- linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]}
- ins(%7 : memref<?xf32, strided<[?], offset:?>>) outs(%6 : memref<?xf32, strided<[?], offset:?>>) {
- ^bb0(%arg4: f32, %arg5: f32): // no predecessors
- %9 = arith.mulf %arg4, %cst : f32
- linalg.yield %9 : f32
- }
-
- // Nothing is yielded, skip the terminator.
- // CHECK-NOT: scf.yield
- }
- return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
- %0 = transform.structured.match ops{["scf.forall"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %1 = forall_to_scf_for %0 : (!transform.any_op) -> !transform.any_op
-}