Add `LinalgFusionInterface` to support fusion for linalg_ext ops (added `scatter` and `reverse`) (#17428)

`LinalgFusionOpInterface` allows for fusion of both `Linalg` and
`LinalgExt` operations. The new interface provides access to methods
essential for performing fusion, allowing existing fusion logic to be
used with `LinalgExt` operations.

As noted in #17392, it probably makes sense to move this into the
`TilingInterface` + probably make it a bit more abstracted


#### Changes
- **`LinalgFusionOpInterface`**: Interface for fusion operations for
both `Linalg` and `LinalgExt` ops.
  - Implements methods to access indexing maps (or null
- **Implementation for Linalg Ops**: The interface is implemented for
standard Linalg operations by forwarding to preexisting methods (e.g
`getIndexingMaps()`). No changes to the ops themselves.
- **Implementation for LinalgExt Ops**: The interface currently only
implemented for `iree_linalg_ext.scatter/reverse`.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index 2c40f82..9889ca6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -100,6 +100,7 @@
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:ArithUtils",
         "@llvm-project//mlir:ComplexDialect",
+        "@llvm-project//mlir:DestinationStyleOpInterface",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 007891c..64ee979 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -78,6 +78,7 @@
     MLIRArithDialect
     MLIRArithUtils
     MLIRComplexDialect
+    MLIRDestinationStyleOpInterface
     MLIRFuncDialect
     MLIRFunctionInterfaces
     MLIRIR
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index beb2176..9f88f48 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -12,8 +12,12 @@
 #include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -30,6 +34,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Pass/Pass.h"
@@ -390,8 +395,10 @@
 // relationship through `operand` have compatible outer-parallel loops.
 static bool hasCompatibleOuterParallelLoops(
     OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) {
-  auto producer = operand.get().getDefiningOp<linalg::LinalgOp>();
-  auto consumer = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+  auto producer =
+      operand.get().getDefiningOp<LinalgExt::LinalgFusionOpInterface>();
+  auto consumer =
+      dyn_cast<LinalgExt::LinalgFusionOpInterface>(operand.getOwner());
   if (!producer || !consumer)
     return false;
 
@@ -399,6 +406,10 @@
       llvm::cast<OpResult>(operand.get()));
   auto consumerIndexingMap = consumer.getMatchingIndexingMap(&operand);
 
+  if (!producerIndexingMap || !consumerIndexingMap) {
+    return false;
+  }
+
   return hasCompatibleOuterParallelLoops(
              cast<TilingInterface>(producer.getOperation()),
              producerIndexingMap, rootOuterParallelLoops) &&
@@ -605,14 +616,16 @@
     return false;
   }
 
-  auto producerLinalgOp = dyn_cast<linalg::LinalgOp>(producer);
-  auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
-  if (!producerLinalgOp || !consumerLinalgOp)
+  auto producerFusionOp =
+      dyn_cast<LinalgExt::LinalgFusionOpInterface>(producer);
+  auto consumerFusionOp =
+      dyn_cast<LinalgExt::LinalgFusionOpInterface>(consumer);
+  if (!producerFusionOp || !consumerFusionOp)
     return false;
 
   // Check that the consumer is all parallel.
-  if (consumerLinalgOp.getNumLoops() !=
-      consumerLinalgOp.getNumParallelLoops()) {
+  if (consumerFusionOp.getNumLoops() !=
+      consumerFusionOp.getNumParallelLoops()) {
     return false;
   }
 
@@ -623,8 +636,8 @@
   // Check if the iteration spaces of the producer and consumer are same.
   // TODO(#12664): This is unnecessary requirement, but we need a better config
   // to tile the consumer with a larger iteration space.
-  auto producerIterationSpace = producerLinalgOp.getStaticLoopRanges();
-  auto consumerIterationSpace = consumerLinalgOp.getStaticLoopRanges();
+  auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
+  auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
   if (producerIterationSpace.size() < consumerIterationSpace.size()) {
     return false;
   }
@@ -640,12 +653,18 @@
   // While fusing with consumer, the result of the root might not be the final
   // result of the dispatch. To avoid a stack allocation we have to ensure that
   // all operations can bufferize without needing additional memory.
-  for (OpOperand *inputOperand : consumerLinalgOp.getDpsInputOperands()) {
+  auto consumerDstOp =
+      dyn_cast<DestinationStyleOpInterface>(consumerFusionOp.getOperation());
+  if (!consumerDstOp) {
+    return true;
+  }
+
+  for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) {
     if (inputOperand->get().getDefiningOp() != producer)
       continue;
     if (isa<linalg::ConvolutionOpInterface>(producer) &&
         !llvm::any_of(
-            consumerLinalgOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
+            consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
               return canUseInOperandAsInitOperand(inputOperand, &initOperand);
             })) {
       return false;
@@ -744,13 +763,14 @@
         .Default([](Operation *) { return false; });
   }
 
-  if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer)) {
+  if (!isa<LinalgExt::LinalgFusionOpInterface>(consumer) ||
+      !isa<LinalgExt::LinalgFusionOpInterface>(producer)) {
     return false;
   }
 
   if (!options.aggressiveFusion) {
-    auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
-    if (!consumerLinalgOp.isDpsInit(&operand)) {
+    auto consumerFusionOp = dyn_cast<DestinationStyleOpInterface>(consumer);
+    if (consumerFusionOp && !consumerFusionOp.isDpsInit(&operand)) {
       return false;
     }
   }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 592d015..cf2a55c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -167,6 +167,7 @@
     "mlir::scf::SCFDialect",
     "mlir::tensor::TensorDialect",
     "IREE::Flow::FlowDialect",
+    "IREE::LinalgExt::IREELinalgExtDialect",
   ];
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index b9ce61b..c74fc90 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -35,6 +35,7 @@
             "form_dispatch_regions.mlir",
             "form_dispatch_workgroups.mlir",
             "form_scalar_dispatches.mlir",
+            "dispatch_linalg_ext_fusion.mlir",
             "fusion_of_tensor_ops.mlir",
             "fusion_preprocessing.mlir",
             "initialize_empty_tensors.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index e7df9de..0a36a67 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -24,6 +24,7 @@
     "collapse_reduction.mlir"
     "convert_region_to_workgroups.mlir"
     "deduplicate_executables.mlir"
+    "dispatch_linalg_ext_fusion.mlir"
     "dispatch_linalg_on_tensors.mlir"
     "dispatch_linalg_on_tensors_default.mlir"
     "dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
new file mode 100644
index 0000000..0e03211
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
@@ -0,0 +1,82 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-form-dispatch-workgroups), cse, canonicalize, cse)" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
+  %0 = tensor.empty() : tensor<4x1xi32>
+  %1 = tensor.empty() : tensor<4x1xi64>
+  %2 = tensor.empty() : tensor<4x1x16x8x128xf32>
+  %3 = tensor.empty() : tensor<4x1x16x8x128xf32>
+  %4 = tensor.empty() : tensor<8192x16x8x128xf32>
+  %5 = tensor.empty() : tensor<8192x16x8x128xf32>
+  %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%1 : tensor<4x1xi64>) outs(%0 : tensor<4x1xi32>) {
+  ^bb0(%in: i64, %out: i32):
+    %10 = arith.trunci %in : i64 to i32
+    linalg.yield %10 : i32
+  } -> tensor<4x1xi32>
+
+  %7 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<4x1x16x8x128xf32>) outs(%3 : tensor<4x1x16x8x128xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %10 = arith.addf %in, %out : f32
+    linalg.yield %10 : f32
+  } -> tensor<4x1x16x8x128xf32>
+
+  %8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%7, %6 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%4 : tensor<8192x16x8x128xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    iree_linalg_ext.yield %arg0 : f32
+  } -> tensor<8192x16x8x128xf32>
+
+  // Dont fuse with scatter's consumer
+  %9 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<8192x16x8x128xf32>) outs(%5 : tensor<8192x16x8x128xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %10 = arith.addf %in, %out : f32
+    linalg.yield %10 : f32
+  } -> tensor<8192x16x8x128xf32>
+  util.return %9 : tensor<8192x16x8x128xf32>
+}
+
+// CHECK:     util.func public @linalgext_scatter_fusion
+// CHECK:       %[[RESULT:.+]] = flow.dispatch.workgroups
+// CHECK:         %[[INDICES:.+]] = linalg.generic
+// CHECK:         %[[UPDATE:.+]] = linalg.generic
+// CHECK:         %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME:      ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
+// CHECK:       flow.dispatch.workgroups
+// CHECK:         %[[GEN2:.+]] = linalg.generic
+// CHECK-SAME:      ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)
+
+
+
+// -----
+
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
+  %0 = tensor.empty() : tensor<10x10xi64>
+  %1 = tensor.empty() : tensor<10x10xi32>
+  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) {
+  ^bb0(%in: i64, %out: i32):
+    %7 = arith.trunci %in : i64 to i32
+    linalg.yield %7 : i32
+  } -> tensor<10x10xi32>
+  %3 = tensor.empty() : tensor<10x10xi32>
+  %4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32>
+
+  // dont fuse with with reverse's consumer
+  %5 = tensor.empty() : tensor<10x10xi32>
+  %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    %7 = arith.addi %in, %out : i32
+    linalg.yield %7 : i32
+  } -> tensor<10x10xi32>
+  util.return %6 : tensor<10x10xi32>
+}
+
+// CHECK:     util.func public @linalgext_reverse_fusion
+// CHECK:       flow.dispatch.workgroups
+// CHECK:       %[[SHRUNK:.+]] = linalg.generic
+// CHECK:       %[[REVERSED:.+]] = iree_linalg_ext.reverse
+// CHECK:         ins(%[[SHRUNK]] : tensor<10x10xi32>)
+// CHECK:       flow.dispatch.workgroups
+// CHECK:       %[[GEN:.+]] = linalg.generic
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
index 7442d30..8675c43 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
@@ -82,6 +82,7 @@
         "@llvm-project//mlir:InferTypeOpInterface",
         "@llvm-project//mlir:InliningUtils",
         "@llvm-project//mlir:LinalgDialect",
