[Flow] Raise special `linalg.generic` ops to `linalg.fill` ops (#14773)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
index 6fcddd6..8d4c4d9 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp
@@ -76,6 +76,34 @@
return std::nullopt;
}
+// Method to match a linalg.generic op representing a linalg.fill op. Returns
+// the fill value (input operand to linalg.fill) on success.
+std::optional<Value> matchGenericFill(linalg::LinalgOp linalgOp) {
+ if (isa<linalg::GenericOp>(linalgOp.getOperation()) &&
+ linalgOp.getNumDpsInputs() == 0 && linalgOp.getNumDpsInits() == 1 &&
+ linalgOp.getNumParallelLoops() == linalgOp.getNumLoops() &&
+ linalgOp.getIndexingMapsArray()[0].isIdentity()) {
+ // Check that the op body is only a linalg.yield op.
+ Value yieldOperand;
+ for (Operation &bodyOp : linalgOp.getBlock()->getOperations()) {
+ if (isa<linalg::YieldOp>(bodyOp)) {
+ yieldOperand = bodyOp.getOperand(0);
+ } else {
+ return std::nullopt;
+ }
+ }
+ // Check that the operand of the linalg.yield op is not an argument of the
+ // linalg.generic basic block
+ for (Value blockArg : linalgOp.getBlock()->getArguments()) {
+ if (yieldOperand == blockArg) {
+ return std::nullopt;
+ }
+ }
+ return yieldOperand;
+ }
+ return std::nullopt;
+}
+
/// Matches a linalg.generic operation reading data from a tensor `source` using
/// tensor.extract, and raises the `source` tensor to an input of the linalg
/// operation.
@@ -333,6 +361,7 @@
SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
+ SmallVector<std::pair<linalg::GenericOp, Value>> genericFills;
getOperation()->walk([&](linalg::LinalgOp op) {
{
transform_ext::MatcherContext matcherContext;
@@ -347,6 +376,10 @@
transposeMatmulRoots.push_back(std::make_pair(
cast<linalg::MatmulOp>(op.getOperation()), newRhs.value()));
}
+ if (std::optional<Value> fillInput = matchGenericFill(op)) {
+ genericFills.push_back(
+ std::make_pair(cast<linalg::GenericOp>(op), fillInput.value()));
+ }
}
});
@@ -369,6 +402,15 @@
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
}
+ for (std::pair<linalg::GenericOp, Value> genericFill : genericFills) {
+ auto genericOp = genericFill.first;
+ Value fillInput = genericFill.second;
+ Value init = genericOp.getDpsInitOperand(0)->get();
+ rewriter.setInsertionPoint(genericOp);
+ SmallVector<NamedAttribute> attrs = getPrunedAttributeList(genericOp);
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ genericOp, ValueRange{fillInput}, ValueRange{init}, attrs);
+ }
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
index 49389b6..54835d5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir
@@ -187,6 +187,32 @@
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]
+func.func @generic_fill(%arg0: tensor<?x?xf32>) -> tensor<1x1x?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %0 = tensor.empty(%dim, %dim_0) : tensor<1x1x?x?xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ outs(%0 : tensor<1x1x?x?xf32>) {
+ ^bb0(%out: f32):
+ linalg.yield %cst : f32
+ } -> tensor<1x1x?x?xf32>
+ return %1 : tensor<1x1x?x?xf32>
+}
+// CHECK-LABEL: func @generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK-SAME: : tensor<1x1x?x?xf32>
+// CHECK: %[[RESULT:.+]] = linalg.fill
+// CHECK-SAME: ins(%[[CST]] : f32)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x?x?xf32>)
+// CHECK: return %[[RESULT]]
+
// -----
#map = affine_map<(d0) -> (d0)>