[Flow] Fix dispatch naming for dynamic shaped fusions (#19439)
Currently all ops with dynamic shapes are assigned the same estimated
cost when naming dispatches. This means that in cases like fused
elementwise ops with matmuls, the elementwise and matmuls are assigned
the same priority and because of traversal order, the dispatch ends up
following the name of the elementwise op.
This patch hacks it by treating all dynamic shapes as moderately sized
static shapes, but in the future if we have more issues we can look at
adding some tensor size range analysis that can give us upper bounds for
the dynamic shapes.
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Co-authored-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp
index e71f856..ad93878 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp
@@ -32,13 +32,26 @@
namespace {
+// This op estimates the cost of a list of perfectly nested loop ranges simply
+// as the product of ranges. Note that this does not take into account the cost
+// of the body of the op whose domain this computes.
static int64_t costOfDomain(ArrayRef<int64_t> domain) {
int64_t product = 1;
for (int64_t size : domain) {
+ int64_t multiplier = size;
if (ShapedType::isDynamic(size)) {
+ // HACK: Use a placeholder value for dynamic sizes. In practice, because
+ // we tend to require that iteration spaces of linalg ops line up for
+ // fusion to occur, more dynamic dims => a larger iteration domain.
+ // TODO: Query the upper bound of the dynamic size range instead.
+ multiplier = 1024;
+ }
+
+ // Preform saturating multiplication
+ if (product > kMaxCost / multiplier) {
return kMaxCost;
}
- product *= size;
+ product *= multiplier;
}
return product;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/annotate_dispatches.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/annotate_dispatches.mlir
index 77aebcc..c795817 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/annotate_dispatches.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/annotate_dispatches.mlir
@@ -669,3 +669,39 @@
}
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map3 = affine_map<(d0, d1) -> (d0, d1)>
+
+flow.executable private @ex {
+ // CHECK: flow.executable.export public @dispatch_matmul_like_16xDx8_f32
+ flow.executable.export public @dispatch
+ builtin.module {
+ func.func @dispatch(%arg0: !flow.dispatch.tensor<readwrite:tensor<16x?xf32>>, %arg1: index) {
+ %0 = tensor.empty() : tensor<16x8xf32>
+ %1 = tensor.empty(%arg1) : tensor<8x?xf32>
+ %init = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [16, %arg1], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<16x?xf32>>{%arg1} -> tensor<16x?xf32>
+ %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%0, %1 : tensor<16x8xf32>, tensor<8x?xf32>) outs(%init : tensor<16x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %3 = arith.mulf %in, %in_0 : f32
+ %4 = arith.addf %out, %3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<16x?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [#map3, #map3],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%2 : tensor<16x?xf32>) outs(%2 : tensor<16x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %4 = math.rsqrt %in : f32
+ linalg.yield %4 : f32
+ } -> tensor<16x?xf32>
+ flow.dispatch.tensor.store %3, %arg0, offsets = [0, 0], sizes = [16, %arg1], strides = [1, 1] : tensor<16x?xf32> -> !flow.dispatch.tensor<readwrite:tensor<16x?xf32>>{%arg1}
+ return
+ }
+ }
+}