Adding iree-stream-verify-async-access-ranges pass. (#13876)

This does some basic verification that AsyncAccessOpInterface ops (like
stream.async.copy and stream.async.dispatch) are in-bounds. Currently
only doing static verification but in the future it can be extended to
use relational verification (start < end, etc).

Prints useful errors like:
```
../iree-tmp/gpt2-64-stream-emplace.mlir:1603:12: error: 'stream.async.dispatch' op has invalid Read access range [180224 to 376832 for 196608] of resource %177 with size 196608; end > resource size
```

Hopefully we never produce programs like this but while iterating on
passes I've done it a few times and it's really hard to track down
(usually only via painful runtime debugging of intermediate values).
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
index cc3d362..bab4434 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
@@ -39,6 +39,7 @@
         "ScheduleConcurrency.cpp",
         "ScheduleExecution.cpp",
         "SpecializeDispatches.cpp",
+        "VerifyAsyncAccessRanges.cpp",
         "VerifyLowerings.cpp",
     ],
     hdrs = [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 9bb1af3..eb8bb19 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -41,6 +41,7 @@
     "ScheduleConcurrency.cpp"
     "ScheduleExecution.cpp"
     "SpecializeDispatches.cpp"
+    "VerifyAsyncAccessRanges.cpp"
     "VerifyLowerings.cpp"
   DEPS
     ::PassesIncGen
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index ab85cab..3636a05 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -145,6 +145,15 @@
   passManager.addPass(IREE::Stream::createRefineUsagePass());
   addCleanupPatterns(passManager);
 
+  // Verify all stream.async.* op access ranges that we can by taking advantage
+  // of statically available information or that which we can infer from data
+  // flow analysis. Because this may require a global analysis it's best done in
+  // a pass instead of individual op verifiers. We could run the pass more
+  // frequently above or move some of the simpler checks to op verifiers if we
+  // wanted to catch errors earlier but this is mostly a guard before we go into
+  // the stream.cmd.* layer.
+  passManager.addPass(IREE::Stream::createVerifyAsyncAccessRangesPass());
+
   //----------------------------------------------------------------------------
   // Stream formation and scheduling
   //----------------------------------------------------------------------------
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
index 4902ff2..a7cf768 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -163,6 +163,9 @@
     DumpOutputFormat outputFormat = DumpOutputFormat::Pretty,
     std::string outputFile = "");
 
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyAsyncAccessRangesPass();
+
 std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyInputPass();
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createVerifyLoweringToTensorsPass();
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 5b5b9f3..788771e 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -246,6 +246,14 @@
   ];
 }
 
