Make reduction/matmul/conv matchers optionally partial (#13981)

When matching reductions, matmuls or convolutions, make the optional
constraint that all tileable operations are matched. This is useful to
enable matching prior to dispatch region formation and will remove code
duplication in the nvgpu plugin.

This should be a noop for the existing flows.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 0fc3d0b..692e632 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -53,6 +53,7 @@
             "tile_and_distribute_to_workgroups.mlir",
             "transform_buffer_opt.mlir",
             "transform_dialect_apply_pattern_op.mlir",
+            "transform_match_partial_reduction.mlir",
             "transform_ops_invalid.mlir",
             "transpose_canonicalization.mlir",
             "type_propagation.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index e814657..87e8c56 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -49,6 +49,7 @@
     "tile_and_distribute_to_workgroups.mlir"
     "transform_buffer_opt.mlir"
     "transform_dialect_apply_pattern_op.mlir"
+    "transform_match_partial_reduction.mlir"
     "transform_ops_invalid.mlir"
     "transpose_canonicalization.mlir"
     "type_propagation.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_match_partial_reduction.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_match_partial_reduction.mlir
new file mode 100644
index 0000000..61baa7f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_match_partial_reduction.mlir
@@ -0,0 +1,43 @@
+// RUN: iree-opt %s --iree-transform-dialect-interpreter --verify-diagnostics --split-input-file
+
+// This can be matched by "reduction_partial" but not by "reduction".
+
+func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %empty = tensor.empty() : tensor<8xf32>
+  // expected-remark @below {{fill}}
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32>
+  // expected-remark @below {{reduction}}
+  %result = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]} 
+    ins(%arg0 : tensor<8x479xf32>)
+    outs(%fill : tensor<8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %6 = arith.addf %in, %out : f32
+    linalg.yield %6 : f32
+  } -> tensor<8xf32>
+
+  %empty2 = tensor.empty() : tensor<32xf32>
+  %fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32>
+  return %result, %fill2 : tensor<8xf32>, tensor<32xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.iree.register_match_callbacks
+
+  %leading, %fill, %reduction, %trailing =
+    transform.iree.match_callback failures(propagate) "reduction_partial"(%arg0)
+    : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+  transform.iree.emit_remark "leading" at %leading : !transform.any_op
+  transform.iree.emit_remark "fill" at %fill : !transform.any_op
+  transform.iree.emit_remark "reduction" at %reduction : !transform.any_op
+  transform.iree.emit_remark "trailing" at %trailing : !transform.any_op
+
+  // expected-error @below {{failed to match}}
+  transform.iree.match_callback failures(propagate) "reduction"(%arg0)
+    : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
index f87b335..75e533a 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
@@ -155,7 +155,8 @@
   StructuredOpMatcher *reduction;
   transform_ext::MatchedReductionCaptures captures;
   transform_ext::MatcherContext matcherContext;
-  makeReductionMatcher(matcherContext, reduction, captures);
+  makeReductionMatcher(matcherContext, reduction, captures,
+                       /*mustMatchEntireFunc=*/true);
   if (!matchPattern(op, *reduction)) return failure();
 
   // 2. Construct the configuration and the strategy builder.
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
index 87828e0..14b4ace 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
@@ -809,7 +809,8 @@
   StructuredOpMatcher *reduction;
   transform_ext::MatchedReductionCaptures captures;
   transform_ext::MatcherContext matcherContext;
