[LinalgExt] Delete fuse_producer transform op. (#16044)
We have equivalent op in upstream, which is
transform.structured.fuse_into_containing_op.
The corresponding test can be found at
https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index c6fdce9..b9d52cf 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -12,23 +12,6 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
-def FuseProducersOp : Op<Transform_Dialect, "fuse_producers",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{Fuses the producers for the operands to fuse.}];
-
- let arguments = (ins TransformHandleTypeInterface:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$operands_to_fuse);
- let results = (outs TransformHandleTypeInterface:$transformed,
- Variadic<TransformHandleTypeInterface>:$fused_ops);
-
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
- let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
-}
-
def RewriteForallToAsyncOp :
Op<Transform_Dialect, "forall_to_async",
[FunctionalStyleTransformOpTrait,
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index d835f73..7bc5970 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -54,31 +54,6 @@
}
};
-struct FusionResult {
- TilingInterface consumerOp;
- SmallVector<TilingInterface> fusedOps;
-};
-
-/// Pattern to fuse the producers of a tileable op.
-struct LinalgExtFusionPattern
- : public OpInterfaceRewritePattern<TilingInterface> {
- LinalgExtFusionPattern(MLIRContext *context, ArrayRef<int64_t> operandsToFuse)
- : OpInterfaceRewritePattern<TilingInterface>(context),
- operandsToFuse(operandsToFuse.begin(), operandsToFuse.end()) {}
-
- FailureOr<FusionResult>
- returningMatchAndRewrite(TilingInterface consumerOp,
- PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(TilingInterface consumerOp,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(consumerOp, rewriter);
- }
-
-private:
- SmallVector<int64_t> operandsToFuse;
-};
-
//===----------------------------------------------------------------------===//
// Transformations exposed as patterns, moved from upstream MLIR as IREE still
// heavily relies on patterns that compose through filters.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index fa18c40..f9e5215 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -31,104 +31,6 @@
// Utility functions
//===---------------------------------------------------------------------===//
-/// Extracts a vector of int64_t from an array attribute. Asserts if the
-/// attribute contains values other than integers.
-static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
- SmallVector<int64_t> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getSExtValue());
- return result;
-}
-
-//===---------------------------------------------------------------------===//
-// FuseProducersOp
-//===---------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-LinalgExt::FuseProducersOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &transformResults,
- transform::TransformState &state) {
- SmallVector<int64_t> operandsToFuse = extractI64Array(getOperandsToFuse());
- LinalgExt::LinalgExtFusionPattern pattern(getContext(), operandsToFuse);
- size_t numProducers = operandsToFuse.size();
-
- SmallVector<Operation *> transformedOps;
- SmallVector<SmallVector<Operation *>> fusedOps(numProducers);
- for (Operation *target : state.getPayloadOps(getTarget())) {
- // Apply the pattern.
- SimplePatternRewriter patternRewriter(target);
- FailureOr<LinalgExt::FusionResult> result =
- pattern.returningMatchAndRewrite(cast<TilingInterface>(target),
- patternRewriter);
- if (failed(result))
- return emitDefaultDefiniteFailure(target);
-
- // Update the fused operations.
- transformedOps.push_back(result->consumerOp);
- for (size_t i = 0; i < numProducers; ++i)
- fusedOps[i].push_back(result->fusedOps[i]);
- }
-
- transformResults.set(getTransformed().cast<OpResult>(), transformedOps);
- for (size_t i = 0; i < numProducers; ++i)
- transformResults.set(getFusedOps()[i], fusedOps[i]);
- return DiagnosedSilenceableFailure::success();
-}
-
-LogicalResult LinalgExt::FuseProducersOp::verify() {
- SmallVector<int64_t> operandsToFuse = extractI64Array(getOperandsToFuse());
- llvm::SmallDenseSet<int64_t> operandsSet;
- for (int64_t operandToFuse : operandsToFuse) {
- if (operandToFuse < 0) {
- return emitOpError() << "expects positive operand numbers, found "
- << operandToFuse;
- }
- if (operandsSet.count(operandToFuse) != 0) {
- return emitOpError() << "expects unique operand numbers, found "
- << operandToFuse << " multiple times";
- }
- operandsSet.insert(operandToFuse);
- }
- return success();
-}
-
-ParseResult LinalgExt::FuseProducersOp::parse(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::UnresolvedOperand targetOperand;
- SMLoc opLoc;
- if (parser.getCurrentLocation(&opLoc))
- return failure();
- if (parser.parseOperand(targetOperand))
- return parser.emitError(opLoc, "expected `target` operand");
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
- StringRef operandsToFuseAttrName("operands_to_fuse");
- Attribute operandsToFuseAttr = result.attributes.get(operandsToFuseAttrName);
- if (!operandsToFuseAttr) {
- return parser.emitError(opLoc, llvm::formatv("expected `{0}` attribute",
- operandsToFuseAttrName));
- }
- auto operandsToFuseArrayAttr = operandsToFuseAttr.dyn_cast<ArrayAttr>();
- if (!operandsToFuseArrayAttr) {
- return parser.emitError(opLoc,
- llvm::formatv("`{0}` attribute must be an array",
- operandsToFuseAttrName));
- }
- Type anyOpType = transform::AnyOpType::get(parser.getBuilder().getContext());
- size_t numProducers = operandsToFuseArrayAttr.size();
- result.addTypes(SmallVector<Type>(numProducers + 1, anyOpType));
- if (parser.resolveOperand(targetOperand, anyOpType, result.operands))
- return failure();
- return success();
-}
-
-void LinalgExt::FuseProducersOp::print(OpAsmPrinter &p) {
- p << ' ';
- p << getTarget();
- p.printOptionalAttrDict((*this)->getAttrs());
-}
-
DiagnosedSilenceableFailure LinalgExt::RewriteForallToAsyncOp::applyToOne(
transform::TransformRewriter &rewriter, scf::ForallOp target,
transform::ApplyToEachResultList &results,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
index a8f81ce..e75f0c1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_library(IREELinalgExtTransforms
ForeachThreadToAsync.cpp
ForeachThreadToSequentialFor.cpp
- Fusion.cpp
Utils.cpp
PARTIAL_SOURCES_INTENDED
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
deleted file mode 100644
index a73a90b..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Fusion.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Interfaces/TilingInterface.h"
-
-using namespace mlir;
-using namespace mlir::iree_compiler::IREE::LinalgExt;
-
-FailureOr<FusionResult> LinalgExtFusionPattern::returningMatchAndRewrite(
- TilingInterface consumerOp, PatternRewriter &rewriter) const {
- // Try to fuse the producers of all operands to fuse.
- SmallVector<TilingInterface> fusedOps;
- for (int64_t operandToFuse : operandsToFuse) {
- // Check the operand exists.
- if (operandToFuse >= consumerOp->getNumOperands())
- return failure();
-
- // Check the operand is a slice of a producer result.
- auto sliceOp = consumerOp->getOperand(operandToFuse)
- .getDefiningOp<tensor::ExtractSliceOp>();
- if (!sliceOp)
- return failure();
- auto producerOp = sliceOp.getSource().getDefiningOp<TilingInterface>();
- if (!producerOp || producerOp->getNumResults() != 1)
- return failure();
-
- // Tile the producer.
- FailureOr<TilingResult> tileAndFuseResult =
- producerOp.generateResultTileValue(rewriter, /*resultNumber=*/0,
- sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes());
- if (failed(tileAndFuseResult))
- return failure();
- for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
- auto interfaceOp = dyn_cast<TilingInterface>(tileAndFusedOp);
- if (!interfaceOp)
- continue;
- fusedOps.push_back(interfaceOp);
- }
- }
-
- // Update the consumer in-place using the tiled producer results.
- SmallVector<Value> newOperands = consumerOp->getOperands();
- for (auto it : llvm::zip(operandsToFuse, fusedOps)) {
- int64_t operandToFuse = std::get<0>(it);
- TilingInterface fusedOp = std::get<1>(it);
- newOperands[operandToFuse] = fusedOp->getResult(0);
- }
- rewriter.updateRootInPlace(consumerOp,
- [&]() { consumerOp->setOperands(newOperands); });
-
- return FusionResult{consumerOp, fusedOps};
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
deleted file mode 100644
index 6e7b6bc..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
+++ /dev/null
@@ -1,140 +0,0 @@
-// RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s
-// TODO(#11765): Fix and re-enable this.
-// REQUIRES: dont-run
-
-#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
-#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
-#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
-
-module {
- // CHECK-LABEL: func.func @fuse_static
- // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
- // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
- // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
- func.func @fuse_static(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
- %cst = arith.constant 4.200000e+01 : f32
- %cst2 = arith.constant 4.300000e+01 : f32
- %0 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
- outs(%arg1 : tensor<64xf32>) {
- ^bb0(%arg3: f32):
- linalg.yield %cst : f32
- } -> tensor<64xf32>
- %1 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
- outs(%arg2 : tensor<64xf32>) {
- ^bb0(%arg3: f32):
- linalg.yield %cst : f32
- } -> tensor<64xf32>
-
- %2 = affine.apply #map0()[%arg0]
- // CHECK: scf.forall
- %3 = scf.forall (%arg3) in (%2) shared_outs(%O = %arg2) -> (tensor<64xf32>) {
- // CHECK: %[[OFFSET:.*]] = affine.apply
- // CHECK: %[[SIZE:.*]] = affine.min
- %4 = affine.apply #map1(%arg3)[%arg0]
- %5 = affine.min #map2(%arg3)[%arg0]
- %6 = tensor.extract_slice %0[%4] [%5] [1] : tensor<64xf32> to tensor<?xf32>
-
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%[[OFFSET]]] [%[[SIZE]]] [{{.*}}]
- // CHECK: %[[T1:.*]] = linalg.generic {{.*}} outs(%[[T0]]
- // CHECK: %[[T2:.*]] = tensor.extract_slice %[[OUT]][%[[OFFSET]]] [%[[SIZE]]] [{{.*}}]
- // CHECK: %[[T3:.*]] = linalg.generic {{.*}} outs(%[[T2]]
- %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<64xf32> to tensor<?xf32>
-
- // CHECK: %[[T4:.*]] = linalg.elemwise_unary ins(%[[T1]] {{.*}} outs(%[[T3]]
- %8 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%7 : tensor<?xf32>) -> tensor<?xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %8 into %O[%4] [%5] [1] : tensor<?xf32> into tensor<64xf32>
- }
- }
- func.return %3 : tensor<64xf32>
- }
-
- transform.with_pdl_patterns {
- ^bb0(%arg0: !pdl.operation):
- pdl.pattern @match_elemwise : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
- pdl.pattern @match_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "scf.forall"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
- transform.sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @match_elemwise in %arg1 : (!pdl.operation) -> !pdl.operation
- %1, %fusedOps:2 = fuse_producers %0 {operands_to_fuse=[0, 1]}
- }
- }
-}
-
-// -----
-
-#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
-#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
-#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
-
-module {
- // CHECK-LABEL: func.func @fuse_dynamic
- // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
- // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
- // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
- func.func @fuse_dynamic(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
- %cst = arith.constant 4.200000e+01 : f32
- %c0 = arith.constant 0 : index
- %0 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
- outs(%arg1 : tensor<?xf32>) {
- ^bb0(%arg3: f32):
- linalg.yield %cst : f32
- } -> tensor<?xf32>
- // TODO: Choosing %arg2 here complicates the size computation.
- %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
- %1 = affine.apply #map0()[%d0, %arg0]
- // CHECK: scf.forall
- %2 = scf.forall (%arg3) in (%1) shared_outs(%O = %arg2) -> (tensor<?xf32>) {
- // CHECK: %[[OFFSET:.*]] = affine.apply
- // CHECK: %[[SIZE:.*]] = affine.min
- %3 = affine.apply #map1(%arg3)[%arg0]
- %4 = affine.min #map2(%arg3)[%d0, %arg0]
- %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
-
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%[[OFFSET]]] [%[[SIZE]]] [{{.*}}]
- // CHECK: %[[T1:.*]] = linalg.generic {{.*}} outs(%[[T0]]
- %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
-
- // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %7 into %O[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
- }
- }
- func.return %2 : tensor<?xf32>
- }
-
- transform.with_pdl_patterns {
- ^bb0(%arg0: !pdl.operation):
- pdl.pattern @match_elemwise : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
- pdl.pattern @match_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "scf.forall"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
- transform.sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @match_elemwise in %arg1 : (!pdl.operation) -> !pdl.operation
- %1, %fusedOps = fuse_producers %0 {operands_to_fuse=[0]}
- }
- }
-}