+        "@llvm-project//mlir:LinalgStructuredOpsIncGen",
         "@llvm-project//mlir:LinalgUtils",
         "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:MemRefDialect",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index 1fb81e3..a6e6576 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -50,6 +50,7 @@
     MLIRIR
     MLIRInferTypeOpInterface
     MLIRLinalgDialect
+    MLIRLinalgStructuredOpsIncGenLib
     MLIRLinalgUtils
     MLIRMathDialect
     MLIRMemRefDialect
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 0b3215d..1ec216b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -6,14 +6,19 @@
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 using namespace mlir;
@@ -46,7 +51,83 @@
   }
 };
 
+// Used to register the LinalgFusionOpInterface with the linalg ops.
+namespace {
+template <typename ConcreteType>
+struct LinalgFusionOpInterfaceAdapter
+    : public LinalgFusionOpInterface::ExternalModel<
+          LinalgFusionOpInterfaceAdapter<ConcreteType>, ConcreteType> {
+public:
+  SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
+    auto maps = llvm::cast<ConcreteType>(op)
+                    .getIndexingMaps()
+                    .template getAsValueRange<AffineMapAttr>();
+    return {maps.begin(),
+            maps.end() - llvm::cast<ConcreteType>(op).getNumResults()};
+  }
+
+  SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
+    auto maps = llvm::cast<ConcreteType>(op)
+                    .getIndexingMaps()
+                    .template getAsValueRange<AffineMapAttr>();
+    return {maps.end() - llvm::cast<ConcreteType>(op).getNumResults(),
+            maps.end()};
+  }
+
+  // Forward all the interface methods to the corresponding linalg op.
+  unsigned getNumParallelLoops(mlir::Operation *op) const {
+    return (llvm::cast<ConcreteType>(op).getNumParallelLoops());
+  }
+
+  unsigned getNumLoops(mlir::Operation *op) const {
+    return (llvm::cast<ConcreteType>(op).getNumLoops());
+  }
+
+  SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
+    return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
+  }
+
+  AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
+                                         OpResult result) const {
+    return (llvm::cast<ConcreteType>(op).getIndexingMapMatchingResult(result));
+  }
+
+  AffineMap getMatchingIndexingMap(mlir::Operation *op,
+                                   OpOperand *operand) const {
+    return (llvm::cast<ConcreteType>(op).getMatchingIndexingMap(operand));
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(mlir::Operation *op) const {
+    // Note: this is different from linalg's implementation
+    // of `getIndexingMaps`. Call interface methods to get
+    // the vector of indexing maps for operands and results.
+    auto inputMaps = getIndexingMapsForOperands(op);
+    llvm::append_range(inputMaps, getIndexingMapsForResults(op));
+    return inputMaps;
+  }
+};
+} // namespace
+
+template <typename... Args>
+static void registerOpsWithLinalgExtOpInterface(mlir::MLIRContext *context) {
+  (Args::template attachInterface<LinalgFusionOpInterfaceAdapter<Args>>(
+       *context),
+   ...);
+}
+
 void IREELinalgExtDialect::initialize() {
+  mlir::MLIRContext *context = getContext();
+  context->loadDialect<mlir::linalg::LinalgDialect>();
+
+#define GET_OP_LIST
+  declarePromisedInterfaces<LinalgFusionOpInterface,
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+                            >();
+
+#define GET_OP_LIST
+  registerOpsWithLinalgExtOpInterface<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+      >(context);
   addInterfaces<IREELinalgExtInlinerInterface>();
 
   addAttributes<
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
index 0babf32..984db79 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
@@ -7,11 +7,20 @@
 #ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
 #define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
 
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Support/LLVM.h"
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index df04c86..77c3266 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -8,6 +8,7 @@
 #define IREE_DIALECT_LINALGEXT_INTERFACES
 
 include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 
 // The interface is a subset of LinalgStructuredInterface.
 def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
@@ -40,4 +41,120 @@
   let verify = [{ return detail::verifyLinalgExtOpInterface($_op); }];
 }
 