-  makeReductionMatcher(matcherContext, reduction, captures);
+  makeReductionMatcher(matcherContext, reduction, captures,
+                       /*mustMatchEntireFunc=*/true);
   if (!matchPattern(op, *reduction)) {
     LDBG("--Reduction strategy failed to match\n");
     return failure();
@@ -855,7 +856,8 @@
   StructuredOpMatcher *trailing;
   transform_ext::MatchedMatmulCaptures captures;
   transform_ext::MatcherContext matcherContext;
-  makeMatmulMatcher(matcherContext, matmul, fill, trailing, captures);
+  makeMatmulMatcher(matcherContext, matmul, fill, trailing, captures,
+                    /*mustMatchEntireFunc=*/true);
   if (!matchPattern(op, *matmul)) {
     LDBG("--Matmul strategy fail to match\n");
     return failure();
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 0a90050..feec208 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -877,28 +877,35 @@
 ///     trailing(reduction(leading(), fill()))
 ///
 /// where trailing and leading are elementwise operations whose presence is
-/// optional. Each matcher will capture the corresponding operation.
+/// optional. Each matcher will capture the corresponding operation. If
+/// `mustMatchEntireFunc` is set, the matcher additionally checks if all
+/// tileable operations in the functions are captured.
 void makeReductionMatcher(transform_ext::MatcherContext &context,
                           StructuredOpMatcher *&reductionCapture,
                           StructuredOpMatcher *&fillCapture,
                           StructuredOpMatcher *&leadingCapture,
                           StructuredOpMatcher *&trailingCapture,
-                          MatchedReductionCaptures &captures);
+                          MatchedReductionCaptures &captures,
+                          bool mustMatchEntireFunc);
 void makeReductionMatcher(transform_ext::MatcherContext &context,
                           StructuredOpMatcher *&reductionCapture,
-                          MatchedReductionCaptures &captures);
+                          MatchedReductionCaptures &captures,
+                          bool mustMatchEntireFunc);
 
 /// Creates a group of matchers for:
 ///
 ///     trailing(matmul(*, *, fill()))
 ///
 /// where trailing and leading are elementwise operations whose presence is
-/// optional. Each matcher will capture the corresponding operation.
+/// optional. Each matcher will capture the corresponding operation. If
+/// `mustMatchEntireFunc` is set, the matcher additionally checks if all
+/// tileable operations in the functions are captured.
 void makeMatmulMatcher(transform_ext::MatcherContext &matcherContext,
                        StructuredOpMatcher *&matmulCapture,
                        StructuredOpMatcher *&fillCapture,
                        StructuredOpMatcher *&trailingCapture,
-                       MatchedMatmulCaptures &captures);
+                       MatchedMatmulCaptures &captures,
+                       bool mustMatchEntireFunc);
 
 /// Create a group of matchers for a different code sequence of operations
 /// matching exactly a softmax operation.
@@ -927,15 +934,43 @@
 ///     trailing(convolution(input, filter, fill()))
 ///
 /// where fill is a FillOp and trailing is an elementwise operation, both of
