[LinalgExt] Add Interfaces for implementing fusion support for `iree_linalg_ext.custom_op`. (#18647)

These methods allow the dispatch region formation to automatically pick
up fusion of `custom_op` with producers/consumers similar to what is
supported with `LinalgOp`s.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 85c6da0..d4d7344 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1766,6 +1766,35 @@
   return success();
 }
 
+SmallVector<AffineMap> CustomOp::getIndexingMapsForOperands() {
+  return llvm::map_to_vector(
+      getIndexingMaps().getValue().take_front(getNumDpsInputs()),
+      [](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
+}
+
+SmallVector<AffineMap> CustomOp::getIndexingMapsForResults() {
+  return llvm::map_to_vector(
+      getIndexingMaps().getValue().take_back(getNumDpsInits()),
+      [](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
+}
+
+SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
+  return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) {
+    return cast<IREE::LinalgExt::IteratorTypeAttr>(attr).getValue();
+  });
+}
+
+LogicalResult
+CustomOp::reifyResultShapes(OpBuilder &builder,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  for (auto init : getOutputs()) {
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(builder, getLoc(), init);
+    reifiedReturnShapes.emplace_back(std::move(sizes));
+  }
+  return success();
+}
+
 #define DEFINE_OP_GET_EFFECTS(OP_NAME)                                         \
   void OP_NAME::getEffects(                                                    \
       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 9de0ae5..eb66a38 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1574,7 +1574,12 @@
 // Custom tilable op
 //===---------------------------------------------------------------------===//
 
-def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op"> {
+def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [
+  DeclareOpInterfaceMethods<LinalgFusionInterface>,
+  DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+  DeclareOpInterfaceMethods<TilingInterface,
+    ["getLoopIteratorTypes"]>
+  ]> {
   let summary = "Custom operation for compiling with IREE";
   let description = [{
     This operation is meant to allow computation sequences that are fused at
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index 62bca48..3f3c91b 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -795,3 +795,116 @@
 //  CHECK-SAME:       outs(%[[INIT0]] : tensor<1x1x32x1x4xf32>)
 //       CHECK:   flow.return %[[GEN]] : tensor<1x1x32x1x4xf32>
 //       CHECK:   util.return %[[DISP1]] : tensor<1x1x32x1x4xf32>
+
+// -----
+
+util.func @custom_op_consumer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+  %0 = iree_linalg_ext.custom_op {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]}
+      ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+    ^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?xf32>):
+      %1 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+          iterator_types = ["parallel", "reduction"]}
+          ins(%b0 : tensor<?x?xf32>) outs(%b1 : tensor<?xf32>) {
+        ^bb1(%bb0 : f32, %bb1 : f32) :
+          %2 = arith.addf %bb0, %bb1 : f32
+          linalg.yield %2 : f32
+      } -> tensor<?xf32>
+      iree_linalg_ext.yield %1 : tensor<?xf32>
+  } -> tensor<?xf32>
+  %3 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]}
+      ins(%0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32):
+      %4 = arith.mulf %b0, %b0 : f32
+      linalg.yield %4 :f32
+  } -> tensor<?xf32>
+  util.return %3 : tensor<?xf32>
+}
+// CHECK-LABEL: func public @custom_op_consumer_fusion
+//       CHECK:   %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:     %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
+//       CHECK:       linalg.generic
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[CUSTOM_OP]] :
+//       CHECK:     flow.return %[[GENERIC]]
+//       CHECK:   util.return %[[DISPATCH]]
+
+// -----
+
+util.func @custom_op_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 :f32
+  } -> tensor<?x?xf32>
+  %2 = iree_linalg_ext.custom_op {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]}
+      ins(%0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+    ^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?xf32>):
+      %3 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+          iterator_types = ["parallel", "reduction"]}
+          ins(%b0 : tensor<?x?xf32>) outs(%b1 : tensor<?xf32>) {
+        ^bb1(%bb0 : f32, %bb1 : f32) :
+          %4 = arith.addf %bb0, %bb1 : f32
+          linalg.yield %4 : f32
+      } -> tensor<?xf32>
+      iree_linalg_ext.yield %3 : tensor<?xf32>
+  } -> tensor<?xf32>
+  util.return %2 : tensor<?xf32>
+}
+// CHECK-LABEL: func public @custom_op_producer_fusion
+//       CHECK:   %[[DISPATCH:.+]] = flow.dispatch.region
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//       CHECK:     %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
+//  CHECK-SAME:         ins(%[[GENERIC]] :
+//       CHECK:     flow.return %[[CUSTOM_OP]]
+//       CHECK:   util.return %[[DISPATCH]]
+
+// -----
+
+util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32):
+      %1 = arith.mulf %b0, %b0 : f32
+      linalg.yield %1 :f32
+  } -> tensor<?x?xf32>
+  %2 = iree_linalg_ext.custom_op {
+      indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
+                       affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+                       affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
+                       affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
+                       affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
+      iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<parallel>]}
+      ins(%0, %arg1, %arg2, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg4 : tensor<?x?xf32>) {
+    ^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?x?xf32>, %b2 : tensor<?x?xf32>, %b3 : tensor<?x?xf32>, %b4 : tensor<?x?xf32>):
+      %3 = linalg.matmul ins(%b0, %b1 : tensor<?x?xf32>, tensor<?x?xf32>)
+          outs(%b2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      %4 = linalg.matmul ins(%3, %b3 : tensor<?x?xf32>, tensor<?x?xf32>)
+          outs(%b4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+      iree_linalg_ext.yield %4 : tensor<?x?xf32>
+  } -> tensor<?x?xf32>
+  util.return %2 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func public @custom_op_no_producer_fusion
+//       CHECK:   %[[DISPATCH1:.+]] = flow.dispatch.region
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//       CHECK:     flow.return %[[GENERIC]]
+//       CHECK:   %[[DISPATCH2:.+]] = flow.dispatch.region
+//       CHECK:     %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
+//  CHECK-SAME:         ins(%[[DISPATCH1]],
+//       CHECK:     flow.return %[[CUSTOM_OP]]
+//       CHECK:   util.return %[[DISPATCH2]]