[DT] Fuse encoding ops more aggressively for multi-use, gather, and slices ops. (#21830)
The fusion constraint of multi-use dispatch is only required by
SetEncoding pass, because it has to move consumer dispatches around. It
is not required by encoding fusion, because it is just moving a
SetEncoding op into its producer dispatch.
The revision also allows the fusion when the dispatch region contains
tensor.extract_slice op and iree_linalg_ext.gather ops. It reduces the
number of dispatches to 644 in llama fp8 model, the same as without data
tiling. The latency drops 25ms, from 378ms to 353ms.
| | No Data Tiling | Data Tiling w/o the revision | Data Tiling w/ the
revision |
| ------------- | ------------- | ------------- | ------------- |
| Benchmark latency | 243ms | 378ms | 353ms |
| Memory usage (HIP unpooled) | 15.9GB | 31.14GB | 31.11GB |
| Number of dispatches | 644 | 741 | 644 |
| | No Data Tiling (ms) | Data Tiling w/o the revision | Data Tiling w/
the revision |
| ------------- | ------------- | ------------- | ------------- |
| dispatch_15_attention_4x8x4xDx128xf8 | 62.29 | 55.35 | 59.21 |
| dispatch_20_matmul_like_Dx14336x4096_f8xf8xf32 | 40.13 | 89.14 |
93.72|
| dispatch_19_matmul_like_Dx14336x4096_f8xf8xf32 | 28.01 | 44.78 | 44.59
|
| dispatch_21_matmul_like_Dx4096x14336_f8xf8xf32 | 27.25 | 40.18 | 39.99
|
| dispatch_643_matmul_like_Dx128256x4096_f16xf16xf32 | 17.1 | 29.76 |
29.21 |
| dispatch_16_matmul_like_Dx4096x4096_f8xf8xf32 | 8.83 | 17.92 | 17.91 |
| dispatch_23_matmul_like_Dx4096x4096_f8xf8xf32 | 9.27 | 16.69 | 16.59 |
| encoding_10_encode_Dx4096xf8_to_Dx4096xf8 | - | 32.15 | - |
| encoding_6_encode_Dx14336xf32_to_Dx14336xf32 | - | 0.318 | - |
---------
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
index 9d5e876..cb1efe2 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
@@ -31,12 +31,13 @@
namespace {
/// Return true if the op is fusable with a SetEncodingOp consumer. The op's
-/// containing dispatch must contain only reshape ops, encoding ops, linalg ops,
-/// and attention ops. Non ShapedType ops (like arith ops, dim ops, etc.) are
-/// also allowed.
-/// TODO(#20179): It should be done by interface methods.
-static bool isFusableWithSetEncoding(Operation *op) {
- auto parentRegion = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
+/// containing dispatch must contain only:
+/// - Reshape ops, encoding ops, linalg ops, gather ops, and attention ops.
+/// - Non ShapedType ops, e.g., like arith ops, dim ops, etc.
+/// - tensor::ExtractSliceOp is allowed as they can be folded into dispatch
+/// tensor load ops.
+static bool isFusableWithSetEncoding(Operation *target) {
+ auto parentRegion = target->getParentOfType<IREE::Flow::DispatchRegionOp>();
// Make sure the dispatch region has only one block.
if (!llvm::hasSingleElement(parentRegion.getBody())) {
return false;
@@ -49,8 +50,9 @@
continue;
}
if (!isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::EmptyOp,
- IREE::Encoding::SetEncodingOp, IREE::Encoding::UnsetEncodingOp,
- linalg::LinalgOp, IREE::LinalgExt::AttentionOp>(op)) {
+ tensor::ExtractSliceOp, IREE::Encoding::SetEncodingOp,
+ IREE::Encoding::UnsetEncodingOp, linalg::LinalgOp,
+ IREE::LinalgExt::AttentionOp, IREE::LinalgExt::GatherOp>(op)) {
return false;
}
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index 1a321e4..cf51973 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -145,13 +145,8 @@
auto producerDispatch =
dyn_cast<IREE::Flow::DispatchRegionOp>(producerValue.getOwner());
- // TODO(MaheshRavishankar): Multi-result producer dispatches can be supported.
- // Will require to move the consumer dispatch immediately after the producer
- // instead of what is done below and move other operands of the consumer
- // dispatch before the producer dispatch.
if (!producerDispatch ||
- !llvm::hasSingleElement(producerDispatch.getBody()) ||
- producerDispatch->getNumResults() != 1) {
+ !llvm::hasSingleElement(producerDispatch.getBody())) {
return std::nullopt;
}
if (!llvm::hasSingleElement(producerValue.getUses())) {
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
index f422449..a0b483f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
@@ -28,8 +28,8 @@
/// Returns the closest producer dispatch region op result and the chain of
/// operations being looked past during the traversal to find the producer
-/// dispatch. Returns std::nullopt if the dispatch or any ops in the chain have
-/// multiple uses.
+/// dispatch. Returns std::nullopt if the dispatch can not be found in the
+/// chain or any op in the chain is not a reshape-like op.
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
getProducerDispatchValueAndOpChain(Value operand);
diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index b2f9a56..1c44973 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -392,16 +392,24 @@
OpOperand &operand = op->getOpOperand(operandNum);
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
dispatchAndOpChain = getProducerDispatchValueAndOpChain(operand.get());
- if (dispatchAndOpChain.has_value()) {
- auto producerDispatch = cast<IREE::Flow::DispatchRegionOp>(
- dispatchAndOpChain->first.getOwner());
- WalkResult res =
- producerDispatch->walk([&](IREE::LinalgExt::AttentionOp op) {
- return WalkResult::interrupt();
- });
- if (res.wasInterrupted()) {
- return {};
- }
+ if (!dispatchAndOpChain.has_value()) {
+ continue;
+ }
+ auto producerDispatch = cast<IREE::Flow::DispatchRegionOp>(
+ dispatchAndOpChain->first.getOwner());
+ // TODO(MaheshRavishankar): Multi-result producer dispatches can be
+ // supported. Will require to move the consumer dispatch immediately after
+ // the producer instead of what is done below and move other operands of the
+ // consumer dispatch before the producer dispatch.
+ if (producerDispatch->getNumResults() != 1) {
+ continue;
+ }
+ WalkResult res =
+ producerDispatch->walk([&](IREE::LinalgExt::AttentionOp op) {
+ return WalkResult::interrupt();
+ });
+ if (res.wasInterrupted()) {
+ return {};
}
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
index 1d12291..d044636 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
@@ -281,6 +281,24 @@
// -----
#encoding = #iree_encoding.testing<>
+util.func public @extract_slice_fusion(%arg0: tensor<192x1024x64xf32>) -> tensor<96x512x32xf32, #encoding> {
+ %0 = flow.dispatch.region -> (tensor<96x512x32xf32>) {
+ %1 = tensor.extract_slice %arg0[0, 0, 0] [96, 512, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<96x512x32xf32>
+ flow.return %1 : tensor<96x512x32xf32>
+ }
+ %2 = iree_encoding.set_encoding %0 : tensor<96x512x32xf32> -> tensor<96x512x32xf32, #encoding>
+ util.return %2 : tensor<96x512x32xf32, #encoding>
+}
+// CHECK-LABEL: @extract_slice_fusion
+// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
+// CHECK: tensor.extract_slice
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
+// CHECK: flow.return %[[SET_ENCODING]] :
+// CHECK: util.return %[[DISPATCH0]]
+
+// -----
+
+#encoding = #iree_encoding.testing<>
util.func public @attention_fusion(
%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>,
%value: tensor<192x1024x64xf32>, %scale: f32) -> tensor<192x1024x64xf32, #encoding> {
@@ -309,3 +327,55 @@
// CHECK: flow.return %[[SET_ENCODING]] :
// CHECK: }
// CHECK: util.return %[[DISPATCH0]]
+
+// -----
+
+#encoding = #iree_encoding.testing<>
+util.func public @gather_fusion(
+ %source: tensor<10x10xi32>, %indices: tensor<1xi32>) -> tensor<1x10xi32, #encoding> {
+ %empty = tensor.empty() : tensor<1x10xi32>
+ %1 = flow.dispatch.region -> (tensor<1x10xi32>) {
+ %3 = iree_linalg_ext.gather dimension_map = [0]
+ ins(%source, %indices : tensor<10x10xi32>, tensor<1xi32>)
+ outs(%empty : tensor<1x10xi32>) -> tensor<1x10xi32>
+ flow.return %3 : tensor<1x10xi32>
+ }
+ %2 = iree_encoding.set_encoding %1 : tensor<1x10xi32> -> tensor<1x10xi32, #encoding>
+ util.return %2 : tensor<1x10xi32, #encoding>
+}
+// CHECK-LABEL: @gather_fusion
+// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
+// CHECK: iree_linalg_ext.gather
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
+// CHECK: flow.return %[[SET_ENCODING]] :
+// CHECK: util.return %[[DISPATCH0]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#encoding = #iree_encoding.testing<>
+util.func public @multi_result_fusion(%arg0: tensor<123x456xf32>) -> (tensor<123x456xf32>, tensor<123x456xf32, #encoding>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<123x456xf32>
+ %1:2 = flow.dispatch.region -> (tensor<123x456xf32>, tensor<123x456xf32>) {
+ %3:2 = linalg.generic {
+ indexing_maps = [#map, #map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg0 : tensor<123x456xf32>, tensor<123x456xf32>)
+ outs(%0, %0 : tensor<123x456xf32>, tensor<123x456xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32, %out2: f32):
+ %4 = arith.addf %in, %in_0 : f32
+ %5 = arith.mulf %in, %in_0 : f32
+ linalg.yield %4, %5 : f32, f32
+ } -> (tensor<123x456xf32>, tensor<123x456xf32>)
+ flow.return %3#0, %3#1 : tensor<123x456xf32>, tensor<123x456xf32>
+ }
+ %2 = iree_encoding.set_encoding %1#1 : tensor<123x456xf32> -> tensor<123x456xf32, #encoding>
+ util.return %1#0, %2 : tensor<123x456xf32>, tensor<123x456xf32, #encoding>
+}
+// CHECK-LABEL: @multi_result_fusion
+// CHECK: %[[DISPATCH0:.+]]:2 = flow.dispatch.region
+// CHECK: %[[ELEM:.+]]:2 = linalg.generic
+// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[ELEM]]#1
+// CHECK: flow.return %[[ELEM]]#0, %[[SET_ENCODING]]
+// CHECK: util.return %[[DISPATCH0]]#0, %[[DISPATCH0]]#1