-/// which is optional. Each matcher will capture the corresponding operation.
+/// which is optional. Each matcher will capture the corresponding operation. If
+/// `mustMatchEntireFunc` is set, the matcher additionally checks if all
+/// tileable operations in the functions are captured.
 void makeConvolutionMatcher(transform_ext::MatcherContext &context,
                             StructuredOpMatcher *&convolutionCapture,
                             StructuredOpMatcher *&fillCapture,
                             StructuredOpMatcher *&trailingCapture,
-                            MatchedConvolutionCaptures &captures);
+                            MatchedConvolutionCaptures &captures,
+                            bool mustMatchEntireFunc);
 void makeConvolutionMatcher(transform_ext::MatcherContext &context,
                             StructuredOpMatcher *&convolutionCapture,
-                            MatchedConvolutionCaptures &captures);
+                            MatchedConvolutionCaptures &captures,
+                            bool mustMatchEntireFunc);
+
+/// Wraps the given matcher callback to indicate that it must capture all
+/// tilable ops in the parent function. Expects the callback to accept the same
+/// arguments as what is expected by MatchCallbacksRegistry::register, followed
+/// by a bool.
+template <typename Fn>
+auto wrapAsEntireFuncMatch(Fn &&fn) {
+  return [fn = std::move(fn)](
+             transform_ext::MatchCallbackResult &res, Location loc,
+             const mlir::transform::TransformState &state,
+             ValueRange handles) { return fn(res, loc, state, handles, true); };
+}
+
+/// Wraps the given matcher callback to indicate that it can match subgraphs.
+/// Expects the callback to accept the same arguments as what is expected by
+/// MatchCallbacksRegistry::register, followed by a bool.
+template <typename Fn>
+auto wrapAsPartialMatch(Fn &&fn) {
+  return [fn = std::move(fn)](
+             transform_ext::MatchCallbackResult &res, Location loc,
+             const mlir::transform::TransformState &state, ValueRange handles) {
+    return fn(res, loc, state, handles, false);
+  };
+}
 
 } // namespace transform_ext
 } // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 14e660b..2a41a20 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -684,7 +684,7 @@
 static DiagnosedSilenceableFailure
 reductionCallback(transform_ext::MatchCallbackResult &res, Location loc,
                   const mlir::transform::TransformState &state,
-                  ValueRange handles) {
+                  ValueRange handles, bool mustMatchEntireFunc) {
   if (handles.size() != 1 ||
       !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) {
     return emitSilenceableFailure(loc)
@@ -694,8 +694,8 @@
   transform_ext::StructuredOpMatcher *pattern, *fill, *leading, *trailing;
   transform_ext::MatchedReductionCaptures ignore;
   transform_ext::MatcherContext matcherContext;
-  makeReductionMatcher(matcherContext, pattern, fill, leading, trailing,
-                       ignore);
+  makeReductionMatcher(matcherContext, pattern, fill, leading, trailing, ignore,
+                       mustMatchEntireFunc);
 
   // TODO: need a mechanism for this to go around the entire IR,
   // potentially with list matches for each group.
@@ -756,7 +756,8 @@
   transform_ext::StructuredOpMatcher *pattern, *fill, *trailing;
   transform_ext::MatchedConvolutionCaptures ignore;
   transform_ext::MatcherContext matcherContext;
-  makeConvolutionMatcher(matcherContext, pattern, fill, trailing, ignore);
+  makeConvolutionMatcher(matcherContext, pattern, fill, trailing, ignore,
+                         /*mustMatchEntireFunc=*/true);
 
   // TODO: need a mechanism for this to go around the entire IR,
   // potentially with list matches for each group.
@@ -815,7 +816,8 @@
   transform_ext::StructuredOpMatcher *pattern, *fill, *trailing;
   transform_ext::MatchedMatmulCaptures ignore;
   transform_ext::MatcherContext matcherContext;
-  makeMatmulMatcher(matcherContext, pattern, fill, trailing, ignore);
+  makeMatmulMatcher(matcherContext, pattern, fill, trailing, ignore,
+                    /*mustMatchEntireFunc=*/true);
 
   // TODO: need a mechanism for this to go around the entire IR,
   // potentially with list matches for each group.
@@ -859,7 +861,10 @@
                             testValueMatcherCallback);
   registry.registerCallback("_test_shaped_value_matcher_callback",
                             testShapedValueMatcherCallback);
-  registry.registerCallback("reduction", reductionCallback);
+  registry.registerCallback("reduction",
+                            wrapAsEntireFuncMatch(reductionCallback));
+  registry.registerCallback("reduction_partial",
+                            wrapAsPartialMatch(reductionCallback));
   registry.registerCallback("convolution", convolutionCallback);
   registry.registerCallback("matmul", matmulCallback);
   return DiagnosedSilenceableFailure::success();
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index d3cefa1..17f2edc 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -1211,7 +1211,7 @@
     transform_ext::StructuredOpMatcher *&fillCapture,
     transform_ext::StructuredOpMatcher *&leadingCapture,
     transform_ext::StructuredOpMatcher *&trailingCapture,
-    MatchedReductionCaptures &captures) {
+    MatchedReductionCaptures &captures, bool mustMatchEntireFunc) {
   // The core part of the matcher is anchored on a particular reduction op.
   auto &reduction =
       m_StructuredOp(matcherContext)
@@ -1311,17 +1311,29 @@
           // Capture output elemental type.
           .output(0, CaptureElementTypeBitWidth(
                          captures.maybeTrailingOutputElementalTypeBitWidth));
-  reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch())
-                  .allTilableOpsCaptured<func::FuncOp>();
+  reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch());
+  if (mustMatchEntireFunc)
+    reduction = reduction.allTilableOpsCaptured<func::FuncOp>();
   trailingCapture = &trailing;
 }
 
