Adding flow.tensor.alloc op for unique allocations. (#13081)
Each op instance will lower into its own stream.resource.alloc. Since
this is a big pessimization it should always be avoided unless
absolutely required - and if required then users should try a different
approach if performance-sensitive.
Reverts a30beed741f58ffaff04c52e7abf33f07d80a082.
Fixes #11251.
Fixes #13022.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index d35e3a6..0e24d89 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -41,6 +41,20 @@
// Folding utilities
//===----------------------------------------------------------------------===//
+// Erases an op if it has no uses.
+// This is to support ops that are "pure" but can't be marked as such because
+// the MLIR CSE pass would deduplicate them.
+template <typename Op>
+struct ElideUnusedOp : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ if (!op.use_empty()) return failure();
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
// Returns true if |value| is definitely empty at runtime.
static bool isTensorZeroElements(Value value) {
auto type = value.getType().dyn_cast<ShapedType>();
@@ -508,6 +522,10 @@
return true;
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.constant
+//===----------------------------------------------------------------------===//
+
OpFoldResult TensorConstantOp::fold(FoldAdaptor operands) {
auto dynamicType = getType();
if (dynamicType.getNumDynamicDims() == 0) {
@@ -550,6 +568,10 @@
results.insert<ExpandDynamicShapeConstant>(context);
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.tie_shape
+//===----------------------------------------------------------------------===//
+
OpFoldResult TensorTieShapeOp::fold(FoldAdaptor operands) {
if (getDynamicDims().empty()) {
return getOperand();
@@ -563,6 +585,10 @@
context);
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.reshape
+//===----------------------------------------------------------------------===//
+
OpFoldResult TensorReshapeOp::fold(FoldAdaptor operands) {
auto sourceType = getSource().getType().cast<ShapedType>();
auto resultType = getResult().getType().cast<ShapedType>();
@@ -692,10 +718,9 @@
results.insert<ResolveShapedDim>(context);
}
-void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<FoldSplatLoadIntoPrimitive>(context);
-}
+//===----------------------------------------------------------------------===//
+// flow.tensor.load
+//===----------------------------------------------------------------------===//
OpFoldResult TensorLoadOp::fold(FoldAdaptor operands) {
if (auto source = operands.getSource().dyn_cast_or_null<ElementsAttr>()) {
@@ -710,6 +735,15 @@
return {};
}
+void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldSplatLoadIntoPrimitive>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.store
+//===----------------------------------------------------------------------===//
+
OpFoldResult TensorStoreOp::fold(FoldAdaptor operands) {
auto value = operands.getValue();
if (!value) return {};
@@ -733,11 +767,28 @@
return {};
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.alloc
+//===----------------------------------------------------------------------===//
+
+void TensorAllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideUnusedOp<TensorAllocOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.empty
+//===----------------------------------------------------------------------===//
+
void TensorEmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): fold static shapes into dims.
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.splat
+//===----------------------------------------------------------------------===//
+
void TensorSplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): canonicalize splat+slice to smaller splat.
@@ -746,6 +797,10 @@
results.insert<FoldSplatReshapeIntoSplat>(context);
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.clone
+//===----------------------------------------------------------------------===//
+
OpFoldResult TensorCloneOp::fold(FoldAdaptor operands) {
if (auto operand = operands.getOperand()) {
// Constants always fold.
@@ -772,6 +827,10 @@
results.insert<ReplaceOpIfTensorOperandEmpty<TensorCloneOp, 0, 0>>(context);
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.slice
+//===----------------------------------------------------------------------===//
+
// Slices tensor from start to (start + length) exclusively at dim.
static ElementsAttr tensorSlice(ElementsAttr tensor, uint64_t dim,
uint64_t start, uint64_t length) {
@@ -830,6 +889,10 @@
results.insert<ReplaceOpIfTensorOperandEmpty<TensorSliceOp, 0, 0>>(context);
}
+//===----------------------------------------------------------------------===//
+// flow.tensor.update
+//===----------------------------------------------------------------------===//
+
static ElementsAttr tensorUpdate(ElementsAttr update, ElementsAttr target,
ArrayRef<Attribute> startIndicesAttrs) {
auto updateType = update.getType().cast<ShapedType>();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 66798bc..7a5b021 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -1395,82 +1395,6 @@
}
//===----------------------------------------------------------------------===//
-// flow.tensor.clone
-//===----------------------------------------------------------------------===//
-
-LogicalResult TensorCloneOp::verify() {
- if (failed(verifyOpDynamicDims(getOperation(), {getOperand()},
- getArgumentDims())) ||
- failed(verifyOpDynamicDims(getOperation(), {getResult()},
- getArgumentDims()))) {
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// flow.tensor.empty
-//===----------------------------------------------------------------------===//
-
-LogicalResult TensorEmptyOp::verify() {
- if (failed(verifyOpDynamicDims(getOperation(), {getResult()},
- getResultDims()))) {
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// flow.tensor.load
-//===----------------------------------------------------------------------===//
-
-LogicalResult TensorLoadOp::verify() {
- if (failed(verifyOpDynamicDims(getOperation(), {getSource()},
- getSourceDims()))) {
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// flow.tensor.slice
-//===----------------------------------------------------------------------===//
-
-LogicalResult TensorSliceOp::verify() {
- if (failed(verifyOpDynamicDims(getOperation(), {getSource()},
- getSourceDims())) ||
- failed(verifyOpDynamicDims(getOperation(), {getResult()},
- getResultDims()))) {
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// flow.tensor.splat
-//===----------------------------------------------------------------------===//
-
-LogicalResult TensorSplatOp::verify() {
- if (failed(verifyOpDynamicDims(getOperation(), {getResult()},
- getResultDims()))) {
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// flow.tensor.store
-//===----------------------------------------------------------------------===//
-
-LogicalResult TensorStoreOp::verify() {
- if (failed(verifyOpDynamicDims(getOperation(), {getTarget()},
- getTargetDims()))) {
- return failure();
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// flow.tensor.tie_shape
//===----------------------------------------------------------------------===//
@@ -1526,6 +1450,94 @@
}
//===----------------------------------------------------------------------===//
+// flow.tensor.load
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorLoadOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getSource()},
+ getSourceDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.store
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorStoreOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getTarget()},
+ getTargetDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.alloc
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorAllocOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getResult()},
+ getResultDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.empty
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorEmptyOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getResult()},
+ getResultDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.splat
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorSplatOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getResult()},
+ getResultDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.clone
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorCloneOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getOperand()},
+ getArgumentDims())) ||
+ failed(verifyOpDynamicDims(getOperation(), {getResult()},
+ getArgumentDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.slice
+//===----------------------------------------------------------------------===//
+
+LogicalResult TensorSliceOp::verify() {
+ if (failed(verifyOpDynamicDims(getOperation(), {getSource()},
+ getSourceDims())) ||
+ failed(verifyOpDynamicDims(getOperation(), {getResult()},
+ getResultDims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// flow.tensor.update
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 2af3102..4c10f6a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -1180,6 +1180,42 @@
let hasFolder = 1;
}
+def FLOW_TensorAllocOp : FLOW_Op<"tensor.alloc", [
+ Util_ShapeAwareOp,
+ MemoryEffects<[MemAlloc]>,
+]> {
+ let summary = [{an empty tensor allocation with undefined contents}];
+ let description = [{
+ Returns a new transient tensor allocation with undefined contents.
+ Subsequent writes must populate any ranges of the tensor that are later
+ read. The resulting tensor may be long-lived and allocated as part of a
+ dedicated allocation. Prefer using `flow.tensor.empty` whenever possible as
+ this op disables nearly all allocation-related optimizations performed by
+ the compiler. The presence of this op is often an indication of an improper
+ lowering.
+ }];
+
+ let arguments = (ins
+ FLOW_ShapeDynamicDims:$result_dims
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ `:` type($result) (`{` $result_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return ValueRange{}; }
+ ValueRange getResultDynamicDims(unsigned idx) { return getResultDims(); }
+ }];
+
+ let hasVerifier = 1;
+ let hasCanonicalizer = 1;
+}
+
def FLOW_TensorEmptyOp : FLOW_PureOp<"tensor.empty", [
FLOW_StreamableOp,
Util_ShapeAwareOp,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index de53448..32d5248 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -234,9 +234,27 @@
// -----
+// CHECK-LABEL: @allocDims
+// CHECK-SAME: (%[[DIM:.+]]: index)
+func.func @allocDims(%dim: index) -> (index, index, index) {
+ // CHECK-NOT: flow.tensor.alloc
+ %0 = flow.tensor.alloc : tensor<4x?x0xf32>{%dim}
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %d0 = tensor.dim %0, %c0 : tensor<4x?x0xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<4x?x0xf32>
+ %d2 = tensor.dim %0, %c2 : tensor<4x?x0xf32>
+ // CHECK: return %c4, %[[DIM]], %c0
+ return %d0, %d1, %d2 : index, index, index
+}
+
+// -----
+
// CHECK-LABEL: @emptyDims
// CHECK-SAME: (%[[DIM:.+]]: index)
func.func @emptyDims(%dim: index) -> (index, index, index) {
+ // CHECK-NOT: flow.tensor.empty
%0 = flow.tensor.empty : tensor<4x?x0xf32>{%dim}
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
index d2adc5d..e824af7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
@@ -73,6 +73,15 @@
// -----
+// CHECK-LABEL: @tensorAlloc
+func.func @tensorAlloc(%arg0: index) -> tensor<?x0x1xf32> {
+ // CHECK-NEXT: = flow.tensor.alloc : tensor<?x0x1xf32>{%arg0}
+ %0 = flow.tensor.alloc : tensor<?x0x1xf32>{%arg0}
+ return %0 : tensor<?x0x1xf32>
+}
+
+// -----
+
// CHECK-LABEL: @tensorEmpty
func.func @tensorEmpty(%arg0: index) -> tensor<?x0x1xf32> {
// CHECK-NEXT: = flow.tensor.empty : tensor<?x0x1xf32>{%arg0}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 3c1dca0..1272523 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -346,8 +346,134 @@
// hal.device.queue.execute
//===----------------------------------------------------------------------===//
+// Returns true if |before| is always executed by the time |after| is reached.
+// NOTE: this is currently very conservative and only looks for ops in the
+// same basic block. We need an abstract interpreter to do much more as we'd
+// need to track conditionals/branching logic.
+static bool isOpAlwaysExecutedWith(Operation *before, Operation *after) {
+ if (before == after) return true;
+ if (before->getBlock() != after->getBlock()) return false;
+ return before->isBeforeInBlock(after);
+}
+
+// Returns true if |op| was hoisted before |insertBefore| without breaking
+// SSA invariants. Returns false if no IR modifications were made.
+static bool tryHoistOpBeforeUser(Operation *op, Operation *insertBefore) {
+ if (op == insertBefore) return false;
+
+ // Currently conservative - should be doing a domination check.
+ if (op->getBlock() != insertBefore->getBlock()) {
+ // Today only doing within the same block.
+ return false;
+ }
+
+ // Ensure all operands are defined above the insertion target.
+ // TODO(benvanik): hoist dependent ops too (constants are common).
+ if (!llvm::all_of(op->getOperands(), [&](Value operand) {
+ auto *definingOp = operand.getDefiningOp();
+ if (!definingOp || definingOp->getBlock() != insertBefore->getBlock()) {
+ // Function/block args or values defined outside the insertion block
+ // are ok since we are limiting to 1 block.
+ return true;
+ }
+ return definingOp->isBeforeInBlock(insertBefore);
+ })) {
+ return false;
+ }
+
+ // Should be safe to hoist the op 🤞.
+ op->moveBefore(insertBefore);
+ return true;
+}
+
namespace {
+/// Swaps a device queue barrier with an immediate host fence signal when the
+/// wait fence is immediately resolved (null).
+struct ImmediatelyResolveDeviceQueueBarrier
+ : public OpRewritePattern<DeviceQueueExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(DeviceQueueExecuteOp barrierOp,
+ PatternRewriter &rewriter) const override {
+ // Only looking for ops performing basic barriers.
+ if (!barrierOp.isBarrier()) return failure();
+
+ // Check for whether we know the wait fence is immediately resolved in the
+ // local scope. A more involved data flow analysis would let us handle more
+ // cases (function calls, block edges, etc) that commonly arise.
+ if (!isa_and_nonnull<IREE::Util::NullOp>(
+ barrierOp.getWaitFence().getDefiningOp())) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::HAL::FenceSignalOp>(
+ barrierOp, barrierOp.getSignalFence());
+ return success();
+ }
+};
+
+/// Aliases a signal fence to a wait fence when there's a direct execution
+/// dependency through the barrier. This only checks the local scope but could
+/// be extended across CFG boundaries.
+///
+/// Example:
+/// %fence0 = hal.fence.create
+/// hal.device.queue.execute signal(%fence0)
+/// hal.device.queue.execute wait(%fence0) signal(%fence1)
+/// ->
+/// hal.device.queue.execute signal(%fence1)
+struct HoistDeviceQueueBarrierChain
+ : public OpRewritePattern<DeviceQueueExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(DeviceQueueExecuteOp barrierOp,
+ PatternRewriter &rewriter) const override {
+ // Only looking for ops performing basic barriers.
+ if (!barrierOp.isBarrier()) return failure();
+
+ // See if we can observe the original fence creation in the local scope.
+ auto waitFence = barrierOp.getWaitFence();
+ auto createOp =
+ dyn_cast_or_null<IREE::HAL::FenceCreateOp>(waitFence.getDefiningOp());
+ if (!createOp) {
+ return rewriter.notifyMatchFailure(barrierOp,
+ "cannot analyze wait fence creation");
+ }
+
+ // Today this simple pattern only deals with the local block. We should
+ // extend this to support a must-be-executed context such that we can deal
+ // with the common case of some basic control flow handling errors/etc.
+ if (createOp->getBlock() != barrierOp->getBlock()) {
+ return rewriter.notifyMatchFailure(
+ barrierOp,
+ "create and barrier are in different blocks; analysis TBD");
+ }
+
+ // To ensure we don't break SSA invariants we need to only hoist if the
+ // signal fence is or can be defined before all users of the waitFence we
+ // are replacing. Note that because we are only matching on ops within
+ // the same block if we don't have the defining op it means it's a block
+ // argument and is always available.
+ auto signalFence = barrierOp.getSignalFence();
+ auto signalDefiningOp = signalFence.getDefiningOp();
+ if (signalDefiningOp) {
+ // Try to hoist up to the defining op.
+ if (!tryHoistOpBeforeUser(signalDefiningOp, createOp)) {
+ return rewriter.notifyMatchFailure(
+ barrierOp, "signal defining op cannot be hoisted");
+ }
+ }
+
+ // Replace the original fence with the new one and drop the create.
+ rewriter.replaceAllUsesWith(waitFence, signalFence);
+ rewriter.eraseOp(createOp);
+
+ // Drop the barrier now that it is a no-op.
+ rewriter.eraseOp(barrierOp);
+
+ return success();
+ }
+};
+
/// Elides queue barriers that are used for sequencing fences when the operation
/// could be performed by way of the originating queue operation.
///
@@ -361,9 +487,8 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DeviceQueueExecuteOp barrierOp,
PatternRewriter &rewriter) const override {
- // Only looking for ops performing basic barriers - ones executing commands
- // are only ever fixed up.
- if (!barrierOp.getCommandBuffers().empty()) return failure();
+ // Only looking for ops performing basic barriers.
+ if (!barrierOp.isBarrier()) return failure();
// We're looking at the wait fence on the barrier back up to the signal
// operation on that fence.
@@ -437,16 +562,6 @@
.Default([](Operation *op) { return false; });
}
- // Returns true if |before| is always executed by the time |after| is reached.
- // NOTE: this is currently very conservative and only looks for ops in the
- // same basic block. We need an abstract interpreter to do much more as we'd
- // need to track conditionals/branching logic.
- static bool isOpAlwaysExecutedWith(Operation *before, Operation *after) {
- if (before == after) return true;
- if (before->getBlock() != after->getBlock()) return false;
- return before->isBeforeInBlock(after);
- }
-
// Updates |op| to signal |fence|.
static LogicalResult updateOpToSignalFence(Operation *op, Value fence) {
// For now we have a limited set of these ops but we should add an interface
@@ -472,6 +587,8 @@
void DeviceQueueExecuteOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
+ results.insert<ImmediatelyResolveDeviceQueueBarrier>(context);
+ results.insert<HoistDeviceQueueBarrierChain>(context);
results.insert<ElideDeviceQueueBarrierOp>(context);
}
@@ -757,6 +874,11 @@
// hal.fence.join
//===----------------------------------------------------------------------===//
+OpFoldResult FenceJoinOp::fold(FoldAdaptor operands) {
+ if (getFences().size() == 1) return getFences().front();
+ return {};
+}
+
namespace {
/// Replaces a fence join with no operands with a null value.
@@ -811,6 +933,70 @@
}
//===----------------------------------------------------------------------===//
+// hal.fence.signal
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Replaces a fence that is immediately signaled on the host with a null fence.
+/// This is only safe if there are no users of the fence between where it is
+/// created and where it is signaled. We keep things in the local block scope
+/// but a larger data flow analysis pass would be useful for propagating across
+/// block/function boundaries (common in larger loops/call trees where signal
+/// fences are passed as arguments).
+///
+/// Example:
+/// %fence = hal.fence.create
+/// hal.fence.signal<%fence : !hal.fence>
+/// ->
+/// %fence = util.null : !hal.fence
+struct ElideSignaledFence : public OpRewritePattern<FenceSignalOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(FenceSignalOp signalOp,
+ PatternRewriter &rewriter) const override {
+ auto fence = signalOp.getFence();
+ auto createOp =
+ dyn_cast_or_null<IREE::HAL::FenceCreateOp>(fence.getDefiningOp());
+ if (!createOp) return failure();
+
+ // TODO(benvanik): broader analysis - likely in a dedicated fence elision
+ // pass so we can do IPO. For now block-only.
+ if (createOp->getBlock() != signalOp->getBlock()) {
+ return rewriter.notifyMatchFailure(
+ signalOp,
+ "fence create and signal are in different blocks; analysis TBD");
+ }
+
+ // Ensure there are no uses between the create and the signal.
+ // There are probably some uses we could allow (selects, etc) but we'll
+ // reserve that for a larger analysis.
+ for (auto userOp : fence.getUsers()) {
+ if (userOp->getBlock() == signalOp->getBlock() &&
+ userOp->isBeforeInBlock(signalOp)) {
+ return rewriter.notifyMatchFailure(
+ signalOp, "interleaved fence usage; cannot elide");
+ }
+ }
+
+ // Safe to elide.
+ Value nullFence = rewriter.create<IREE::Util::NullOp>(
+ rewriter.getFusedLoc({createOp.getLoc(), signalOp.getLoc()}),
+ fence.getType());
+ rewriter.replaceAllUsesWith(fence, nullFence);
+ rewriter.eraseOp(signalOp);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void FenceSignalOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<ElideSignaledFence>(context);
+}
+
+//===----------------------------------------------------------------------===//
// hal.fence.await
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index eb8dd55..fd295b4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -640,6 +640,28 @@
return getResultSize();
}
+static LogicalResult verifyDeviceQueueFences(Operation *queueOp,
+ Value waitFence,
+ Value signalFence) {
+ if (waitFence == signalFence) {
+ return queueOp->emitOpError() << "device queue operations cannot wait and "
+ "signal on the same fence";
+ }
+ return success();
+}
+
+LogicalResult DeviceQueueAllocaOp::verify() {
+ return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
+}
+
+LogicalResult DeviceQueueDeallocaOp::verify() {
+ return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
+}
+
+LogicalResult DeviceQueueExecuteOp::verify() {
+ return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
+}
+
//===----------------------------------------------------------------------===//
// hal.executable
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 9fada78..c1d5480 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1465,6 +1465,8 @@
`:` custom<SizeAwareType>(type($result), $result_size)
attr-dict-with-keyword
}];
+
+ let hasVerifier = 1;
}
def HAL_DeviceQueueDeallocaOp : HAL_Op<"device.queue.dealloca"> {
@@ -1495,6 +1497,8 @@
`buffer` `(` $buffer `:` type($buffer) `)`
attr-dict-with-keyword
}];
+
+ let hasVerifier = 1;
}
def HAL_DeviceQueueExecuteOp : HAL_Op<"device.queue.execute"> {
@@ -1524,7 +1528,13 @@
attr-dict-with-keyword
}];
+ let extraClassDeclaration = [{
+ // Returns true if the execution represents a barrier.
+ bool isBarrier() { return getCommandBuffers().empty(); }
+ }];
+
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
def HAL_DeviceQueueFlushOp : HAL_Op<"device.queue.flush"> {
@@ -2524,6 +2534,7 @@
attr-dict-with-keyword
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -2564,6 +2575,8 @@
`<` $fence `:` type($fence) `>`
attr-dict-with-keyword
}];
+
+ let hasCanonicalizer = 1;
}
def HAL_FenceFailOp : HAL_Op<"fence.fail"> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_folding.mlir
index e40e525..878faab 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_folding.mlir
@@ -1,5 +1,69 @@
// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
+// CHECK-LABEL: @ImmediatelyResolveDeviceQueueBarrier
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[SIGNAL_FENCE:.+]]: !hal.fence)
+func.func @ImmediatelyResolveDeviceQueueBarrier(%device: !hal.device, %signal_fence: !hal.fence) {
+ %c-1_i64 = arith.constant -1 : i64
+ // CHECK-NOT: util.null
+ %wait_fence = util.null : !hal.fence
+ // CHECK-NOT: hal.device.queue.execute
+ // CHECK: hal.fence.signal<%[[SIGNAL_FENCE]] : !hal.fence>
+ hal.device.queue.execute<%device : !hal.device>
+ affinity(%c-1_i64)
+ wait(%wait_fence)
+ signal(%signal_fence)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @HoistDeviceQueueBarrierChain
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[SIGNAL_FENCE:.+]]: !hal.fence)
+func.func @HoistDeviceQueueBarrierChain(%device: !hal.device, %signal_fence: !hal.fence) {
+ %c-1_i64 = arith.constant -1 : i64
+ // CHECK-NOT: hal.fence.create
+ %temp_fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
+ // CHECK: call @external_async_fn(%[[SIGNAL_FENCE]])
+ call @external_async_fn(%temp_fence) : (!hal.fence) -> ()
+ // CHECK-NOT: hal.device.queue.execute
+ hal.device.queue.execute<%device : !hal.device>
+ affinity(%c-1_i64)
+ wait(%temp_fence)
+ signal(%signal_fence)
+ return
+}
+func.func private @external_async_fn(!hal.fence)
+
+// -----
+
+// Tests that chains of locally defined fences are handled by hoisting the fence
+// create op (when possible).
+
+// CHECK-LABEL: @HoistDeviceQueueBarrierChainOutOfOrder
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[CMD:.+]]: !hal.command_buffer, %[[WAIT_FENCE:.+]]: !hal.fence)
+func.func @HoistDeviceQueueBarrierChainOutOfOrder(%device: !hal.device, %cmd: !hal.command_buffer, %wait_fence: !hal.fence) -> !hal.fence {
+ %c-1_i64 = arith.constant -1 : i64
+ // CHECK: %[[FENCE1:.+]] = hal.fence.create {{.+}} {test.fence1}
+ %fence0 = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence attributes {test.fence0}
+ // CHECK: hal.device.queue.execute{{.+}} wait(%[[WAIT_FENCE]]) signal(%[[FENCE1]]) commands([%[[CMD]]])
+ hal.device.queue.execute<%device : !hal.device>
+ affinity(%c-1_i64)
+ wait(%wait_fence)
+ signal(%fence0)
+ commands([%cmd])
+ // CHECK-NOT: hal.fence.create
+ %fence1 = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence attributes {test.fence1}
+ // CHECK-NOT: hal.device.queue.execute
+ hal.device.queue.execute<%device : !hal.device>
+ affinity(%c-1_i64)
+ wait(%fence0)
+ signal(%fence1)
+ // CHECK: return %[[FENCE1]]
+ return %fence1 : !hal.fence
+}
+
+// -----
+
// CHECK-LABEL: @ElideDeviceQueueBarrierOp
// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device,
// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir
index ec9487e..2127d89 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir
@@ -13,6 +13,18 @@
// -----
+// Tests that a fence join with 1 operand folds into that operand.
+
+// CHECK-LABEL: @fence_join_one
+// CHECK-SAME: %[[ARG:.+]]: !hal.fence
+func.func @fence_join_one(%arg: !hal.fence) -> !hal.fence {
+ %join = hal.fence.join at([%arg]) -> !hal.fence
+ // CHECK: return %[[ARG]]
+ return %join : !hal.fence
+}
+
+// -----
+
// Tests that a fence join with no operands folds into a util.null.
// CHECK-LABEL: @fence_join_empty
@@ -28,12 +40,12 @@
// Tests that known null fences are dropped from joins.
// CHECK-LABEL: @fence_join_null
-// CHECK-SAME: %[[ARG:.+]]: !hal.fence
-func.func @fence_join_null(%arg: !hal.fence) -> !hal.fence {
+// CHECK-SAME: (%[[ARG0:.+]]: !hal.fence, %[[ARG1:.+]]: !hal.fence)
+func.func @fence_join_null(%arg0: !hal.fence, %arg1: !hal.fence) -> !hal.fence {
// CHECK-NOT: util.null
%null = util.null : !hal.fence
- // CHECK: %[[JOIN:.+]] = hal.fence.join at([%[[ARG]]]) -> !hal.fence
- %join = hal.fence.join at([%arg, %null]) -> !hal.fence
+ // CHECK: %[[JOIN:.+]] = hal.fence.join at([%[[ARG0]], %[[ARG1]]]) -> !hal.fence
+ %join = hal.fence.join at([%arg0, %null, %arg1]) -> !hal.fence
// CHECK: return %[[JOIN]]
return %join : !hal.fence
}
@@ -53,6 +65,46 @@
// -----
+// Elides fences that are immediately signaled on the host.
+// This requires that there are no ops using the fence value between the time it
+// is created and the time it is signaled.
+
+// CHECK-LABEL: @fence_elide_signaled
+func.func @fence_elide_signaled(%device: !hal.device) -> !hal.fence {
+ // CHECK-NOT: hal.fence.create
+ %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
+ // Ok to have other things inbetween so long as they don't touch the fence.
+ // CHECK: call @external_nop_call
+ call @external_nop_call() : () -> ()
+ // CHECK-NOT: hal.fence.signal
+ hal.fence.signal<%fence : !hal.fence>
+ // CHECK: %[[FENCE:.+]] = util.null : !hal.fence
+ // CHECK: return %[[FENCE]]
+ return %fence : !hal.fence
+}
+func.func private @external_nop_call()
+
+// -----
+
+// Ensures that immediate fence signals aren't elided if the fence may be waited
+// on between when it is created and when it is signaled.
+
+// CHECK-LABEL: @fence_cannot_elide_signaled
+func.func @fence_cannot_elide_signaled(%device: !hal.device) -> !hal.fence {
+ // CHECK: hal.fence.create
+ %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
+ // Block the elision as the external call may wait on the fence.
+ // CHECK: call @external_wait_call
+ call @external_wait_call(%fence) : (!hal.fence) -> ()
+ // CHECK: hal.fence.signal
+ hal.fence.signal<%fence : !hal.fence>
+ // CHECK: return
+ return %fence : !hal.fence
+}
+func.func private @external_wait_call(!hal.fence)
+
+// -----
+
// Tests that awaits with no fences are elided.
// CHECK-LABEL: @fence_await_none
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 3b555f0..11c5055 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -121,6 +121,7 @@
static void addCleanupPatterns(OpPassManager &passManager) {
// Standard MLIR cleanup.
+ passManager.addPass(mlir::createCSEPass());
passManager.addPass(mlir::createCanonicalizerPass());
passManager.addPass(mlir::createCSEPass());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 3ba27af..b1310ee 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -67,6 +67,22 @@
}
};
+struct ConvertTensorAllocOp
+ : public OpConversionPattern<IREE::Flow::TensorAllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorAllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type unknownType = IREE::Stream::ResourceType::get(getContext());
+ auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(),
+ op.getResultDims(), rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Stream::ResourceAllocOp>(
+ op, unknownType, resultSize,
+ /*uninitialized*/ true, getAffinityFor(op));
+ return success();
+ }
+};
+
struct ConvertTensorEmptyOp
: public OpConversionPattern<IREE::Flow::TensorEmptyOp> {
using OpConversionPattern::OpConversionPattern;
@@ -733,11 +749,12 @@
void populateFlowToStreamConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.insert<
- ConvertTensorReshapeOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
- ConvertTensorCloneOp, ConvertTensorSliceOp, ConvertTensorUpdateOp,
- ConvertTensorLoadOp, ConvertTensorStoreOp, ConvertTensorTraceOp>(
- typeConverter, context);
+ patterns
+ .insert<ConvertTensorReshapeOp, ConvertTensorAllocOp,
+ ConvertTensorEmptyOp, ConvertTensorSplatOp, ConvertTensorCloneOp,
+ ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
+ ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter,
+ context);
patterns.insert<ConvertChannelCountOp, ConvertChannelDefaultOp,
ConvertChannelRankOp>(typeConverter, context);
patterns
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index 6d2946a..166b52b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -46,6 +46,18 @@
// -----
+// CHECK-LABEL: @tensorAlloc
+// CHECK-SAME: (%[[DIM0:.+]]: index)
+func.func @tensorAlloc(%dim0: index) -> tensor<?x0xf32> {
+ // CHECK: %[[ALLOC_SIZE:.+]] = stream.tensor.sizeof tensor<?x0xf32>{%[[DIM0]]}
+ // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource<*>{%[[ALLOC_SIZE]]}
+ %0 = flow.tensor.alloc : tensor<?x0xf32>{%dim0}
+ // CHECK: return %[[ALLOC]]
+ return %0 : tensor<?x0xf32>
+}
+
+// -----
+
// CHECK-LABEL: @tensorEmpty
// CHECK-SAME: (%[[DIM0:.+]]: index)
func.func @tensorEmpty(%dim0: index) -> tensor<?x0xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 6e44cbe..ef44afc 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -896,7 +896,7 @@
constantOp.getLoc(), rewriter.getIndexType(),
TypeAttr::get(constantOp.getResultEncoding()),
constantOp.getResultEncodingDims(), constantOp.getAffinityAttr());
- rewriter.replaceOpWithNewOp<TensorEmptyOp>(
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorEmptyOp>(
constantOp, constantOp.getResult().getType(),
constantOp.getResultEncoding(), constantOp.getResultEncodingDims(),
resultSize, constantOp.getAffinityAttr());
@@ -2375,6 +2375,82 @@
}
//===----------------------------------------------------------------------===//
+// stream.timepoint.chain_external
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Elides a timepoint chaining operation when the chained timepoint is directly
+// usable from imported external values. This covers the common case where an
+// imported fence is chained with a new fence - since fences are single-shot the
+// new fence can be replaced with the imported fence. We rely on MemAlloc to
+// detect when the external fence is one created for chaining vs an argument/etc
+// that we may not be able to elide.
+//
+// Example:
+// %timepoint = stream.timepoint.import %arg_fence
+// %chained_fence = hal.fence.create
+// stream.timepoint.chain_external %timepoint => (%chained_fence : !hal.fence)
+// ->
+// %chained_fence = %arg_fence
+struct PassThroughChainExternal
+ : public OpRewritePattern<TimepointChainExternalOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointChainExternalOp op,
+ PatternRewriter &rewriter) const override {
+ // Try to get the original external values that we want to chain.
+ auto importOp = dyn_cast_or_null<IREE::Stream::TimepointImportOp>(
+ op.getAwaitTimepoint().getDefiningOp());
+ if (!importOp) {
+ return rewriter.notifyMatchFailure(
+ op, "timepoint not imported from an external value");
+ }
+
+ // The imported external values must match the types of the chained external
+ // values as we'll be doing a SSA value replacement and can't change types.
+ if (!llvm::all_of_zip(importOp.getOperands(), op.getExternalValues(),
+ [](Value importValue, Value chainValue) {
+ return importValue.getType() ==
+ chainValue.getType();
+ })) {
+ return rewriter.notifyMatchFailure(
+ op,
+ "can only chain when external value types match between the import "
+ "and chain op");
+ }
+
+ // We can only replace external values that are locally allocated as this
+ // pattern is effectively just killing the allocation - if it comes from
+ // above/globals/external functions then we can't change things.
+ //
+ // TODO(benvanik): improve this to handle more external value types; for now
+ // only !hal.fence is used in practice and that is MemAlloc.
+ for (auto externalValue : op.getExternalValues()) {
+ auto definingOp = dyn_cast_or_null<MemoryEffectOpInterface>(
+ externalValue.getDefiningOp());
+ if (!definingOp || !definingOp.hasEffect<MemoryEffects::Allocate>()) {
+ return rewriter.notifyMatchFailure(
+ op, "external chained value is not locally allocated");
+ }
+ }
+
+ // Should be safe to now replace the allocated external values with the
+ // original imported ones.
+ rewriter.replaceAllUsesWith(op.getExternalValues(), importOp.getOperands());
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void TimepointChainExternalOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.insert<PassThroughChainExternal>(context);
+}
+
+//===----------------------------------------------------------------------===//
// stream.timepoint.join
//===----------------------------------------------------------------------===//
@@ -2483,6 +2559,45 @@
namespace {
+// Extremely basic check for whether a source |resource| is immediately resolved
+// or may be part of a timeline sequence.
+static bool isSourceImmediatelyResolved(Value resource) {
+ // TODO(benvanik): data flow analysis/at least walk up tied ops. For now we
+ // err on the conservative side and only check for a few common scenarios.
+ auto *definingOp = resource.getDefiningOp();
+ if (!definingOp) return false;
+ return TypeSwitch<Operation *, bool>(definingOp)
+ .Case<IREE::Stream::ResourceAllocOp, IREE::Stream::TensorImportOp>(
+ [](auto op) { return true; })
+ .Default([](auto op) { return false; });
+}
+
+// Elides barriers that source their operands from immediate operations.
+// These barriers are implicitly resolved and need not be modeled.
+//
+// Example:
+// %r0a = stream.resource.alloc
+// %r0b, %r0ready = stream.timepoint.barrier %r0a
+// ->
+// %r0a = stream.resource.alloc
+// %r0b = %r0a
+// %r0ready = stream.timepoint.immediate
+struct ElideImmediateBarrier : public OpRewritePattern<TimepointBarrierOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointBarrierOp barrierOp,
+ PatternRewriter &rewriter) const override {
+ if (!isSourceImmediatelyResolved(barrierOp.getResource())) {
+ // Could not analyze or found to be a timeline op.
+ return failure();
+ }
+ auto immediateTimepoint =
+ rewriter.create<IREE::Stream::TimepointImmediateOp>(barrierOp.getLoc());
+ rewriter.replaceOp(barrierOp,
+ {barrierOp.getResource(), immediateTimepoint});
+ return success();
+ }
+};
+
// Walks up the tied op SSA def chain to find a stream.timepoint.await op that
// produces the resource. Returns nullptr if no await op is found or local
// analysis cannot determine the source (spans across a branch, etc).
@@ -2543,6 +2658,7 @@
void TimepointBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
+ results.insert<ElideImmediateBarrier>(context);
results.insert<ChainTimepoints>(context);
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index aba0a5c..d7e570f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -3147,6 +3147,8 @@
`(` $external_values `:` type($external_values) `)`
attr-dict-with-keyword
}];
+
+ let hasCanonicalizer = 1;
}
def Stream_TimepointJoinOp : Stream_PureOp<"timepoint.join", [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel
index 4b5fc16..db17d5a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel
@@ -27,12 +27,10 @@
"resource_ops.mlir",
"tensor_folding.mlir",
"tensor_ops.mlir",
- # TODO(#11251): Enable the test.
- # "timepoint_folding.mlir",
+ "timepoint_folding.mlir",
"timepoint_ops.mlir",
],
include = ["*.mlir"],
- exclude = ["timepoint_folding.mlir"],
),
cfg = "//compiler:lit.cfg.py",
tools = [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
index a6b8498..93e102b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
@@ -25,6 +25,7 @@
"resource_ops.mlir"
"tensor_folding.mlir"
"tensor_ops.mlir"
+ "timepoint_folding.mlir"
"timepoint_ops.mlir"
TOOLS
FileCheck
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
index f8edd40..3fa662b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
@@ -23,6 +23,36 @@
// -----
+// CHECK-LABEL: @PassThroughChainExternal
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[ARG_FENCE:.+]]: !hal.fence)
+func.func @PassThroughChainExternal(%device: !hal.device, %arg_fence: !hal.fence) -> !hal.fence {
+ // CHECK-NOT: stream.timepoint.import
+ %timepoint = stream.timepoint.import %arg_fence : (!hal.fence) => !stream.timepoint
+ // CHECK-NOT: hal.fence.create
+ %chained_fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
+ // CHECK-NOT: stream.timepoint.chain_external
+ stream.timepoint.chain_external %timepoint => (%chained_fence : !hal.fence)
+ // CHECK: return %[[ARG_FENCE]]
+ return %chained_fence : !hal.fence
+}
+
+// -----
+
+// Tests that external chained values we can't analyze aren't replaced.
+
+// CHECK-LABEL: @DontPassThroughChainExternal
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[ARG_FENCE:.+]]: !hal.fence, %[[CHAINED_FENCE:.+]]: !hal.fence)
+func.func @DontPassThroughChainExternal(%device: !hal.device, %arg_fence: !hal.fence, %chained_fence: !hal.fence) -> !hal.fence {
+ // CHECK: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[ARG_FENCE]]
+ %timepoint = stream.timepoint.import %arg_fence : (!hal.fence) => !stream.timepoint
+ // CHECK: stream.timepoint.chain_external %[[TIMEPOINT]] => (%[[CHAINED_FENCE]]
+ stream.timepoint.chain_external %timepoint => (%chained_fence : !hal.fence)
+ // CHECK: return %[[CHAINED_FENCE]]
+ return %chained_fence : !hal.fence
+}
+
+// -----
+
// CHECK-LABEL: @FoldTimepointJoinOp
func.func @FoldTimepointJoinOp(%arg0: !stream.timepoint) -> !stream.timepoint {
// CHECK-NOT: stream.timepoint.join
@@ -77,6 +107,20 @@
// -----
+// CHECK-LABEL: @ElideImmediateBarrier
+// CHECK-SAME: (%[[SIZE:.+]]: index)
+func.func @ElideImmediateBarrier(%size: index) -> (!stream.resource<external>, !stream.timepoint) {
+ // CHECK-DAG: %[[RESOURCE:.+]] = stream.resource.alloc
+ %r0 = stream.resource.alloc uninitialized : !stream.resource<external>{%size}
+ // CHECK-DAG: %[[FENCE:.+]] = stream.timepoint.immediate
+ // CHECK-NOT: stream.timepoint.barrier
+ %r1, %r1t = stream.timepoint.barrier %r0 : !stream.resource<external>{%size} => !stream.timepoint
+ // CHECK: return %[[RESOURCE]], %[[FENCE]]
+ return %r1, %r1t : !stream.resource<external>, !stream.timepoint
+}
+
+// -----
+
// CHECK-LABEL: @ChainTimepoints
// CHECK-SAME: (%[[FENCE:.+]]: !stream.timepoint, %[[SOURCE:.+]]: !stream.resource<external>)
func.func @ChainTimepoints(%fence: !stream.timepoint, %source: !stream.resource<external>) -> (!stream.resource<external>, !stream.timepoint) {