+// Interface that allows for fusion of both LinalgExt and Linalg ops.
+def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> {
+  let methods = [
+    //===------------------------------------------------------------------===//
+    // Interface methods for fusion. 
+    //===------------------------------------------------------------------===//
+    InterfaceMethod<
+      /*desc=*/[{
+        Return an AffineMap for each operand or nullptr if the operand
+        does not have an indexing map representation.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getIndexingMapsForOperands",
+      /*args=*/(ins),
+      /*methodBody=*/""
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return an AffineMap for each result or nullptr if the result
+        does not have an indexing map representation.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getIndexingMapsForResults",
+      /*args=*/(ins),
+      /*methodBody=*/""
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        For each operand/result return indexing map or nullptr if an
+        operand or result does not have an indexing map representation.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getIndexingMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto inputMaps = $_op.getIndexingMapsForOperands();
+        llvm::append_range(inputMaps, $_op.getIndexingMapsForResults());
+        return inputMaps;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of parallel loops.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getNumParallelLoops",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::count($_op.getLoopIteratorTypes(), utils::IteratorType::parallel);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the total number of loops.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getNumLoops",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getLoopIteratorTypes().size(); 
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the static loop ranges.
+      }],
+      /*retTy=*/"SmallVector<int64_t, 4>",
+      /*methodName=*/"getStaticLoopRanges",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+          SmallVector<int64_t, 4> loopRanges;
+          llvm::for_each($_op.getOperands(), [&](Value operand) {
+          if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
+            llvm::append_range(loopRanges, shapedType.getShape());
+            }
+          });
+          return loopRanges;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing map for an op's `result` or nullptr if
+        the indexing map is not representable.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getIndexingMapMatchingResult",
+      /*args=*/(ins "OpResult":$result),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(result.getOwner() == $_op);
+        return $_op.getIndexingMapsForResults()[result.getResultNumber()];
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing map for `opOperand` or nullptr if
+        the indexing map is not representable.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getMatchingIndexingMap",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == $_op);
+        return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()];
+      }]
+    >,
+
+ 
+  ];
+}
+
 #endif  // IREE_DIALECT_LINALGEXT_INTERFACES
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 950bf9f..c5c42ec 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
@@ -22,8 +23,11 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/OpDefinition.h"
@@ -31,11 +35,15 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 
+#include <cstdint>
+#include <optional>
+
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
 //===----------------------------------------------------------------------===//
@@ -245,6 +253,16 @@
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
+  Builder builder(getContext());
+  return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
+          builder.getMultiDimIdentityMap(getIndicesType().getRank())};
+}
+
+SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
+  return {AffineMap(nullptr)};
+}
+
 //===----------------------------------------------------------------------===//
 // SortOp
 //===----------------------------------------------------------------------===//
@@ -478,6 +496,15 @@
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+SmallVector<AffineMap> ReverseOp::getIndexingMapsForOperands() {
+  Builder builder(getContext());
+  return {builder.getMultiDimIdentityMap(getOperandRank())};
+}
+
+SmallVector<AffineMap> ReverseOp::getIndexingMapsForResults() {
+  return {AffineMap(nullptr)};
+}
+
 //===----------------------------------------------------------------------===//
 // TopkOp
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c75f709..bf9694d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -78,6 +78,7 @@
 
 def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
     [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<LinalgFusionInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
         ["generateScalarImplementation",
          "getIterationDomain",
@@ -369,6 +370,7 @@
 
 def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
   DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+  DeclareOpInterfaceMethods<LinalgFusionInterface>,
   DeclareOpInterfaceMethods<
       TilingInterface,
       ["generateScalarImplementation",