Making the ElideAsyncCopiesPass support stream.async.slice. (#16667)
This is only looking for some very specific patterns that we can
(mostly) verify are safe. This can definitely be improved and widened to
more cases but this is better than the nothing we do today for slice
ops. It does have some gotchas around potentially increasing lifetime of
source tensors but we don't have a good way of measuring that today so
🤷.
Fixes #16647.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceHazards.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceHazards.cpp
index c035eab..dcfc6c9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceHazards.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceHazards.cpp
@@ -25,29 +25,6 @@
namespace mlir::iree_compiler::IREE::Stream {
//===----------------------------------------------------------------------===//
-// Access utilities
-//===----------------------------------------------------------------------===//
-
-// TODO(#6972): move to StreamTypes.h.
-
-static bool isReadOnly(ResourceAccessBitfield access) {
- return access == ResourceAccessBitfield::Read;
-}
-
-static bool doesRangeOverlap(AsyncAccessRange &lhs, AsyncAccessRange &rhs) {
- if (lhs.resource != rhs.resource)
- return false;
-
- if (lhs.end == rhs.start || lhs.start == rhs.end) {
- // Adjacent but not overlapping.
- return false;
- }
-
- // TODO(#6972): use adjacency tracking sets to handle out-of-order ranges.
- return true;
-}
-
-//===----------------------------------------------------------------------===//
// Hazard analysis
//===----------------------------------------------------------------------===//
@@ -87,16 +64,8 @@
llvm::interleave(
allProducerRanges, llvm::dbgs(),
[&](auto range) {
- llvm::dbgs() << " " << stringifyResourceAccessBitfield(range.access)
- << " ";
- range.resource.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << "[";
- range.start.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << " to ";
- range.end.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << " for ";
- range.length.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << "]";
+ llvm::dbgs() << " ";
+ range.print(llvm::dbgs(), *asmState);
},
"\n");
llvm::dbgs() << "\n";
@@ -106,16 +75,8 @@
llvm::interleave(
allConsumerRanges, llvm::dbgs(),
[&](auto range) {
- llvm::dbgs() << " " << stringifyResourceAccessBitfield(range.access)
- << " ";
- range.resource.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << "[";
- range.start.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << " to ";
- range.end.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << " for ";
- range.length.printAsOperand(llvm::dbgs(), *asmState);
- llvm::dbgs() << "]";
+ llvm::dbgs() << " ";
+ range.print(llvm::dbgs(), *asmState);
},
"\n");
llvm::dbgs() << "\n";
@@ -124,12 +85,15 @@
for (auto &producerRange : allProducerRanges) {
for (auto &consumerRange : allConsumerRanges) {
if (producerRange.resource == consumerRange.resource) {
- if (!doesRangeOverlap(producerRange, consumerRange)) {
+ // TODO(#6972): use adjacency tracking sets to handle out-of-order
+ // ranges. The basic overlap check only handles perfectly adjacent
+ // ranges.
+ if (!IREE::Stream::AsyncAccessRange::mayOverlap(producerRange,
+ consumerRange)) {
// No overlap - no hazard.
continue;
}
- if (isReadOnly(producerRange.access) &&
- isReadOnly(consumerRange.access)) {
+ if (producerRange.isReadOnly() && consumerRange.isReadOnly()) {
// Read-read is not a hazard.
continue;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 94a8414..93d0b8f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -80,6 +80,39 @@
llvm::cl::init(IREE::Stream::MemoryModel::Unified));
//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+void AsyncAccessRange::print(llvm::raw_ostream &os, AsmState &asmState) {
+ os << stringifyResourceAccessBitfield(access) << " ";
+ resource.printAsOperand(os, asmState);
+ os << "[";
+ start.printAsOperand(os, asmState);
+ os << " to ";
+ end.printAsOperand(os, asmState);
+ os << " for ";
+ length.printAsOperand(os, asmState);
+ os << "]";
+}
+
+// static
+bool AsyncAccessRange::mayOverlap(const AsyncAccessRange &lhs,
+ const AsyncAccessRange &rhs) {
+ // Different resources do not overlap for this purpose. They may still alias
+ // at various points but that's beyond the analysis we can do here.
+ if (lhs.resource != rhs.resource)
+ return false;
+
+ // Check for adjacent but not overlapping.
+ if (lhs.end == rhs.start || lhs.start == rhs.end) {
+ return false;
+ }
+
+ // _May_ overlap. More analysis required.
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
// custom<ParameterReference>($scope, $key)
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
index 9835f9a..0ca39e4 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -86,6 +86,17 @@
Value start; // may be nullptr to indicate 0
Value end;
Value length;
+
+ // Returns true if the access is read-only.
+ bool isReadOnly() const { return access == ResourceAccessBitfield::Read; }
+
+ // Prints a textual representation of the range.
+ void print(llvm::raw_ostream &os, AsmState &asmState);
+
+ // Returns true if |lhs| and |rhs| may overlap and false only if it can be
+ // locally proven that they do not.
+ static bool mayOverlap(const AsyncAccessRange &lhs,
+ const AsyncAccessRange &rhs);
};
#include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.h.inc" // IWYU pragma: export
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
index 2d477f6..7e9c231 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
@@ -95,13 +95,12 @@
indicateOptimisticFixpoint();
LLVM_DEBUG({
- AsmState asmState(value.getParentBlock()->getParentOp());
llvm::dbgs() << "[elide-copies] initialized value last users for ";
- value.printAsOperand(llvm::dbgs(), asmState);
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << ": " << getAssumedSet().size() << "\n";
for (auto user : getAssumedSet()) {
llvm::dbgs() << " ";
- user->print(llvm::dbgs(), OpPrintingFlags().elideLargeElementsAttrs());
+ user->print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
}
});
@@ -289,9 +288,9 @@
// TODO(benvanik): change into something we can use for ref counting. We need
// that to insert stream-ordered deallocs and know when timepoints have been
// discard as they go out of scope. For now this strictly checks last use.
-class LastUseAnalysis {
+class ElisionAnalysis {
public:
- explicit LastUseAnalysis(Operation *rootOp)
+ explicit ElisionAnalysis(Operation *rootOp)
: explorer(rootOp, TraversalAction::SHALLOW),
solver(explorer, allocator) {
explorer.setOpInterfaceAction<mlir::FunctionOpInterface>(
@@ -307,6 +306,8 @@
rootOp->getRegions().front().getOps<mlir::CallableOpInterface>());
}
+ AsmState &getAsmState() { return solver.getAsmState(); }
+
// Runs analysis and populates the state cache.
// May fail if analysis cannot be completed due to unsupported or unknown IR.
LogicalResult run() {
@@ -359,6 +360,10 @@
SmallVector<mlir::CallableOpInterface> topLevelOps;
};
+//===----------------------------------------------------------------------===//
+// IREE::Stream::AsyncCloneOp elision
+//===----------------------------------------------------------------------===//
+
// Returns true if the given |operand| value does not need a copy on write.
// This is a conservative check and will return false ("not safe to elide") in
// many cases that otherwise don't need a copy. The
@@ -376,12 +381,11 @@
// %0 ---> %1 = clone(%0) ---> use(%1)
// \--> %2 = clone(%0) ---> use(%2) // last use of %0
static bool isSafeToElideCloneOp(IREE::Stream::AsyncCloneOp cloneOp,
- LastUseAnalysis &analysis) {
+ ElisionAnalysis &analysis) {
LLVM_DEBUG({
llvm::dbgs() << "isSafeToElideCloneOp:\n";
llvm::dbgs() << " ";
- cloneOp.print(llvm::dbgs(),
- OpPrintingFlags().elideLargeElementsAttrs().assumeVerified());
+ cloneOp.print(llvm::dbgs(), analysis.getAsmState());
llvm::dbgs() << "\n";
});
@@ -395,7 +399,7 @@
if (sourceType != targetType &&
sourceType.getLifetime() == IREE::Stream::Lifetime::Constant) {
LLVM_DEBUG(llvm::dbgs()
- << " + clone source is a constant; cannot elide\n");
+ << " - clone source is a constant; cannot elide\n");
return false;
}
@@ -433,28 +437,228 @@
return false;
}
-// Tries to elide copies nested within |region| when safe.
-// Returns true if any ops were elided.
-static bool tryElideAsyncCopiesInRegion(Region ®ion,
- LastUseAnalysis &analysis) {
- bool didChange = false;
- for (auto &block : region) {
- for (auto cloneOp : llvm::make_early_inc_range(
- block.getOps<IREE::Stream::AsyncCloneOp>())) {
- if (!isSafeToElideCloneOp(cloneOp, analysis))
- continue;
- cloneOp.replaceAllUsesWith(cloneOp.getSource());
- cloneOp.erase();
- didChange = true;
+// Elides a stream.async.clone op by replacing all uses with the cloned source.
+static void elideCloneOp(IREE::Stream::AsyncCloneOp cloneOp) {
+ cloneOp.replaceAllUsesWith(cloneOp.getSource());
+ cloneOp.erase();
+}
+
+//===----------------------------------------------------------------------===//
+// IREE::Stream::AsyncSliceOp elision
+//===----------------------------------------------------------------------===//
+
+// Filter to slices that are supported by the folding code.
+static bool areSliceUsesSupported(IREE::Stream::AsyncSliceOp sliceOp) {
+ for (auto &use : sliceOp.getResult().getUses()) {
+ if (!TypeSwitch<Operation *, bool>(use.getOwner())
+ .Case<IREE::Stream::AsyncCopyOp>([&](auto copyOp) {
+ // Only support folding into source today.
+ return !copyOp.isOperandTied(use.getOperandNumber());
+ })
+ .Case<IREE::Stream::AsyncDispatchOp>([&](auto dispatchOp) {
+ // Only support folding into reads today.
+ return !dispatchOp.isOperandTied(use.getOperandNumber());
+ })
+ .Default([](auto *op) { return false; })) {
+ return false;
}
}
- return didChange;
+ return true;
+}
+
+// Returns true if |sliceOp| is safe to elide.
+// This is only the case if the users are all supported ops.
+static bool isSafeToElideSliceOp(IREE::Stream::AsyncSliceOp sliceOp,
+ ElisionAnalysis &analysis) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "isSafeToElideSliceOp:\n";
+ llvm::dbgs() << " ";
+ sliceOp.print(llvm::dbgs(), analysis.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+
+ // Ensure all uses are ones we can support.
+ if (!areSliceUsesSupported(sliceOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " - slice consumers not supported; cannot elide\n");
+ return false;
+ }
+
+ // Currently we don't analyze up a tied op chain and require the defining op
+ // to be the producer.
+ Value source = sliceOp.getSource();
+ Value sourceBase = IREE::Util::TiedOpInterface::findTiedBaseValue(source);
+ if (source != sourceBase) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " - source is tied; cannot be elided (today)\n");
+ return false;
+ }
+
+ AsyncAccessRange sliceRange;
+ sliceRange.access = ResourceAccessBitfield::Read;
+ sliceRange.resource = source;
+ sliceRange.start = sliceOp.getSourceOffset();
+ sliceRange.end = sliceOp.getSourceEnd();
+ sliceRange.length = sliceOp.getResultSize();
+
+ // Gather all accesses of the source by all other ops (not the slice being
+ // inspected).
+ SmallVector<AsyncAccessRange> consumerRanges;
+ SmallVector<AsyncAccessRange> queryRanges;
+ for (auto user : source.getUsers()) {
+ if (user == sliceOp)
+ continue;
+ if (auto accessOp = dyn_cast<IREE::Stream::AsyncAccessOpInterface>(user)) {
+ // Async op consuming part of the resource. We can query it to see what
+ // it's doing to its operands/results and filter to just the accesses of
+ // the source value.
+ accessOp.getAsyncAccessRanges(queryRanges);
+ for (auto range : queryRanges) {
+ if (range.resource == source)
+ consumerRanges.push_back(range);
+ }
+ queryRanges.clear();
+ } else {
+ // Unknown user - for now we skip analysis. If we made the access range
+ // things elements in the solver we could traverse further.
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << " - analysis failure on unhandled user of slice source:\n";
+ user->print(llvm::dbgs(), analysis.getAsmState());
+ });
+ return false;
+ }
+ }
+
+ // If all other users don't overlap with the slice we can directly use the
+ // source resource.
+ for (auto &otherRange : consumerRanges) {
+ if (IREE::Stream::AsyncAccessRange::mayOverlap(sliceRange, otherRange)) {
+ // Potential overlap detected (or analysis failed) - if both are reads
+ // then we allow the elision (today) as there should be no hazard.
+ if (!otherRange.isReadOnly()) {
+ LLVM_DEBUG({
+ llvm::dbgs() << " - consumer overlap, skipping elision today\n";
+ llvm::dbgs() << " v slice ";
+ sliceRange.print(llvm::dbgs(), analysis.getAsmState());
+ llvm::dbgs() << "\n";
+ llvm::dbgs() << " ^ conflict ";
+ otherRange.print(llvm::dbgs(), analysis.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ return false;
+ }
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " + slice can (probably) be elided\n");
+ return true;
+}
+
+// arith.addi folders are terrible and don't handle adds of 0 so we handle that
+// here and then avoid doing the folding.
+static Value addOffset(Value lhs, Value rhs, OpBuilder &builder) {
+ if (matchPattern(lhs, m_Zero()))
+ return rhs;
+ if (matchPattern(rhs, m_Zero()))
+ return lhs;
+ return builder.createOrFold<arith::AddIOp>(
+ builder.getFusedLoc(lhs.getLoc(), rhs.getLoc()), lhs, rhs);
+}
+
+// TODO(benvanik): move these into patterns and use subview ops, maybe.
+// That would allow us to support a lot more op types but if we can't guarantee
+// a fold then we'd be left with unanalyzable subview ops. For now we handle the
+// cases we care about here.
+
+// Folds a stream.async.slice into a stream.async.copy source.
+static void foldSliceIntoCopy(IREE::Stream::AsyncSliceOp sliceOp,
+ IREE::Stream::AsyncCopyOp copyOp,
+ unsigned operandNumber) {
+ copyOp.getSourceMutable().set(sliceOp.getSource());
+ OpBuilder builder(copyOp);
+ copyOp.getSourceOffsetMutable().set(
+ addOffset(sliceOp.getSourceOffset(), copyOp.getSourceOffset(), builder));
+ copyOp.getSourceEndMutable().set(
+ addOffset(sliceOp.getSourceOffset(), copyOp.getSourceEnd(), builder));
+ copyOp.getSourceSizeMutable().set(sliceOp.getSourceSize());
+}
+
+// Folds a stream.async.slice into a stream.async.dispatch operand.
+static void foldSliceIntoDispatch(IREE::Stream::AsyncSliceOp sliceOp,
+ IREE::Stream::AsyncDispatchOp dispatchOp,
+ unsigned operandNumber) {
+ unsigned operandIndex =
+ operandNumber - dispatchOp.getTiedOperandsIndexAndLength().first;
+ dispatchOp.getResourceOperandsMutable()[operandIndex].set(
+ sliceOp.getSource());
+ unsigned resourceIndex = llvm::count_if(
+ dispatchOp.getResourceOperands().slice(0, operandIndex),
+ [](Value operand) {
+ return llvm::isa<IREE::Stream::ResourceType>(operand.getType());
+ });
+ OpBuilder builder(dispatchOp);
+ dispatchOp.getResourceOperandOffsetsMutable()[resourceIndex].set(addOffset(
+ sliceOp.getSourceOffset(),
+ dispatchOp.getResourceOperandOffsets()[resourceIndex], builder));
+ dispatchOp.getResourceOperandEndsMutable()[resourceIndex].set(
+ addOffset(sliceOp.getSourceOffset(),
+ dispatchOp.getResourceOperandEnds()[resourceIndex], builder));
+ dispatchOp.getResourceOperandSizesMutable()[resourceIndex].set(
+ sliceOp.getSourceSize());
+}
+
+// Elides a stream.async.slice op (assuming able) by folding it into consumers.
+static void elideSliceOp(IREE::Stream::AsyncSliceOp sliceOp) {
+ SmallVector<std::pair<Operation *, unsigned>> consumers;
+ for (auto &use : sliceOp.getResult().getUses())
+ consumers.push_back(std::make_pair(use.getOwner(), use.getOperandNumber()));
+ for (auto [owner, operandNumberIt] : consumers) {
+ unsigned operandNumber = operandNumberIt; // need C++20 to avoid this :|
+ TypeSwitch<Operation *>(owner)
+ .Case<IREE::Stream::AsyncCopyOp>([=](auto copyOp) {
+ foldSliceIntoCopy(sliceOp, copyOp, operandNumber);
+ })
+ .Case<IREE::Stream::AsyncDispatchOp>([=](auto dispatchOp) {
+ foldSliceIntoDispatch(sliceOp, dispatchOp, operandNumber);
+ })
+ .Default([](auto *op) {});
+ }
+ sliceOp.erase();
}
//===----------------------------------------------------------------------===//
// --iree-stream-elide-async-copies
//===----------------------------------------------------------------------===//
+// Tries to elide copies nested within |region| when safe.
+// Returns true if any ops were elided.
+static bool tryElideAsyncCopiesInRegion(Region ®ion,
+ ElisionAnalysis &analysis) {
+ bool didChange = false;
+ for (auto &block : region) {
+ block.walk([&](Operation *op) {
+ return TypeSwitch<Operation *, WalkResult>(op)
+ .Case<IREE::Stream::AsyncCloneOp>([&](auto cloneOp) {
+ if (isSafeToElideCloneOp(cloneOp, analysis)) {
+ elideCloneOp(cloneOp);
+ didChange = true;
+ }
+ return WalkResult::advance();
+ })
+ .Case<IREE::Stream::AsyncSliceOp>([&](auto sliceOp) {
+ if (isSafeToElideSliceOp(sliceOp, analysis)) {
+ elideSliceOp(sliceOp);
+ didChange = true;
+ }
+ return WalkResult::advance();
+ })
+ .Default([&](auto *op) { return WalkResult::advance(); });
+ });
+ }
+ return didChange;
+}
+
// Elides async copies that perform no meaningful work - such as clones of the
// last use of a value. This is designed to be run after
// --iree-stream-materialize-copy-on-write to clean up the copies it introduces
@@ -485,7 +689,7 @@
for (; iterationCount < maxIterationCount; ++iterationCount) {
// Perform whole-program analysis.
// TODO(benvanik): reuse allocator across iterations.
- LastUseAnalysis analysis(moduleOp);
+ ElisionAnalysis analysis(moduleOp);
if (failed(analysis.run())) {
moduleOp.emitError() << "failed to solve for last users";
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir
index c586855..6a6f935 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-stream-elide-async-copies %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-stream-elide-async-copies --cse %s | FileCheck %s
// Tests that a normal clone-on-multiple-uses pattern has the last clone elided.
// This is what the --iree-stream-materialize-copy-on-write pass generates and
@@ -120,3 +120,140 @@
^bb2(%bb2_0: !stream.resource<*>, %bb2_1: !stream.resource<*>):
util.return %bb2_0, %bb2_1 : !stream.resource<*>, !stream.resource<*>
}
+
+// -----
+
+// Tests that slices aren't elided when there are ops our folding doesn't (yet)
+// support.
+
+// CHECK-LABEL: @slice_unsupported_fold
+util.func private @slice_unsupported_fold(%producer: !stream.resource<*>) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: stream.async.slice
+ %slice = stream.async.slice %producer[%c100 to %c200] : !stream.resource<*>{%c300} -> !stream.resource<*>{%c100}
+ // CHECK: stream.async.fill
+ %consumer = stream.async.fill %c123_i32, %slice[%c0 to %c100 for %c100] : i32 -> !stream.resource<*>{%c100}
+ util.return %consumer : !stream.resource<*>
+}
+
+// -----
+
+// Tests that slices of tied values don't get folded as our analysis doesn't
+// (yet) walk up the use-def chain.
+
+// CHECK-LABEL: @slice_unsupported_tied
+util.func private @slice_unsupported_tied(%input: !stream.resource<*>) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ %producer_storage = stream.async.alloca : !stream.resource<*>{%c100}
+ // CHECK: stream.async.copy
+ %producer = stream.async.copy %input[%c0 to %c300], %producer_storage[%c0 to %c300], %c300 : !stream.resource<*>{%c300} -> %producer_storage as !stream.resource<*>{%c300}
+ // CHECK: stream.async.slice
+ %slice = stream.async.slice %producer[%c100 to %c200] : !stream.resource<*>{%c300} -> !stream.resource<*>{%c100}
+ %consumer_storage = stream.async.alloca : !stream.resource<*>{%c100}
+ // CHECK: stream.async.copy
+ %consumer = stream.async.copy %slice[%c0 to %c100], %consumer_storage[%c0 to %c100], %c300 : !stream.resource<*>{%c100} -> %consumer_storage as !stream.resource<*>{%c100}
+ util.return %consumer : !stream.resource<*>
+}
+
+// -----
+
+// Tests that sliced ranges that overlap other used ranges don't fold if there
+// are writes as the copy is required for correctness.
+
+// CHECK-LABEL: @slice_overlap_preventing
+util.func private @slice_overlap_preventing(%producer: !stream.resource<*>) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: stream.async.slice
+ %slice = stream.async.slice %producer[%c100 to %c200] : !stream.resource<*>{%c300} -> !stream.resource<*>{%c100}
+ %consumer_storage = stream.async.alloca : !stream.resource<*>{%c100}
+ // CHECK: stream.async.copy
+ %consumer = stream.async.copy %slice[%c0 to %c100], %consumer_storage[%c0 to %c100], %c300 : !stream.resource<*>{%c100} -> %consumer_storage as !stream.resource<*>{%c100}
+ // This fill overlaps the sliced range and should block the fold.
+ // CHECK: stream.async.fill
+ %fill = stream.async.fill %c123_i32, %producer[%c0 to %c200 for %c200] : i32 -> !stream.resource<*>{%c300}
+ util.return %consumer, %fill : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// Tests that sliced ranges that don't overlap other used ranges fold.
+
+// CHECK-LABEL: @slice_overlap_exclusive
+// CHECK-SAME: (%[[PRODUCER:.+]]: !stream.resource<*>)
+util.func private @slice_overlap_exclusive(%producer: !stream.resource<*>) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: stream.async.slice
+ %slice = stream.async.slice %producer[%c100 to %c200] : !stream.resource<*>{%c300} -> !stream.resource<*>{%c100}
+ %consumer_storage = stream.async.alloca : !stream.resource<*>{%c100}
+ // CHECK: stream.async.copy %[[PRODUCER]][%c100 to %c200]
+ %consumer = stream.async.copy %slice[%c0 to %c100], %consumer_storage[%c0 to %c100], %c300 : !stream.resource<*>{%c100} -> %consumer_storage as !stream.resource<*>{%c100}
+ // CHECK: stream.async.fill
+ %fill = stream.async.fill %c123_i32, %producer[%c200 to %c300 for %c100] : i32 -> !stream.resource<*>{%c300}
+ util.return %consumer, %fill : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// Tests that sliced ranges that overlap but just for reads.
+
+// CHECK-LABEL: @slice_overlap_readonly
+// CHECK-SAME: (%[[PRODUCER:.+]]: !stream.resource<*>)
+util.func private @slice_overlap_readonly(%producer: !stream.resource<*>) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-NOT: stream.async.slice
+ %slice = stream.async.slice %producer[%c100 to %c200] : !stream.resource<*>{%c300} -> !stream.resource<*>{%c100}
+ %consumer_storage_0 = stream.async.alloca : !stream.resource<*>{%c100}
+ // CHECK: stream.async.copy %[[PRODUCER]][%c100 to %c200]
+ %consumer_0 = stream.async.copy %slice[%c0 to %c100], %consumer_storage_0[%c0 to %c100], %c300 : !stream.resource<*>{%c100} -> %consumer_storage as !stream.resource<*>{%c100}
+ %consumer_storage_1 = stream.async.alloca : !stream.resource<*>{%c100}
+ // CHECK: stream.async.copy %[[PRODUCER]][%c101 to %c201]
+ %consumer_1 = stream.async.copy %slice[%c1 to %c101], %consumer_storage_1[%c0 to %c100], %c300 : !stream.resource<*>{%c100} -> %consumer_storage as !stream.resource<*>{%c100}
+ util.return %consumer_0, %consumer_1 : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+stream.executable private @ex {
+ stream.executable.export public @dispatch workgroups() -> (index, index, index) {
+ %c1 = arith.constant 1 : index
+ stream.return %c1, %c1, %c1 : index, index, index
+ }
+}
+
+// CHECK-LABEL: @slice_dispatch_fold
+// CHECK-SAME: (%[[PRODUCER:.+]]: !stream.resource<*>)
+util.func private @slice_dispatch_fold(%producer: !stream.resource<*>) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %c20 = arith.constant 20 : index
+ %c30 = arith.constant 30 : index
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: stream.async.slice
+ %slice = stream.async.slice %producer[%c100 to %c200] : !stream.resource<*>{%c300} -> !stream.resource<*>{%c100}
+ // CHECK: stream.async.dispatch @ex::@dispatch(%c123_i32, %[[PRODUCER]][%c110 to %c130 for %c20]) : (i32, !stream.resource<*>{%c300}) -> !stream.resource<*>{%c100}
+ %consumer = stream.async.dispatch @ex::@dispatch(%c123_i32, %slice[%c10 to %c30 for %c20]) : (i32, !stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+ util.return %consumer : !stream.resource<*>
+}