Adding iree-stream-verify-async-access-ranges pass. (#13876)
This does some basic verification that AsyncAccessOpInterface ops (like
stream.async.copy and stream.async.dispatch) are in-bounds. Currently
only doing static verification but in the future it can be extended to
use relational verification (start < end, etc).
Prints useful errors like:
```
../iree-tmp/gpt2-64-stream-emplace.mlir:1603:12: error: 'stream.async.dispatch' op has invalid Read access range [180224 to 376832 for 196608] of resource %177 with size 196608; end > resource size
```
Hopefully we never produce programs like this but while iterating on
passes I've done it a few times and it's really hard to track down
(usually only via painful runtime debugging of intermediate values).
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
index cc3d362..bab4434 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
@@ -39,6 +39,7 @@
"ScheduleConcurrency.cpp",
"ScheduleExecution.cpp",
"SpecializeDispatches.cpp",
+ "VerifyAsyncAccessRanges.cpp",
"VerifyLowerings.cpp",
],
hdrs = [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 9bb1af3..eb8bb19 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -41,6 +41,7 @@
"ScheduleConcurrency.cpp"
"ScheduleExecution.cpp"
"SpecializeDispatches.cpp"
+ "VerifyAsyncAccessRanges.cpp"
"VerifyLowerings.cpp"
DEPS
::PassesIncGen
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index ab85cab..3636a05 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -145,6 +145,15 @@
passManager.addPass(IREE::Stream::createRefineUsagePass());
addCleanupPatterns(passManager);
+ // Verify all stream.async.* op access ranges that we can by taking advantage
+ // of statically available information or that which we can infer from data
+ // flow analysis. Because this may require a global analysis it's best done in
+ // a pass instead of individual op verifiers. We could run the pass more
+ // frequently above or move some of the simpler checks to op verifiers if we
+ // wanted to catch errors earlier but this is mostly a guard before we go into
+ // the stream.cmd.* layer.
+ passManager.addPass(IREE::Stream::createVerifyAsyncAccessRangesPass());
+
//----------------------------------------------------------------------------
// Stream formation and scheduling
//----------------------------------------------------------------------------
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
index 4902ff2..a7cf768 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -163,6 +163,9 @@
DumpOutputFormat outputFormat = DumpOutputFormat::Pretty,
std::string outputFile = "");
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyAsyncAccessRangesPass();
+
std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyInputPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createVerifyLoweringToTensorsPass();
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 5b5b9f3..788771e 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -246,6 +246,14 @@
];
}
+def VerifyAsyncAccessRanges :
+ Pass<"iree-stream-verify-async-access-ranges", "mlir::ModuleOp"> {
+ let summary = "Verifies that stream.async.* access ranges are in bounds where possible.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createVerifyAsyncAccessRangesPass()
+ }];
+}
+
def VerifyInput :
Pass<"iree-stream-verify-input", "mlir::ModuleOp"> {
let summary = "Verifies that input dialects are supported by the streams dialect.";
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp
new file mode 100644
index 0000000..f447ed9
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp
@@ -0,0 +1,143 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+namespace {
+
+static std::optional<int64_t> matchConstant(Value value) {
+ if (!value) return std::nullopt;
+ APInt constant;
+ if (!matchPattern(value, m_ConstantInt(&constant))) return std::nullopt;
+ return constant.getSExtValue();
+}
+
+static LogicalResult verifyAsyncAccessRange(
+ IREE::Stream::AsyncAccessOpInterface accessOp,
+ IREE::Stream::AsyncAccessRange &range) {
+ auto start = matchConstant(range.start);
+ auto length = matchConstant(range.length);
+ auto end = matchConstant(range.end);
+ auto resourceSize =
+ matchConstant(IREE::Util::SizeAwareTypeInterface::findSizeValue(
+ range.resource, accessOp->getBlock(), Block::iterator(accessOp)));
+
+ auto appendValue = [&](InFlightDiagnostic &diagnostic, Value value) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ value.printAsOperand(os, OpPrintingFlags());
+ diagnostic << str;
+ };
+ auto emitRangeError = [&]() {
+ auto diagnostic = accessOp.emitOpError();
+ diagnostic << "has invalid "
+ << IREE::Stream::stringifyResourceAccessBitfield(range.access)
+ << " access range [";
+ start ? (diagnostic << *start) : (diagnostic << "?");
+ diagnostic << " to ";
+ end ? (diagnostic << *end) : (diagnostic << "?");
+ diagnostic << " for ";
+ length ? (diagnostic << *length) : (diagnostic << "?");
+ diagnostic << "] of resource ";
+ appendValue(diagnostic, range.resource);
+ diagnostic << " with size ";
+ resourceSize ? (diagnostic << *resourceSize) : (diagnostic << "?");
+ return diagnostic;
+ };
+
+ if (start && end) {
+ if (start.value() > end.value()) {
+ return emitRangeError() << "; start > end";
+ }
+ }
+ if (length && end) {
+ if (length.value() > end.value()) {
+ return emitRangeError() << "; length > end";
+ }
+ }
+ if (start && length && end) {
+ if (start.value() + length.value() != end.value()) {
+ return emitRangeError() << "; start + length != end";
+ }
+ }
+ if (resourceSize) {
+ if (start && *start > *resourceSize) {
+ return emitRangeError() << "; start > resource size";
+ }
+ if (length && *length > *resourceSize) {
+ return emitRangeError() << "; length > resource size";
+ }
+ if (end && *end > *resourceSize) {
+ return emitRangeError() << "; end > resource size";
+ }
+ }
+ return success();
+}
+
+// Statically verifies that the ranges used by |accessOp| are in bounds.
+// Emits errors for all ranges declared on the op that are invalid.
+static LogicalResult verifyAsyncAccessOp(
+ IREE::Stream::AsyncAccessOpInterface accessOp) {
+ SmallVector<AsyncAccessRange> ranges;
+ accessOp.getAsyncAccessRanges(ranges);
+ bool allSucceeded = true;
+ for (auto &range : ranges) {
+ if (failed(verifyAsyncAccessRange(accessOp, range))) {
+ allSucceeded = false;
+ }
+ }
+ return success(allSucceeded);
+}
+
+class VerifyAsyncAccessRangesPass
+ : public VerifyAsyncAccessRangesBase<VerifyAsyncAccessRangesPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Stream::StreamDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ // TODO(benvanik): do whole-program data flow analysis to get bounded sizes
+ // for range checking. Today we just do static checks.
+ if (moduleOp
+ .walk([&](IREE::Stream::AsyncAccessOpInterface accessOp) {
+ return succeeded(verifyAsyncAccessOp(accessOp))
+ ? WalkResult::advance()
+ : WalkResult::interrupt();
+ })
+ .wasInterrupted()) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyAsyncAccessRangesPass() {
+ return std::make_unique<VerifyAsyncAccessRangesPass>();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
index 29f2001..713f6ac 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
@@ -45,6 +45,7 @@
"schedule_concurrency.mlir",
"schedule_execution.mlir",
"specialize_dispatches.mlir",
+ "verify_async_access_ranges.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index 0d996f2..3b3448c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -43,6 +43,7 @@
"schedule_concurrency.mlir"
"schedule_execution.mlir"
"specialize_dispatches.mlir"
+ "verify_async_access_ranges.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_async_access_ranges.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_async_access_ranges.mlir
new file mode 100644
index 0000000..1a4d361
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_async_access_ranges.mlir
@@ -0,0 +1,46 @@
+// RUN: iree-opt --iree-stream-verify-async-access-ranges --split-input-file %s --verify-diagnostics | FileCheck %s
+
+// Tests that statically-known valid ranges pass verification.
+
+// CHECK: @inRangeCopy
+func.func @inRangeCopy(%source: !stream.resource<*>, %target: !stream.resource<*>) -> !stream.resource<*> {
+ %source_size = arith.constant 256 : index
+ %target_size = arith.constant 256 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ // CHECK: = stream.async.copy
+ %0 = stream.async.copy %source[%c128 to %c256], %target[%c128 to %c256], %c128 : !stream.resource<*>{%source_size} -> %target as !stream.resource<*>{%target_size}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// Tests that statically-known invalid ranges emit errors.
+// For more useful reporting we report all errors on an op so this expects 2.
+func.func @outOfRangeCopy(%source: !stream.resource<*>, %target: !stream.resource<*>) -> !stream.resource<*> {
+ %source_size = arith.constant 256 : index
+ %target_size = arith.constant 255 : index // NOTE: too small!
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c512 = arith.constant 512 : index
+ // expected-error @+3 {{invalid Read access range [128 to 512 for 128] of resource %arg0 with size 256}}
+ // expected-error @+2 {{invalid Write access range [256 to 512 for 128] of resource %arg1 with size 255}}
+ // expected-error @+1 {{invalid Write access range [256 to 512 for 128] of resource %0 with size 255}}
+ %0 = stream.async.copy %source[%c128 to %c512], %target[%c256 to %c512], %c128 : !stream.resource<*>{%source_size} -> %target as !stream.resource<*>{%target_size}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// Tests that static ranges don't get checked against dynamic sizes.
+// In the future we could use data flow analysis to try to bound dynamic values
+// and this pass could verify the conditions (size of A < size of B, etc).
+
+// CHECK-LABEL: @dynamicSizes
+func.func @dynamicSizes(%source: !stream.resource<*>, %source_size: index, %target: !stream.resource<*>, %target_size: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.async.copy
+ %0 = stream.async.copy %source[%c0 to %c128], %target[%c0 to %c128], %c128 : !stream.resource<*>{%source_size} -> %target as !stream.resource<*>{%target_size}
+ return %0 : !stream.resource<*>
+}