[NFC] Cleanups to flow op folders. (#18974)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 855df65..3219c6e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -37,6 +37,8 @@
// Folding utilities
//===----------------------------------------------------------------------===//
+namespace {
+
// Erases an op if it has no uses.
// This is to support ops that are "pure" but can't be marked as such because
// the MLIR CSE pass would deduplicate them.
@@ -170,6 +172,8 @@
return newDims;
}
+} // namespace
+
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroups
//===----------------------------------------------------------------------===//
@@ -365,6 +369,8 @@
// flow.dispatch.workload.ordinal
//===----------------------------------------------------------------------===//
+namespace {
+
// Bubble up the ordinal ops so that all uses go through this operation.
struct BubbleUpOrdinalOp : public OpRewritePattern<DispatchWorkloadOrdinalOp> {
using OpRewritePattern::OpRewritePattern;
@@ -399,6 +405,8 @@
}
};
+} // namespace
+
/// Fold away following sequence of `flow.dispatch.workload.ordinal`.
///
/// ```mlir
@@ -863,25 +871,6 @@
return {};
}
-//===----------------------------------------------------------------------===//
-// flow.tensor.bitcast
-//===----------------------------------------------------------------------===//
-
-OpFoldResult TensorBitCastOp::fold(FoldAdaptor operands) {
- auto sourceType = llvm::cast<ShapedType>(getSource().getType());
- auto resultType = llvm::cast<ShapedType>(getResult().getType());
- if (sourceType.getElementType() != resultType.getElementType()) {
- // Element type mismatch, this is a bitcast.
- return {};
- }
- if (compareShapesEqual(sourceType, getSourceDims(), resultType,
- getResultDims())) {
- // Shapes match and this is a no-op so just fold to the source.
- return getSource();
- }
- return {};
-}
-
namespace {
// Flatten a chain of reshapes or bitcasts (reshape/bitcast feeding into
@@ -930,48 +919,6 @@
}
};
-// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
-// primitive value for the splat op.
-struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
- using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TensorLoadOp loadOp,
- PatternRewriter &rewriter) const override {
- auto sourceOp =
- dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp());
-
- if (!sourceOp)
- return failure();
-
- rewriter.replaceOp(loadOp, sourceOp.getValue());
- return success();
- }
-};
-
-struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> {
- using OpRewritePattern<TensorSplatOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TensorSplatOp splatOp,
- PatternRewriter &rewriter) const override {
- if (!splatOp.getResult().hasOneUse())
- return failure();
-
- auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>(
- splatOp.getResult().use_begin()->getOwner());
- if (!reshapeOp)
- return failure();
-
- PatternRewriter::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(reshapeOp);
- rewriter.replaceOpWithNewOp<TensorSplatOp>(
- reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
- reshapeOp.getResultDims());
- rewriter.eraseOp(splatOp);
-
- return success();
- }
-};
-
struct ResolveShapedRank : public OpRewritePattern<tensor::RankOp> {
using OpRewritePattern<tensor::RankOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::RankOp op,
@@ -1032,6 +979,25 @@
results.insert<ResolveShapedDim>(context);
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.bitcast
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorBitCastOp::fold(FoldAdaptor operands) {
+ auto sourceType = llvm::cast<ShapedType>(getSource().getType());
+ auto resultType = llvm::cast<ShapedType>(getResult().getType());
+ if (sourceType.getElementType() != resultType.getElementType()) {
+ // Element type mismatch, this is a bitcast.
+ return {};
+ }
+ if (compareShapesEqual(sourceType, getSourceDims(), resultType,
+ getResultDims())) {
+ // Shapes match and this is a no-op so just fold to the source.
+ return getSource();
+ }
+ return {};
+}
+
void TensorBitCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ReplaceOpIfTensorOperandZeroElements<TensorBitCastOp, 0>>(
@@ -1060,6 +1026,25 @@
return {};
}
+namespace {
+
+// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
+// primitive value for the splat op.
+struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
+ using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceOp =
+ dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp());
+ if (!sourceOp)
+ return failure();
+ rewriter.replaceOp(loadOp, sourceOp.getValue());
+ return success();
+ }
+};
+
+} // namespace
+
void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<FoldSplatLoadIntoPrimitive>(context);
@@ -1116,6 +1101,25 @@
// flow.tensor.splat
//===----------------------------------------------------------------------===//
+namespace {
+
+struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto splatOp = dyn_cast_if_present<TensorSplatOp>(
+ reshapeOp.getSource().getDefiningOp());
+ if (!splatOp)
+ return failure();
+ rewriter.replaceOpWithNewOp<TensorSplatOp>(
+ reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
+ reshapeOp.getResultDims());
+ return success();
+ }
+};
+
+} // namespace
+
void TensorSplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): canonicalize splat+slice to smaller splat.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index ada2b37..6e92e0b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -411,14 +411,14 @@
// CHECK-LABEL: @ElideRedundantTransfer
// CHECK-SAME: (%[[OPERAND:.+]]: tensor<4x?xf32>, %[[DIM:.+]]: index)
-util.func public @ElideRedundantTransfer(%arg0: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
- // CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %arg0
- %transfer = flow.tensor.transfer %arg0 : tensor<4x?xf32>{%dim} to "target"
+util.func public @ElideRedundantTransfer(%operand: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
+ // CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %[[OPERAND]]
+ %transfer = flow.tensor.transfer %operand : tensor<4x?xf32>{%dim} to "target"
// CHECK: %[[BITCAST:.+]] = flow.tensor.bitcast %[[TRANSFER]]
%bitcast = flow.tensor.bitcast %transfer : tensor<4x?xf32>{%dim} -> tensor<4x?xi32>{%dim}
- // CHECK-NOT: flow.transfer
+ // CHECK-NOT: flow.tensor.transfer
%redundant = flow.tensor.transfer %bitcast : tensor<4x?xi32>{%dim} to "target"
- // CHECK-NEXT: %[[BITCAST]]
+ // CHECK-NEXT: util.return %[[BITCAST]]
util.return %redundant : tensor<4x?xi32>
}