[NFC] Cleanups to flow op folders. (#18974)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 855df65..3219c6e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -37,6 +37,8 @@
 // Folding utilities
 //===----------------------------------------------------------------------===//
 
+namespace {
+
 // Erases an op if it has no uses.
 // This is to support ops that are "pure" but can't be marked as such because
 // the MLIR CSE pass would deduplicate them.
@@ -170,6 +172,8 @@
   return newDims;
 }
 
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // flow.dispatch.workgroups
 //===----------------------------------------------------------------------===//
@@ -365,6 +369,8 @@
 // flow.dispatch.workload.ordinal
 //===----------------------------------------------------------------------===//
 
+namespace {
+
 // Bubble up the ordinal ops so that all uses go through this operation.
 struct BubbleUpOrdinalOp : public OpRewritePattern<DispatchWorkloadOrdinalOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -399,6 +405,8 @@
   }
 };
 
+} // namespace
+
 /// Fold away following sequence of `flow.dispatch.workload.ordinal`.
 ///
 /// ```mlir
@@ -863,25 +871,6 @@
   return {};
 }
 
-//===----------------------------------------------------------------------===//
-// flow.tensor.bitcast
-//===----------------------------------------------------------------------===//
-
-OpFoldResult TensorBitCastOp::fold(FoldAdaptor operands) {
-  auto sourceType = llvm::cast<ShapedType>(getSource().getType());
-  auto resultType = llvm::cast<ShapedType>(getResult().getType());
-  if (sourceType.getElementType() != resultType.getElementType()) {
-    // Element type mismatch, this is a bitcast.
-    return {};
-  }
-  if (compareShapesEqual(sourceType, getSourceDims(), resultType,
-                         getResultDims())) {
-    // Shapes match and this is a no-op so just fold to the source.
-    return getSource();
-  }
-  return {};
-}
-
 namespace {
 
 // Flatten a chain of reshapes or bitcasts (reshape/bitcast feeding into
@@ -930,48 +919,6 @@
   }
 };
 
-// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
-// primitive value for the splat op.
-struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
-  using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TensorLoadOp loadOp,
-                                PatternRewriter &rewriter) const override {
-    auto sourceOp =
-        dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp());
-
-    if (!sourceOp)
-      return failure();
-
-    rewriter.replaceOp(loadOp, sourceOp.getValue());
-    return success();
-  }
-};
-
-struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> {
-  using OpRewritePattern<TensorSplatOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TensorSplatOp splatOp,
-                                PatternRewriter &rewriter) const override {
-    if (!splatOp.getResult().hasOneUse())
-      return failure();
-
-    auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>(
-        splatOp.getResult().use_begin()->getOwner());
-    if (!reshapeOp)
-      return failure();
-
-    PatternRewriter::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(reshapeOp);
-    rewriter.replaceOpWithNewOp<TensorSplatOp>(
-        reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
-        reshapeOp.getResultDims());
-    rewriter.eraseOp(splatOp);
-
-    return success();
-  }
-};
-
 struct ResolveShapedRank : public OpRewritePattern<tensor::RankOp> {
   using OpRewritePattern<tensor::RankOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(tensor::RankOp op,
@@ -1032,6 +979,25 @@
   results.insert<ResolveShapedDim>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// flow.tensor.bitcast
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorBitCastOp::fold(FoldAdaptor operands) {
+  auto sourceType = llvm::cast<ShapedType>(getSource().getType());
+  auto resultType = llvm::cast<ShapedType>(getResult().getType());
+  if (sourceType.getElementType() != resultType.getElementType()) {
+    // Element type mismatch, this is a bitcast.
+    return {};
+  }
+  if (compareShapesEqual(sourceType, getSourceDims(), resultType,
+                         getResultDims())) {
+    // Shapes match and this is a no-op so just fold to the source.
+    return getSource();
+  }
+  return {};
+}
+
 void TensorBitCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.insert<ReplaceOpIfTensorOperandZeroElements<TensorBitCastOp, 0>>(
@@ -1060,6 +1026,25 @@
   return {};
 }
 
+namespace {
+
+// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
+// primitive value for the splat op.
+struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
+  using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorLoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceOp =
+        dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp());
+    if (!sourceOp)
+      return failure();
+    rewriter.replaceOp(loadOp, sourceOp.getValue());
+    return success();
+  }
+};
+
+} // namespace
+
 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
   results.insert<FoldSplatLoadIntoPrimitive>(context);
@@ -1116,6 +1101,25 @@
 // flow.tensor.splat
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto splatOp = dyn_cast_if_present<TensorSplatOp>(
+        reshapeOp.getSource().getDefiningOp());
+    if (!splatOp)
+      return failure();
+    rewriter.replaceOpWithNewOp<TensorSplatOp>(
+        reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
+        reshapeOp.getResultDims());
+    return success();
+  }
+};
+
+} // namespace
+
 void TensorSplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   // TODO(benvanik): canonicalize splat+slice to smaller splat.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index ada2b37..6e92e0b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -411,14 +411,14 @@
 
 // CHECK-LABEL: @ElideRedundantTransfer
 //  CHECK-SAME: (%[[OPERAND:.+]]: tensor<4x?xf32>, %[[DIM:.+]]: index)
-util.func public @ElideRedundantTransfer(%arg0: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
-  // CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %arg0
-  %transfer = flow.tensor.transfer %arg0 : tensor<4x?xf32>{%dim} to "target"
+util.func public @ElideRedundantTransfer(%operand: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
+  // CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %[[OPERAND]]
+  %transfer = flow.tensor.transfer %operand : tensor<4x?xf32>{%dim} to "target"
   // CHECK: %[[BITCAST:.+]] = flow.tensor.bitcast %[[TRANSFER]]
   %bitcast = flow.tensor.bitcast %transfer : tensor<4x?xf32>{%dim} -> tensor<4x?xi32>{%dim}
-  // CHECK-NOT: flow.transfer
+  // CHECK-NOT: flow.tensor.transfer
   %redundant = flow.tensor.transfer %bitcast : tensor<4x?xi32>{%dim} to "target"
-  // CHECK-NEXT: %[[BITCAST]]
+  // CHECK-NEXT: util.return %[[BITCAST]]
   util.return %redundant : tensor<4x?xi32>
 }