[LinalgExt] Add Interfaces for implementing fusion support for `iree_linalg_ext.custom_op`. (#18647)
These methods allow the dispatch region formation to automatically pick
up fusion of `custom_op` with producers/consumers similar to what is
supported with `LinalgOp`s.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 85c6da0..d4d7344 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1766,6 +1766,35 @@
return success();
}
+SmallVector<AffineMap> CustomOp::getIndexingMapsForOperands() {
+ return llvm::map_to_vector(
+ getIndexingMaps().getValue().take_front(getNumDpsInputs()),
+ [](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
+}
+
+SmallVector<AffineMap> CustomOp::getIndexingMapsForResults() {
+ return llvm::map_to_vector(
+ getIndexingMaps().getValue().take_back(getNumDpsInits()),
+ [](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
+}
+
+SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
+ return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) {
+ return cast<IREE::LinalgExt::IteratorTypeAttr>(attr).getValue();
+ });
+}
+
+LogicalResult
+CustomOp::reifyResultShapes(OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ for (auto init : getOutputs()) {
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(builder, getLoc(), init);
+ reifiedReturnShapes.emplace_back(std::move(sizes));
+ }
+ return success();
+}
+
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 9de0ae5..eb66a38 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -1574,7 +1574,12 @@
// Custom tilable op
//===---------------------------------------------------------------------===//
-def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op"> {
+def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [
+ DeclareOpInterfaceMethods<LinalgFusionInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getLoopIteratorTypes"]>
+ ]> {
let summary = "Custom operation for compiling with IREE";
let description = [{
This operation is meant to allow computation sequences that are fused at
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index 62bca48..3f3c91b 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -795,3 +795,116 @@
// CHECK-SAME: outs(%[[INIT0]] : tensor<1x1x32x1x4xf32>)
// CHECK: flow.return %[[GEN]] : tensor<1x1x32x1x4xf32>
// CHECK: util.return %[[DISP1]] : tensor<1x1x32x1x4xf32>
+
+// -----
+
+util.func @custom_op_consumer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+ %0 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?xf32>):
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%b0 : tensor<?x?xf32>) outs(%b1 : tensor<?xf32>) {
+ ^bb1(%bb0 : f32, %bb1 : f32) :
+ %2 = arith.addf %bb0, %bb1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ iree_linalg_ext.yield %1 : tensor<?xf32>
+ } -> tensor<?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %4 = arith.mulf %b0, %b0 : f32
+ linalg.yield %4 :f32
+ } -> tensor<?xf32>
+ util.return %3 : tensor<?xf32>
+}
+// CHECK-LABEL: func public @custom_op_consumer_fusion
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
+// CHECK: linalg.generic
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CUSTOM_OP]] :
+// CHECK: flow.return %[[GENERIC]]
+// CHECK: util.return %[[DISPATCH]]
+
+// -----
+
+util.func @custom_op_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %1 = arith.mulf %b0, %b0 : f32
+ linalg.yield %1 :f32
+ } -> tensor<?x?xf32>
+ %2 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]}
+ ins(%0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?xf32>):
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%b0 : tensor<?x?xf32>) outs(%b1 : tensor<?xf32>) {
+ ^bb1(%bb0 : f32, %bb1 : f32) :
+ %4 = arith.addf %bb0, %bb1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?xf32>
+ iree_linalg_ext.yield %3 : tensor<?xf32>
+ } -> tensor<?xf32>
+ util.return %2 : tensor<?xf32>
+}
+// CHECK-LABEL: func public @custom_op_producer_fusion
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
+// CHECK-SAME: ins(%[[GENERIC]] :
+// CHECK: flow.return %[[CUSTOM_OP]]
+// CHECK: util.return %[[DISPATCH]]
+
+// -----
+
+util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %1 = arith.mulf %b0, %b0 : f32
+ linalg.yield %1 :f32
+ } -> tensor<?x?xf32>
+ %2 = iree_linalg_ext.custom_op {
+ indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
+ affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
+ affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
+ affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
+ affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
+ iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<parallel>]}
+ ins(%0, %arg1, %arg2, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg4 : tensor<?x?xf32>) {
+ ^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?x?xf32>, %b2 : tensor<?x?xf32>, %b3 : tensor<?x?xf32>, %b4 : tensor<?x?xf32>):
+ %3 = linalg.matmul ins(%b0, %b1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%b2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.matmul ins(%3, %b3 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%b4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ iree_linalg_ext.yield %4 : tensor<?x?xf32>
+ } -> tensor<?x?xf32>
+ util.return %2 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func public @custom_op_no_producer_fusion
+// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: flow.return %[[GENERIC]]
+// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region
+// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
+// CHECK-SAME: ins(%[[DISPATCH1]],
+// CHECK: flow.return %[[CUSTOM_OP]]
+// CHECK: util.return %[[DISPATCH2]]