Add `LinalgFusionInterface` to support fusion for linalg_ext ops (added `scatter` and `reverse`) (#17428)
`LinalgFusionOpInterface` allows for fusion of both `Linalg` and
`LinalgExt` operations. The new interface provides access to methods
essential for performing fusion, allowing existing fusion logic to be
used with `LinalgExt` operations.
As noted in #17392, it probably makes sense to move this into the
`TilingInterface` + probably make it a bit more abstracted
#### Changes
- **`LinalgFusionOpInterface`**: Interface for fusion operations for
both `Linalg` and `LinalgExt` ops.
- Implements methods to access indexing maps (or null
- **Implementation for Linalg Ops**: The interface is implemented for
standard Linalg operations by forwarding to preexisting methods (e.g
`getIndexingMaps()`). No changes to the ops themselves.
- **Implementation for LinalgExt Ops**: The interface currently only
implemented for `iree_linalg_ext.scatter/reverse`.
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index 2c40f82..9889ca6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -100,6 +100,7 @@
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithUtils",
"@llvm-project//mlir:ComplexDialect",
+ "@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 007891c..64ee979 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -78,6 +78,7 @@
MLIRArithDialect
MLIRArithUtils
MLIRComplexDialect
+ MLIRDestinationStyleOpInterface
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index beb2176..9f88f48 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -12,8 +12,12 @@
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -30,6 +34,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Pass/Pass.h"
@@ -390,8 +395,10 @@
// relationship through `operand` have compatible outer-parallel loops.
static bool hasCompatibleOuterParallelLoops(
OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) {
- auto producer = operand.get().getDefiningOp<linalg::LinalgOp>();
- auto consumer = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+ auto producer =
+ operand.get().getDefiningOp<LinalgExt::LinalgFusionOpInterface>();
+ auto consumer =
+ dyn_cast<LinalgExt::LinalgFusionOpInterface>(operand.getOwner());
if (!producer || !consumer)
return false;
@@ -399,6 +406,10 @@
llvm::cast<OpResult>(operand.get()));
auto consumerIndexingMap = consumer.getMatchingIndexingMap(&operand);
+ if (!producerIndexingMap || !consumerIndexingMap) {
+ return false;
+ }
+
return hasCompatibleOuterParallelLoops(
cast<TilingInterface>(producer.getOperation()),
producerIndexingMap, rootOuterParallelLoops) &&
@@ -605,14 +616,16 @@
return false;
}
- auto producerLinalgOp = dyn_cast<linalg::LinalgOp>(producer);
- auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
- if (!producerLinalgOp || !consumerLinalgOp)
+ auto producerFusionOp =
+ dyn_cast<LinalgExt::LinalgFusionOpInterface>(producer);
+ auto consumerFusionOp =
+ dyn_cast<LinalgExt::LinalgFusionOpInterface>(consumer);
+ if (!producerFusionOp || !consumerFusionOp)
return false;
// Check that the consumer is all parallel.
- if (consumerLinalgOp.getNumLoops() !=
- consumerLinalgOp.getNumParallelLoops()) {
+ if (consumerFusionOp.getNumLoops() !=
+ consumerFusionOp.getNumParallelLoops()) {
return false;
}
@@ -623,8 +636,8 @@
// Check if the iteration spaces of the producer and consumer are same.
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
- auto producerIterationSpace = producerLinalgOp.getStaticLoopRanges();
- auto consumerIterationSpace = consumerLinalgOp.getStaticLoopRanges();
+ auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
+ auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
}
@@ -640,12 +653,18 @@
// While fusing with consumer, the result of the root might not be the final
// result of the dispatch. To avoid a stack allocation we have to ensure that
// all operations can bufferize without needing additional memory.
- for (OpOperand *inputOperand : consumerLinalgOp.getDpsInputOperands()) {
+ auto consumerDstOp =
+ dyn_cast<DestinationStyleOpInterface>(consumerFusionOp.getOperation());
+ if (!consumerDstOp) {
+ return true;
+ }
+
+ for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) {
if (inputOperand->get().getDefiningOp() != producer)
continue;
if (isa<linalg::ConvolutionOpInterface>(producer) &&
!llvm::any_of(
- consumerLinalgOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
+ consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) {
return canUseInOperandAsInitOperand(inputOperand, &initOperand);
})) {
return false;
@@ -744,13 +763,14 @@
.Default([](Operation *) { return false; });
}
- if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer)) {
+ if (!isa<LinalgExt::LinalgFusionOpInterface>(consumer) ||
+ !isa<LinalgExt::LinalgFusionOpInterface>(producer)) {
return false;
}
if (!options.aggressiveFusion) {
- auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
- if (!consumerLinalgOp.isDpsInit(&operand)) {
+ auto consumerFusionOp = dyn_cast<DestinationStyleOpInterface>(consumer);
+ if (consumerFusionOp && !consumerFusionOp.isDpsInit(&operand)) {
return false;
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 592d015..cf2a55c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -167,6 +167,7 @@
"mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect",
"IREE::Flow::FlowDialect",
+ "IREE::LinalgExt::IREELinalgExtDialect",
];
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index b9ce61b..c74fc90 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -35,6 +35,7 @@
"form_dispatch_regions.mlir",
"form_dispatch_workgroups.mlir",
"form_scalar_dispatches.mlir",
+ "dispatch_linalg_ext_fusion.mlir",
"fusion_of_tensor_ops.mlir",
"fusion_preprocessing.mlir",
"initialize_empty_tensors.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index e7df9de..0a36a67 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -24,6 +24,7 @@
"collapse_reduction.mlir"
"convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
+ "dispatch_linalg_ext_fusion.mlir"
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_default.mlir"
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
new file mode 100644
index 0000000..0e03211
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
@@ -0,0 +1,82 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-form-dispatch-workgroups), cse, canonicalize, cse)" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
+ %0 = tensor.empty() : tensor<4x1xi32>
+ %1 = tensor.empty() : tensor<4x1xi64>
+ %2 = tensor.empty() : tensor<4x1x16x8x128xf32>
+ %3 = tensor.empty() : tensor<4x1x16x8x128xf32>
+ %4 = tensor.empty() : tensor<8192x16x8x128xf32>
+ %5 = tensor.empty() : tensor<8192x16x8x128xf32>
+ %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%1 : tensor<4x1xi64>) outs(%0 : tensor<4x1xi32>) {
+ ^bb0(%in: i64, %out: i32):
+ %10 = arith.trunci %in : i64 to i32
+ linalg.yield %10 : i32
+ } -> tensor<4x1xi32>
+
+ %7 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<4x1x16x8x128xf32>) outs(%3 : tensor<4x1x16x8x128xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %10 = arith.addf %in, %out : f32
+ linalg.yield %10 : f32
+ } -> tensor<4x1x16x8x128xf32>
+
+ %8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%7, %6 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%4 : tensor<8192x16x8x128xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ iree_linalg_ext.yield %arg0 : f32
+ } -> tensor<8192x16x8x128xf32>
+
+ // Dont fuse with scatter's consumer
+ %9 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<8192x16x8x128xf32>) outs(%5 : tensor<8192x16x8x128xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %10 = arith.addf %in, %out : f32
+ linalg.yield %10 : f32
+ } -> tensor<8192x16x8x128xf32>
+ util.return %9 : tensor<8192x16x8x128xf32>
+}
+
+// CHECK: util.func public @linalgext_scatter_fusion
+// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
+// CHECK: %[[INDICES:.+]] = linalg.generic
+// CHECK: %[[UPDATE:.+]] = linalg.generic
+// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
+// CHECK: flow.dispatch.workgroups
+// CHECK: %[[GEN2:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)
+
+
+
+// -----
+
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
+ %0 = tensor.empty() : tensor<10x10xi64>
+ %1 = tensor.empty() : tensor<10x10xi32>
+ %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) {
+ ^bb0(%in: i64, %out: i32):
+ %7 = arith.trunci %in : i64 to i32
+ linalg.yield %7 : i32
+ } -> tensor<10x10xi32>
+ %3 = tensor.empty() : tensor<10x10xi32>
+ %4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32>
+
+ // dont fuse with with reverse's consumer
+ %5 = tensor.empty() : tensor<10x10xi32>
+ %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ %7 = arith.addi %in, %out : i32
+ linalg.yield %7 : i32
+ } -> tensor<10x10xi32>
+ util.return %6 : tensor<10x10xi32>
+}
+
+// CHECK: util.func public @linalgext_reverse_fusion
+// CHECK: flow.dispatch.workgroups
+// CHECK: %[[SHRUNK:.+]] = linalg.generic
+// CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse
+// CHECK: ins(%[[SHRUNK]] : tensor<10x10xi32>)
+// CHECK: flow.dispatch.workgroups
+// CHECK: %[[GEN:.+]] = linalg.generic
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
index 7442d30..8675c43 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
@@ -82,6 +82,7 @@
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:LinalgStructuredOpsIncGen",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index 1fb81e3..a6e6576 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -50,6 +50,7 @@
MLIRIR
MLIRInferTypeOpInterface
MLIRLinalgDialect
+ MLIRLinalgStructuredOpsIncGenLib
MLIRLinalgUtils
MLIRMathDialect
MLIRMemRefDialect
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 0b3215d..1ec216b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -6,14 +6,19 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
@@ -46,7 +51,83 @@
}
};
+// Used to register the LinalgFusionOpInterface with the linalg ops.
+namespace {
+template <typename ConcreteType>
+struct LinalgFusionOpInterfaceAdapter
+ : public LinalgFusionOpInterface::ExternalModel<
+ LinalgFusionOpInterfaceAdapter<ConcreteType>, ConcreteType> {
+public:
+ SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
+ auto maps = llvm::cast<ConcreteType>(op)
+ .getIndexingMaps()
+ .template getAsValueRange<AffineMapAttr>();
+ return {maps.begin(),
+ maps.end() - llvm::cast<ConcreteType>(op).getNumResults()};
+ }
+
+ SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
+ auto maps = llvm::cast<ConcreteType>(op)
+ .getIndexingMaps()
+ .template getAsValueRange<AffineMapAttr>();
+ return {maps.end() - llvm::cast<ConcreteType>(op).getNumResults(),
+ maps.end()};
+ }
+
+ // Forward all the interface methods to the corresponding linalg op.
+ unsigned getNumParallelLoops(mlir::Operation *op) const {
+ return (llvm::cast<ConcreteType>(op).getNumParallelLoops());
+ }
+
+ unsigned getNumLoops(mlir::Operation *op) const {
+ return (llvm::cast<ConcreteType>(op).getNumLoops());
+ }
+
+ SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
+ return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
+ }
+
+ AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
+ OpResult result) const {
+ return (llvm::cast<ConcreteType>(op).getIndexingMapMatchingResult(result));
+ }
+
+ AffineMap getMatchingIndexingMap(mlir::Operation *op,
+ OpOperand *operand) const {
+ return (llvm::cast<ConcreteType>(op).getMatchingIndexingMap(operand));
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(mlir::Operation *op) const {
+ // Note: this is different from linalg's implementation
+ // of `getIndexingMaps`. Call interface methods to get
+ // the vector of indexing maps for operands and results.
+ auto inputMaps = getIndexingMapsForOperands(op);
+ llvm::append_range(inputMaps, getIndexingMapsForResults(op));
+ return inputMaps;
+ }
+};
+} // namespace
+
+template <typename... Args>
+static void registerOpsWithLinalgExtOpInterface(mlir::MLIRContext *context) {
+ (Args::template attachInterface<LinalgFusionOpInterfaceAdapter<Args>>(
+ *context),
+ ...);
+}
+
void IREELinalgExtDialect::initialize() {
+ mlir::MLIRContext *context = getContext();
+ context->loadDialect<mlir::linalg::LinalgDialect>();
+
+#define GET_OP_LIST
+ declarePromisedInterfaces<LinalgFusionOpInterface,
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >();
+
+#define GET_OP_LIST
+ registerOpsWithLinalgExtOpInterface<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(context);
addInterfaces<IREELinalgExtInlinerInterface>();
addAttributes<
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
index 0babf32..984db79 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
@@ -7,11 +7,20 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index df04c86..77c3266 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -8,6 +8,7 @@
#define IREE_DIALECT_LINALGEXT_INTERFACES
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
// The interface is a subset of LinalgStructuredInterface.
def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
@@ -40,4 +41,120 @@
let verify = [{ return detail::verifyLinalgExtOpInterface($_op); }];
}
+// Interface that allows for fusion of both LinalgExt and Linalg ops.
+def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> {
+ let methods = [
+ //===------------------------------------------------------------------===//
+ // Interface methods for fusion.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return an AffineMap for each operand or nullptr if the operand
+ does not have an indexing map representation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMapsForOperands",
+ /*args=*/(ins),
+ /*methodBody=*/""
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return an AffineMap for each result or nullptr if the result
+ does not have an indexing map representation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMapsForResults",
+ /*args=*/(ins),
+ /*methodBody=*/""
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ For each operand/result return indexing map or nullptr if an
+ operand or result does not have an indexing map representation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto inputMaps = $_op.getIndexingMapsForOperands();
+ llvm::append_range(inputMaps, $_op.getIndexingMapsForResults());
+ return inputMaps;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of parallel loops.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumParallelLoops",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return llvm::count($_op.getLoopIteratorTypes(), utils::IteratorType::parallel);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the total number of loops.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumLoops",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getLoopIteratorTypes().size();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the static loop ranges.
+ }],
+ /*retTy=*/"SmallVector<int64_t, 4>",
+ /*methodName=*/"getStaticLoopRanges",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<int64_t, 4> loopRanges;
+ llvm::for_each($_op.getOperands(), [&](Value operand) {
+ if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
+ llvm::append_range(loopRanges, shapedType.getShape());
+ }
+ });
+ return loopRanges;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing map for an op's `result` or nullptr if
+ the indexing map is not representable.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getIndexingMapMatchingResult",
+ /*args=*/(ins "OpResult":$result),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(result.getOwner() == $_op);
+ return $_op.getIndexingMapsForResults()[result.getResultNumber()];
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing map for `opOperand` or nullptr if
+ the indexing map is not representable.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getMatchingIndexingMap",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == $_op);
+ return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()];
+ }]
+ >,
+
+
+ ];
+}
+
#endif // IREE_DIALECT_LINALGEXT_INTERFACES
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 950bf9f..c5c42ec 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
@@ -22,8 +23,11 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
@@ -31,11 +35,15 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
+#include <cstdint>
+#include <optional>
+
namespace mlir::iree_compiler::IREE::LinalgExt {
//===----------------------------------------------------------------------===//
@@ -245,6 +253,16 @@
.reifyResultShapes(b, reifiedReturnShapes);
}
+SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
+ Builder builder(getContext());
+ return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
+ builder.getMultiDimIdentityMap(getIndicesType().getRank())};
+}
+
+SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
+ return {AffineMap(nullptr)};
+}
+
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
@@ -478,6 +496,15 @@
.reifyResultShapes(b, reifiedReturnShapes);
}
+SmallVector<AffineMap> ReverseOp::getIndexingMapsForOperands() {
+ Builder builder(getContext());
+ return {builder.getMultiDimIdentityMap(getOperandRank())};
+}
+
+SmallVector<AffineMap> ReverseOp::getIndexingMapsForResults() {
+ return {AffineMap(nullptr)};
+}
+
//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c75f709..bf9694d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -78,6 +78,7 @@
def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
@@ -369,6 +370,7 @@
def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<
TilingInterface,
["generateScalarImplementation",