+def VerifyAsyncAccessRanges :
+    Pass<"iree-stream-verify-async-access-ranges", "mlir::ModuleOp"> {
+  let summary = "Verifies that stream.async.* access ranges are in bounds where possible.";
+  let constructor = [{
+    mlir::iree_compiler::IREE::Stream::createVerifyAsyncAccessRangesPass()
+  }];
+}
+
 def VerifyInput :
     Pass<"iree-stream-verify-input", "mlir::ModuleOp"> {
   let summary = "Verifies that input dialects are supported by the streams dialect.";
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp
new file mode 100644
index 0000000..f447ed9
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp
@@ -0,0 +1,143 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+namespace {
+
+static std::optional<int64_t> matchConstant(Value value) {
+  if (!value) return std::nullopt;
+  APInt constant;
+  if (!matchPattern(value, m_ConstantInt(&constant))) return std::nullopt;
+  return constant.getSExtValue();
+}
+
+static LogicalResult verifyAsyncAccessRange(
+    IREE::Stream::AsyncAccessOpInterface accessOp,
+    IREE::Stream::AsyncAccessRange &range) {
+  auto start = matchConstant(range.start);
+  auto length = matchConstant(range.length);
+  auto end = matchConstant(range.end);
+  auto resourceSize =
+      matchConstant(IREE::Util::SizeAwareTypeInterface::findSizeValue(
+          range.resource, accessOp->getBlock(), Block::iterator(accessOp)));
+
+  auto appendValue = [&](InFlightDiagnostic &diagnostic, Value value) {
+    std::string str;
+    llvm::raw_string_ostream os(str);
+    value.printAsOperand(os, OpPrintingFlags());
+    diagnostic << str;
+  };
+  auto emitRangeError = [&]() {
+    auto diagnostic = accessOp.emitOpError();
+    diagnostic << "has invalid "
+               << IREE::Stream::stringifyResourceAccessBitfield(range.access)
+               << " access range [";
+    start ? (diagnostic << *start) : (diagnostic << "?");
+    diagnostic << " to ";
+    end ? (diagnostic << *end) : (diagnostic << "?");
+    diagnostic << " for ";
+    length ? (diagnostic << *length) : (diagnostic << "?");
+    diagnostic << "] of resource ";
+    appendValue(diagnostic, range.resource);
+    diagnostic << " with size ";
+    resourceSize ? (diagnostic << *resourceSize) : (diagnostic << "?");
+    return diagnostic;
+  };
+
+  if (start && end) {
+    if (start.value() > end.value()) {
+      return emitRangeError() << "; start > end";
+    }
+  }
+  if (length && end) {
+    if (length.value() > end.value()) {
+      return emitRangeError() << "; length > end";
+    }
+  }
+  if (start && length && end) {
+    if (start.value() + length.value() != end.value()) {
+      return emitRangeError() << "; start + length != end";
+    }
+  }
+  if (resourceSize) {
+    if (start && *start > *resourceSize) {
+      return emitRangeError() << "; start > resource size";
+    }
+    if (length && *length > *resourceSize) {
+      return emitRangeError() << "; length > resource size";
+    }
+    if (end && *end > *resourceSize) {
+      return emitRangeError() << "; end > resource size";
+    }
+  }
+  return success();
+}
+
+// Statically verifies that the ranges used by |accessOp| are in bounds.
+// Emits errors for all ranges declared on the op that are invalid.
+static LogicalResult verifyAsyncAccessOp(
+    IREE::Stream::AsyncAccessOpInterface accessOp) {
+  SmallVector<AsyncAccessRange> ranges;
+  accessOp.getAsyncAccessRanges(ranges);
+  bool allSucceeded = true;
+  for (auto &range : ranges) {
+    if (failed(verifyAsyncAccessRange(accessOp, range))) {
+      allSucceeded = false;
+    }
+  }
+  return success(allSucceeded);
+}
+
+class VerifyAsyncAccessRangesPass
+    : public VerifyAsyncAccessRangesBase<VerifyAsyncAccessRangesPass> {
+ public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREE::Stream::StreamDialect>();
+  }
+
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+    // TODO(benvanik): do whole-program data flow analysis to get bounded sizes
+    // for range checking. Today we just do static checks.
+    if (moduleOp
+            .walk([&](IREE::Stream::AsyncAccessOpInterface accessOp) {
+              return succeeded(verifyAsyncAccessOp(accessOp))
+                         ? WalkResult::advance()
+                         : WalkResult::interrupt();
+            })
+            .wasInterrupted()) {
+      return signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyAsyncAccessRangesPass() {
+  return std::make_unique<VerifyAsyncAccessRangesPass>();
+}
+
+}  // namespace Stream
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
index 29f2001..713f6ac 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
@@ -45,6 +45,7 @@
             "schedule_concurrency.mlir",
             "schedule_execution.mlir",
             "specialize_dispatches.mlir",
+            "verify_async_access_ranges.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index 0d996f2..3b3448c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -43,6 +43,7 @@
     "schedule_concurrency.mlir"
     "schedule_execution.mlir"
     "specialize_dispatches.mlir"
+    "verify_async_access_ranges.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_async_access_ranges.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_async_access_ranges.mlir
new file mode 100644
index 0000000..1a4d361
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_async_access_ranges.mlir
@@ -0,0 +1,46 @@
+// RUN: iree-opt --iree-stream-verify-async-access-ranges --split-input-file %s --verify-diagnostics | FileCheck %s
+
+// Tests that statically-known valid ranges pass verification.
+
+// CHECK: @inRangeCopy
+func.func @inRangeCopy(%source: !stream.resource<*>, %target: !stream.resource<*>) -> !stream.resource<*> {
+  %source_size = arith.constant 256 : index
+  %target_size = arith.constant 256 : index
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  // CHECK: = stream.async.copy
+  %0 = stream.async.copy %source[%c128 to %c256], %target[%c128 to %c256], %c128 : !stream.resource<*>{%source_size} -> %target as !stream.resource<*>{%target_size}
+  return %0 : !stream.resource<*>
+}
+
+// -----
+
+// Tests that statically-known invalid ranges emit errors.
+// For more useful reporting we report all errors on an op so this expects 2.
+func.func @outOfRangeCopy(%source: !stream.resource<*>, %target: !stream.resource<*>) -> !stream.resource<*> {
+  %source_size = arith.constant 256 : index
+  %target_size = arith.constant 255 : index  // NOTE: too small!
+  %c128 = arith.constant 128 : index
+  %c256 = arith.constant 256 : index
+  %c512 = arith.constant 512 : index
+  // expected-error @+3 {{invalid Read access range [128 to 512 for 128] of resource %arg0 with size 256}}
+  // expected-error @+2 {{invalid Write access range [256 to 512 for 128] of resource %arg1 with size 255}}
+  // expected-error @+1 {{invalid Write access range [256 to 512 for 128] of resource %0 with size 255}}
+  %0 = stream.async.copy %source[%c128 to %c512], %target[%c256 to %c512], %c128 : !stream.resource<*>{%source_size} -> %target as !stream.resource<*>{%target_size}
+  return %0 : !stream.resource<*>
+}
+
+// -----
+
+// Tests that static ranges don't get checked against dynamic sizes.
+// In the future we could use data flow analysis to try to bound dynamic values
+// and this pass could verify the conditions (size of A < size of B, etc).
+
+// CHECK-LABEL: @dynamicSizes
+func.func @dynamicSizes(%source: !stream.resource<*>, %source_size: index, %target: !stream.resource<*>, %target_size: index) -> !stream.resource<*> {
+  %c0 = arith.constant 0 : index
+  %c128 = arith.constant 128 : index
+  // CHECK: = stream.async.copy
+  %0 = stream.async.copy %source[%c0 to %c128], %target[%c0 to %c128], %c128 : !stream.resource<*>{%source_size} -> %target as !stream.resource<*>{%target_size}
+  return %0 : !stream.resource<*>
+}