+void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context,
+                                         StructuredOpMatcher *&reductionCapture,
+                                         MatchedReductionCaptures &captures,
+                                         bool mustMatchEntireFunc) {
+  StructuredOpMatcher *fill;
+  StructuredOpMatcher *leading;
+  StructuredOpMatcher *trailing;
+  makeReductionMatcher(context, reductionCapture, fill, leading, trailing,
+                       captures, mustMatchEntireFunc);
+}
+
 void transform_ext::makeMatmulMatcher(
     transform_ext::MatcherContext &matcherContext,
     transform_ext::StructuredOpMatcher *&matmulCapture,
     transform_ext::StructuredOpMatcher *&fillCapture,
     transform_ext::StructuredOpMatcher *&trailingCapture,
-    transform_ext::MatchedMatmulCaptures &captures) {
+    transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) {
   auto &matmul = transform_ext::m_StructuredOp<linalg::MatmulOp>(matcherContext)
                      // Capture op sizes.
                      .dim(AllDims(), CaptureDims(captures.matmulOpSizes))
@@ -1336,21 +1348,12 @@
   fillCapture = &fill;
 
   auto &trailing = m_StructuredOp<linalg::GenericOp>(matcherContext);
-  matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch())
-               .allTilableOpsCaptured<func::FuncOp>();
+  matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch());
+  if (mustMatchEntireFunc)
+    matmul = matmul.allTilableOpsCaptured<func::FuncOp>();
   trailingCapture = &trailing;
 }
 
-void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context,
-                                         StructuredOpMatcher *&reductionCapture,
-                                         MatchedReductionCaptures &captures) {
-  StructuredOpMatcher *fill;
-  StructuredOpMatcher *leading;
-  StructuredOpMatcher *trailing;
-  makeReductionMatcher(context, reductionCapture, fill, leading, trailing,
-                       captures);
-}
-
 /// Match sum(%src, broadcast(%reduction))
 static void
 matchSubBroadcast(transform_ext::MatcherContext &matcherContext,
@@ -1531,7 +1534,7 @@
     transform_ext::StructuredOpMatcher *&convolutionCapture,
     transform_ext::StructuredOpMatcher *&fillCapture,
     transform_ext::StructuredOpMatcher *&trailingCapture,
-    MatchedConvolutionCaptures &captures) {
+    MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) {
   // The core part of the matcher is anchored on a particular convolution op.
   auto &convolution =
       m_StructuredOp<linalg::Conv2DNchwFchwOp, linalg::Conv2DNhwcHwcfOp>(
@@ -1571,16 +1574,18 @@
 
   // Optional trailing can be any map, transpose, broadcast but not reduce or
   // windowing operation for now.
-  convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch())
-                    .allTilableOpsCaptured<func::FuncOp>();
+  convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch());
+  if (mustMatchEntireFunc)
+    convolution = convolution.allTilableOpsCaptured<func::FuncOp>();
   trailingCapture = &trailing;
 }
 
 void transform_ext::makeConvolutionMatcher(
     transform_ext::MatcherContext &context,
     StructuredOpMatcher *&convolutionCapture,
-    MatchedConvolutionCaptures &captures) {
+    MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) {
   StructuredOpMatcher *fill;
   StructuredOpMatcher *trailing;
-  makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures);
+  makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures,
+                         mustMatchEntireFunc);
 }