Update resource placement and transfer for barrier operations (#19725)
Barriers indicate within device blocking. The results of a barrier
should not transfer to another location, otherwise there would be a
transfer and not a barrier.
---------
Co-authored-by: Ben Vanik <ben.vanik@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index 4ff656c..e8dbaae 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -394,6 +394,12 @@
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
+ .Case([&](IREE::Stream::AsyncBarrierOp op) {
+ auto &tiedUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.getOperand(0)),
+ DFX::Resolution::REQUIRED);
+ getState() ^= tiedUsage.getState();
+ })
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto &sourceUsage = solver.getElementFor<ValueResourceUsage>(
@@ -716,6 +722,12 @@
getState() ^= resultUsage.getState();
}
})
+ .Case([&](IREE::Stream::AsyncBarrierOp op) {
+ auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(op.getResult()),
+ DFX::Resolution::OPTIONAL);
+ getState() ^= resultUsage.getState();
+ })
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_READ);
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index 4fb2216..df9e548 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -138,6 +138,26 @@
util.global private @device : !hal.device
+// CHECK-LABEL: @tensorBarrierDispatch
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+util.func public @tensorBarrierDispatch(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
+ %c0 = arith.constant 0 : index
+ %barrier = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
+ %0 = flow.dispatch @ex::@entry[%c0](%barrier) : (tensor<?x128xi8>{%dim0}) -> tensor<?x128xi8>{%dim0}
+
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[BARRIER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[DIM0]]} -> !stream.resource<*>
+ // CHECK: %[[C0_2:.+]] = arith.constant 0 : index
+ // CHECK: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device>) tensor<?x128xi8>{%arg2} : index
+ // CHECK: %[[DISP:.+]] = stream.async.dispatch on(#hal.device.affinity<@device>) @ex::@entry[%[[C0]]](%[[BARRIER]][%[[C0_2]] to %[[DIM0]] for %[[DIM0]]])
+ // CHECK: util.return %[[DISP]], %[[SIZE]]
+ util.return %0 : tensor<?x128xi8>
+}
+
+// -----
+
+util.global private @device : !hal.device
+
// CHECK-LABEL: @tensorTransfer
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index)
util.func public @tensorTransfer(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
index 4512245..051b6a1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
@@ -64,7 +64,9 @@
Value resource = convertedOperand[0];
Value resourceSize = convertedOperand[1];
auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand);
- if (affinityAttr != requiredAffinityAttr) {
+ bool isBarrier = resource.getDefiningOp() &&
+ isa<IREE::Stream::AsyncBarrierOp>(resource.getDefiningOp());
+ if (affinityAttr != requiredAffinityAttr && !isBarrier) {
resource = builder.create<IREE::Stream::AsyncTransferOp>(
loc, resource.getType(), resource, resourceSize, resourceSize,
affinityAttr, requiredAffinityAttr);