[Flow] Raise special `linalg.generic` ops to `linalg.fill` ops (#14773)

diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
index 6fcddd6..8d4c4d9 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -76,6 +76,34 @@
   return std::nullopt;
 }
 
+// Method to match a linalg.generic op representing a linalg.fill op. Returns
+// the fill value (input operand to linalg.fill) on success.
+std::optional<Value> matchGenericFill(linalg::LinalgOp linalgOp) {
+  if (isa<linalg::GenericOp>(linalgOp.getOperation()) &&
+      linalgOp.getNumDpsInputs() == 0 && linalgOp.getNumDpsInits() == 1 &&
+      linalgOp.getNumParallelLoops() == linalgOp.getNumLoops() &&
+      linalgOp.getIndexingMapsArray()[0].isIdentity()) {
+    // Check that the op body is only a linalg.yield op.
+    Value yieldOperand;
+    for (Operation &bodyOp : linalgOp.getBlock()->getOperations()) {
+      if (isa<linalg::YieldOp>(bodyOp)) {
+        yieldOperand = bodyOp.getOperand(0);
+      } else {
+        return std::nullopt;
+      }
+    }
+    // Check that the operand of the linalg.yield op is not an argument of the
+    // linalg.generic basic block
+    for (Value blockArg : linalgOp.getBlock()->getArguments()) {
+      if (yieldOperand == blockArg) {
+        return std::nullopt;
+      }
+    }
+    return yieldOperand;
+  }
+  return std::nullopt;
+}
+
 /// Matches a linalg.generic operation reading data from a tensor `source` using
 /// tensor.extract, and raises the `source` tensor to an input of the linalg
 /// operation.
@@ -333,6 +361,7 @@
 
     SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
     SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
+    SmallVector<std::pair<linalg::GenericOp, Value>> genericFills;
     getOperation()->walk([&](linalg::LinalgOp op) {
       {
         transform_ext::MatcherContext matcherContext;
@@ -347,6 +376,10 @@
           transposeMatmulRoots.push_back(std::make_pair(
               cast<linalg::MatmulOp>(op.getOperation()), newRhs.value()));
         }
+        if (std::optional<Value> fillInput = matchGenericFill(op)) {
+          genericFills.push_back(
+              std::make_pair(cast<linalg::GenericOp>(op), fillInput.value()));
+        }
       }
     });
 
@@ -369,6 +402,15 @@
       rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
           matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
     }
+    for (std::pair<linalg::GenericOp, Value> genericFill : genericFills) {
+      auto genericOp = genericFill.first;
+      Value fillInput = genericFill.second;
+      Value init = genericOp.getDpsInitOperand(0)->get();
+      rewriter.setInsertionPoint(genericOp);
+      SmallVector<NamedAttribute> attrs = getPrunedAttributeList(genericOp);
+      rewriter.replaceOpWithNewOp<linalg::FillOp>(
+          genericOp, ValueRange{fillInput}, ValueRange{init}, attrs);
+    }
   }
 };
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
index 49389b6..54835d5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -187,6 +187,32 @@
 //  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 //       CHECK:   return %[[RESULT]]
 
+func.func @generic_fill(%arg0: tensor<?x?xf32>) -> tensor<1x1x?x?xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+    %0 = tensor.empty(%dim, %dim_0) : tensor<1x1x?x?xf32>
+    %1 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+        iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        outs(%0 : tensor<1x1x?x?xf32>) {
+      ^bb0(%out: f32):
+        linalg.yield %cst : f32
+    } -> tensor<1x1x?x?xf32>
+    return %1 : tensor<1x1x?x?xf32>
+}
+// CHECK-LABEL: func @generic_fill
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty
+//  CHECK-SAME:   : tensor<1x1x?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = linalg.fill
+//  CHECK-SAME:       ins(%[[CST]] : f32)
+//  CHECK-SAME:       outs(%[[EMPTY]] : tensor<1x1x?x?xf32>)
+//       CHECK:   return %[[RESULT]]
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>