Adding `stream` dialect. (#7398)
This adds the `stream` dialect, ops, types, interfaces, and canonicalizers.
See `iree/compiler/Dialect/Stream/IR/StreamBase.td` for an overview of the dialect.
Future changes will add conversions, analyses, and passes using the dialect. There are some aspects
still work-in-progress, particularly around copy-on-write materialization, but the dialect is largely
functionally complete (if yet still missing a lot of potential canonicalizations).
diff --git a/iree/compiler/Dialect/Stream/BUILD b/iree/compiler/Dialect/Stream/BUILD
new file mode 100644
index 0000000..f27d209
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/BUILD
@@ -0,0 +1,11 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/iree/compiler/Dialect/Stream/CMakeLists.txt b/iree/compiler/Dialect/Stream/CMakeLists.txt
new file mode 100644
index 0000000..de2c66b
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/IR/BUILD b/iree/compiler/Dialect/Stream/IR/BUILD
new file mode 100644
index 0000000..39655c5
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/BUILD
@@ -0,0 +1,179 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc")
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+td_library(
+ name = "td_files",
+ srcs = enforce_glob(
+ [
+ "StreamBase.td",
+ "StreamInterfaces.td",
+ "StreamOps.td",
+ ],
+ include = ["*.td"],
+ ),
+ deps = [
+ "//iree/compiler/Dialect/Shape/IR:td_files",
+ "//iree/compiler/Dialect/Util/IR:td_files",
+ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:SideEffectTdFiles",
+ "@llvm-project//mlir:StdOpsTdFiles",
+ "@llvm-project//mlir:SubElementInterfacesTdFiles",
+ "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
+ ],
+)
+
+cc_library(
+ name = "IR",
+ srcs = [
+ "StreamDialect.cpp",
+ "StreamEnums.cpp.inc",
+ "StreamOpFolders.cpp",
+ "StreamOpInterfaces.cpp.inc",
+ "StreamOps.cpp",
+ "StreamOps.cpp.inc",
+ "StreamTypeInterfaces.cpp.inc",
+ "StreamTypes.cpp",
+ "StreamTypes.cpp.inc",
+ ],
+ hdrs = [
+ "StreamDialect.h",
+ "StreamEnums.h.inc",
+ "StreamOpInterfaces.h.inc",
+ "StreamOps.h",
+ "StreamOps.h.inc",
+ "StreamTraits.h",
+ "StreamTypeInterfaces.h.inc",
+ "StreamTypes.h",
+ "StreamTypes.h.inc",
+ ],
+ deps = [
+ ":StreamEnumsGen",
+ ":StreamInterfacesGen",
+ ":StreamOpsGen",
+ ":StreamTypesGen",
+ "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/Util/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:InferTypeOpInterface",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:SideEffects",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformUtils",
+ ],
+)
+
+gentbl_cc_library(
+ name = "StreamEnumsGen",
+ tbl_outs = [
+ (
+ ["-gen-enum-decls"],
+ "StreamEnums.h.inc",
+ ),
+ (
+ ["-gen-enum-defs"],
+ "StreamEnums.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "StreamBase.td",
+ deps = [":td_files"],
+)
+
+gentbl_cc_library(
+ name = "StreamInterfacesGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "StreamOpInterfaces.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "StreamOpInterfaces.cpp.inc",
+ ),
+ (
+ ["-gen-type-interface-decls"],
+ "StreamTypeInterfaces.h.inc",
+ ),
+ (
+ ["-gen-type-interface-defs"],
+ "StreamTypeInterfaces.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "StreamInterfaces.td",
+ deps = [":td_files"],
+)
+
+gentbl_cc_library(
+ name = "StreamOpsGen",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "StreamOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "StreamOps.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "StreamOps.td",
+ deps = [":td_files"],
+)
+
+gentbl_cc_library(
+ name = "StreamTypesGen",
+ tbl_outs = [
+ (
+ ["-gen-attrdef-decls"],
+ "StreamAttrs.h.inc",
+ ),
+ (
+ ["-gen-attrdef-defs"],
+ "StreamAttrs.cpp.inc",
+ ),
+ (
+ ["-gen-typedef-decls"],
+ "StreamTypes.h.inc",
+ ),
+ (
+ ["-gen-typedef-defs"],
+ "StreamTypes.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "StreamBase.td",
+ deps = [":td_files"],
+)
+
+iree_tablegen_doc(
+ name = "StreamDialectDocGen",
+ tbl_outs = [
+ (
+ ["-gen-dialect-doc"],
+ "StreamDialect.md",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "StreamOps.td",
+ deps = [":td_files"],
+)
diff --git a/iree/compiler/Dialect/Stream/IR/CMakeLists.txt b/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
new file mode 100644
index 0000000..6b00ba4
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
@@ -0,0 +1,110 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/IR/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ IR
+ HDRS
+ "StreamDialect.h"
+ "StreamEnums.h.inc"
+ "StreamOpInterfaces.h.inc"
+ "StreamOps.h"
+ "StreamOps.h.inc"
+ "StreamTraits.h"
+ "StreamTypeInterfaces.h.inc"
+ "StreamTypes.h"
+ "StreamTypes.h.inc"
+ SRCS
+ "StreamDialect.cpp"
+ "StreamEnums.cpp.inc"
+ "StreamOpFolders.cpp"
+ "StreamOpInterfaces.cpp.inc"
+ "StreamOps.cpp"
+ "StreamOps.cpp.inc"
+ "StreamTypeInterfaces.cpp.inc"
+ "StreamTypes.cpp"
+ "StreamTypes.cpp.inc"
+ DEPS
+ ::StreamEnumsGen
+ ::StreamInterfacesGen
+ ::StreamOpsGen
+ ::StreamTypesGen
+ LLVMSupport
+ MLIRArithmetic
+ MLIRIR
+ MLIRInferTypeOpInterface
+ MLIRMemRef
+ MLIRParser
+ MLIRSideEffectInterfaces
+ MLIRStandard
+ MLIRSupport
+ MLIRTensor
+ MLIRTransformUtils
+ iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::Util::IR
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ StreamEnumsGen
+ TD_FILE
+ "StreamBase.td"
+ OUTS
+ -gen-enum-decls StreamEnums.h.inc
+ -gen-enum-defs StreamEnums.cpp.inc
+)
+
+iree_tablegen_library(
+ NAME
+ StreamInterfacesGen
+ TD_FILE
+ "StreamInterfaces.td"
+ OUTS
+ -gen-op-interface-decls StreamOpInterfaces.h.inc
+ -gen-op-interface-defs StreamOpInterfaces.cpp.inc
+ -gen-type-interface-decls StreamTypeInterfaces.h.inc
+ -gen-type-interface-defs StreamTypeInterfaces.cpp.inc
+)
+
+iree_tablegen_library(
+ NAME
+ StreamOpsGen
+ TD_FILE
+ "StreamOps.td"
+ OUTS
+ -gen-op-decls StreamOps.h.inc
+ -gen-op-defs StreamOps.cpp.inc
+)
+
+iree_tablegen_library(
+ NAME
+ StreamTypesGen
+ TD_FILE
+ "StreamBase.td"
+ OUTS
+ -gen-attrdef-decls StreamAttrs.h.inc
+ -gen-attrdef-defs StreamAttrs.cpp.inc
+ -gen-typedef-decls StreamTypes.h.inc
+ -gen-typedef-defs StreamTypes.cpp.inc
+)
+
+iree_tablegen_doc(
+ NAME
+ StreamDialectDocGen
+ TD_FILE
+ "StreamOps.td"
+ OUTS
+ -gen-dialect-doc StreamDialect.md
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/IR/StreamBase.td b/iree/compiler/Dialect/Stream/IR/StreamBase.td
new file mode 100644
index 0000000..4be92fd
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -0,0 +1,549 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_STREAM_BASE
+#define IREE_DIALECT_STREAM_BASE
+
+include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td"
+include "iree/compiler/Dialect/Util/IR/UtilBase.td"
+include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
+include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/IR/SubElementInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// IREE stream modeling dialect
+//===----------------------------------------------------------------------===//
+
+def Stream_Dialect : Dialect {
+ let name = "stream";
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
+ let summary = [{
+ A dialect designed to model execution partitioning and scheduling.
+ }];
+ let description = [{
+ The stream dialect is designed to take tensor programs and convert them to
+ explicitly scheduled asynchronous programs. This includes placing ops on
+ specific targets, partitioning the work between the targets, scheduling the
+ work for concurrency, and encoding tensors into target-specific resources.
+
+ ```
+ +--------+ +----------+ +-------+
+ | flow.* | -> | stream.* | -> | hal.* |
+ +--------+ +----------+ +-------+
+ ```
+
+ This sits in-between the `flow` and `hal` dialects.
+
+ * `flow` models tensor programs by separating work into dispatchable
+ functions in order to isolate the main host program data flow and the
+ dense tensor compute operations.
+
+ * `stream` models explicitly scheduled asynchronous programs by partitioning
+ the dispatchable work, specifying target affinities, encoding tensors into
+ target-specific forms, and scheduling the work to run concurrently.
+
+ * `hal` models a low-level hardware abstraction layer used to manage
+ buffers and issue asynchronous work across a variety of device types. The
+ dialect is largely 1:1 with the IREE HAL C API.
+
+ Transforms in the dialect lower tensor values into opaque resources with the
+ goal of ensuring no tensors survive in the IR. At entry `stream.tensor.*`
+ ops are used to capture the source tensor encoding information (data type,
+ shapes, etc) and then lowered into `stream.async.*` ops that model the
+ asynchronous workloads on the opaque resources. The asynchronous operations
+ are then partitioned, allocated, and scheduled for execution using the
+ `stream.cmd.*` ops.
+
+ It's intended that after transformation through the stream dialect the
+ program is ready for execution on an abstract machine. At this level of
+ representation buffers have still not been allocated and devices are not
+ yet resolved, however the information captured in the `stream` IR allows
+ such operations to be done trivially. To this end all ops carry the symbolic
+ size of the resources on which they operate as well as the lifetime of the
+ resources they are acting upon. This manifests in the usage of the
+ `!stream.resource` type:
+
+ ```mlir
+ // Unresolved lifetime (resolved during the iree-stream-refine-usage pass):
+ !stream.resource<*>
+ // An externally managed value (passed in via the program API).
+ !stream.resource<external>
+ // A staging buffer for uploads/downloads.
+ !stream.resource<staging>
+ // A short-lived value that is used across streams.
+ !stream.resource<transient>
+ // A long-lived value that persists across streams in globals.
+ !stream.resource<variable>
+ // An immutable value that persists for the duration of the program.
+ !stream.resource<constant>
+ ```
+
+ Operations using resources carry the size of all operand result resources:
+
+ ```mlir
+ // %update (40 bytes) is being inserted into %target (296 bytes).
+ // Can be dynamic values such as those originating from dynamic dimensions.
+ %13 = stream.async.update %update, %target[%c256 to %c296] :
+ !stream.resource<transient>{%c40} ->
+ %target as !stream.resource<transient>{%c296}
+ ```
+
+ Once all `stream.async.*` work is moved into executable regions (such as
+ `stream.async.execute`) `!stream.timepoint` values are used to sequence
+ the execution. These timepoints represent some point in time where all
+ execution up to that timepoint has completed and any results that were
+ produced by the execution are available for use. Attempting to use the
+ resources before their corresponding timepoint has been reached will lead
+ to undefined behavior. The benefit of this is that after timepoints are
+ established in the IR it's possible to induce aliasing of resources without
+ breaking execution correctness.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Base stream dialect op classes
+//===----------------------------------------------------------------------===//
+
+class Stream_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Stream_Dialect, mnemonic, traits> {
+ let parser = [{ return parse$cppClass(parser, &result); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Stream dialect types
+//===----------------------------------------------------------------------===//
+
+def Stream_PrimitiveType : AnyTypeOf<[Index, AnyInteger, AnyFloat]>;
+def Stream_Offset : TypeAlias<Index>;
+def Stream_Size : TypeAlias<Index>;
+
+def Stream_Tensor : TypeAlias<AnyRankedTensor>;
+def Stream_Dim : TypeAlias<Index>;
+def Stream_ShapeDynamicDims : Variadic<Stream_Dim>;
+
+def Stream_IndexAttr : Util_IndexAttrBase<"index">;
+def Stream_IndexArrayAttr : TypedArrayAttrBase<Stream_IndexAttr,
+ "index array attribute"> {
+ let constBuilderCall = "$_builder.getIndexArrayAttr($0)";
+}
+
+def Stream_ExecutableRefAttr : AliasedSymbolRefAttr;
+def Stream_GlobalRefAttr : AliasedSymbolRefAttr;
+def Stream_GlobalPtr : Util_AnyPtrOf<[Stream_Tensor, Stream_PrimitiveType]>;
+
+//===----------------------------------------------------------------------===//
+// Stream attributes
+//===----------------------------------------------------------------------===//
+
+def Stream_AffinityAttr : AttrDef<Stream_Dialect, "Affinity", []> {
+ let mnemonic = "affinity";
+
+ let summary = [{defines execution context affinity}];
+ let description = [{
+ TBD. The intent is that this can specify host, device, and queue affinity.
+ Scopes can be annotated with an affinity to ensure execution within happens
+ in a particular location. Arrays of affinities or wildcard specifiers will
+ allow for refinement ("do it on this device but auto select a queue"). It
+ will also allow us to indicate host affinity such that device<->device and
+ host<->device can be identified in the IR structure. Today all affinities
+ are no-op'ed and assumed to be 'current device'.
+ }];
+
+ // TODO(benvanik): affinity.
+ let parameters = (ins);
+
+ let valueType = NoneType;
+
+ let extraClassDeclaration = [{
+ // Returns an affinity active for the given operation.
+ // This will recursively walk parent operations until one with the
+ // `stream.affinity` attribute is found.
+ static AffinityAttr lookup(Operation *op);
+
+ // Returns true if |desiredAffinity| (if any) is compatible with
+ // |requiredAffinity|.
+ static bool areCompatible(AffinityAttr desiredAffinity,
+ AffinityAttr requiredAffinity);
+ }];
+}
+
+def Stream_Favor_Debug : I32EnumAttrCase<"Debug", 0, "debug">;
+def Stream_Favor_Concurrency : I32EnumAttrCase<"Concurrency", 1, "concurrency">;
+def Stream_Favor_MinPeakMemory : I32EnumAttrCase<"MinPeakMemory", 2, "min-peak-memory">;
+def Stream_FavorAttr :
+ I32EnumAttr<"Favor", "IREE partitioning bias", [
+ Stream_Favor_Debug,
+ Stream_Favor_Concurrency,
+ Stream_Favor_MinPeakMemory,
+ ]> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+}
+
+def Stream_PartitioningConfigAttr :
+ AttrDef<Stream_Dialect, "PartitioningConfig", [
+ DeclareAttrInterfaceMethods<SubElementAttrInterface>,
+ ]> {
+ let mnemonic = "partitioning_config";
+ let summary = [{defines partitioning configuration}];
+ let description = [{
+ Configures the partitioning algorithm to use and its configuration.
+ Partitioning is useful to adjust when scheduling behavior of targets is
+ radically different - such as single-threaded vs. multi-threaded CPUs or
+ bespoke ML accelerators vs. general purpose GPUs. This mechanism controls
+ the amount of concurrency, parallelism, memory consumption, and latency.
+ }];
+
+ // TODO(benvanik): partitioning config.
+ let parameters = (ins
+ "IREE::Stream::FavorAttr":$favor
+ );
+
+ let valueType = NoneType;
+
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "IREE::Stream::FavorAttr":$favor), [{
+ return $_get(favor.getContext(), favor);
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ // Returns a partitioning config active for the given operation.
+ // This will recursively walk parent operations until one with the
+ // `stream.partitioning` attribute is found.
+ static PartitioningConfigAttr lookup(Operation *op);
+ }];
+}
+
+def Stream_ResourceConfigAttr :
+ AttrDef<Stream_Dialect, "ResourceConfig", []> {
+ let mnemonic = "resource_config";
+
+ let summary = [{defines resource constraints configuration}];
+ let description = [{
+ Defines resource storage constraints. These allow for packing and layout
+ algorithms to ensure they are producing usable results on target devices.
+ }];
+
+ // TODO(benvanik): this was just copied from the HAL; probably good to make it
+ // more generic such that we can classify entire device families instead of
+ // specific values like this. As-is this is a phase-ordering violation.
+ let parameters = (ins
+ // The maximum size of a memory allocation that can be created, even if
+ // there is more space available in the heap.
+ "int64_t":$maxAllocationSize,
+ // The minimum required alignment, in bytes, for offsets used in runtime
+ // resource bindings. Offset values (both dynamic and static) must be an
+ // integer multiple of this limit.
+ "int64_t":$minBufferOffsetAlignment,
+ // The maximum value that can be specified for size ranges of resource
+ // bindings. The underlying allocation may be larger than this but only
+ // up to this amount will be visible to kernels.
+ "int64_t":$maxBufferRange,
+ // The minimum required alignment, in bytes, for size ranges of resources
+ // bindings.
+ "int64_t":$minBufferRangeAlignment
+ );
+
+ let valueType = NoneType;
+
+ let extraClassDeclaration = [{
+ // Returns the intersection (most conservative) constraints |lhs| ∩ |rhs|.
+ static ResourceConfigAttr
+ intersectBufferConstraints(ResourceConfigAttr lhs, ResourceConfigAttr rhs);
+
+ // Returns a resource config compatible with the host.
+ // These must only be used with resources when it is known that the device
+ // is local or has unified memory.
+ static ResourceConfigAttr
+ getDefaultHostConstraints(MLIRContext *context);
+
+ // Returns a resource config active for the given operation.
+ // This will recursively walk parent operations until one with the
+ // `stream.resources` attribute is found, an affinity specifies a
+ // configuration, or as a fallback returns a conservative configuration.
+ static ResourceConfigAttr lookup(Operation *op);
+ }];
+}
+
+def Stream_ResourceAccess_None : BitEnumAttrCase<"None", 0x0000>;
+def Stream_ResourceAccess_Read : BitEnumAttrCase<"Read", 0x0001>;
+def Stream_ResourceAccess_Write : BitEnumAttrCase<"Write", 0x0002>;
+def Stream_ResourceAccessBitfieldAttr :
+ BitEnumAttr<"ResourceAccessBitfield", "valid ResourceAccess", [
+ Stream_ResourceAccess_None,
+ Stream_ResourceAccess_Read,
+ Stream_ResourceAccess_Write,
+ ]> {
+ let cppNamespace = "mlir::iree_compiler::IREE::Stream";
+}
+def Stream_ResourceAccessArrayAttr :
+ TypedArrayAttrBase<Stream_ResourceAccessBitfieldAttr,
+ "access array attribute"> {}
+
+//===----------------------------------------------------------------------===//
+// Stream synchronization types
+//===----------------------------------------------------------------------===//
+
+def Stream_Timepoint : TypeDef<Stream_Dialect, "Timepoint", [
+ Util_GlobalTypeInterface,
+]> {
+ let mnemonic = "timepoint";
+
+ let summary = [{a timepoint indicating execution availability}];
+ let description = [{
+ Represents a point in the execution timeline that when resolved indicates
+ that all of the execution prior to this timepoint has completed and the
+ results of the execution are available for use. This includes transitive
+ dependencies as well; if timepoint B is dependent on timepoint A then when
+ B is available so too must be A.
+ }];
+
+ // TODO(benvanik): track affinity so we know where timepoints come from.
+ let parameters = (ins);
+}
+
+def Stream_TimepointAttr : AttrDef<Stream_Dialect, "Timepoint", []> {
+ let mnemonic = "timepoint";
+ let summary = [{an immediately-resolved timepoint}];
+ let description = [{}];
+ let parameters = (ins AttributeSelfTypeParameter<"">:$type);
+ let valueType = Stream_Timepoint;
+ let typeBuilder = "IREE::Stream::TimepointType::get($_value.getContext())";
+ let constBuilderCall = [{
+ IREE::Stream::TimepointAttr::get(
+ $_builder.getContext(),
+ IREE::Stream::TimepointType::get($_value.getContext()));
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Stream resource value types
+//===----------------------------------------------------------------------===//
+
+// Indicates a value whose lifetime is not yet analyzed.
+def Stream_Lifetime_Unknown : I32EnumAttrCase<"Unknown", 0, "*">;
+// An externally managed value.
+def Stream_Lifetime_External : I32EnumAttrCase<"External", 1, "external">;
+// A staging buffer for uploads/downloads.
+def Stream_Lifetime_Staging : I32EnumAttrCase<"Staging", 2, "staging">;
+// A short-lived value that is used across streams.
+def Stream_Lifetime_Transient : I32EnumAttrCase<"Transient", 3, "transient">;
+// A long-lived value that persists across streams.
+def Stream_Lifetime_Variable : I32EnumAttrCase<"Variable", 4, "variable">;
+// An immutable value that persist for the duration of the program.
+def Stream_Lifetime_Constant : I32EnumAttrCase<"Constant", 5, "constant">;
+def Stream_LifetimeAttr :
+ I32EnumAttr<"Lifetime", "IREE Stream value lifetime", [
+ Stream_Lifetime_Unknown,
+ Stream_Lifetime_External,
+ Stream_Lifetime_Staging,
+ Stream_Lifetime_Transient,
+ Stream_Lifetime_Variable,
+ Stream_Lifetime_Constant,
+ ]> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+}
+
+def Stream_AnyResource : Type<
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ "any stream-compatible type">;
+
+// TODO(benvanik): support other types; the interface may be enough.
+def Stream_TransferType : AnyTypeOf<[
+ Stream_AnyResource,
+ Stream_Tensor,
+ Stream_PrimitiveType,
+]>;
+
+def Stream_Resource : TypeDef<Stream_Dialect, "Resource", [
+ Util_ReferenceType,
+ Util_SizeAwareType,
+ DeclareTypeInterfaceMethods<Util_GlobalTypeInterface, [
+ "isAccessStorageCompatible",
+ ]>,
+]> {
+ let mnemonic = "resource";
+
+ let summary = [{a managed resource}];
+ let description = [{
+ Stream external values represent asynchronously-available and sequenced
+ values that are owned and managed by external code - such as those passed in
+ or out of the program entry points. Though external values are managed
+ during an invocation the same as other stream values the visibility into
+ them does not extend outside of the invocation they are provided to.
+
+ Stream values are not usable directly outside of a stream execution or
+ transfer operation. If the contents of the value are needed they must first
+ be transferred via `stream.transfer` - which may incur a copy.
+ }];
+
+ let parameters = (ins
+ "IREE::Stream::Lifetime":$lifetime
+ );
+
+ let builders = [
+ TypeBuilder<(ins), [{
+ return $_get($_ctxt, IREE::Stream::Lifetime::Unknown);
+ }]>,
+ TypeBuilderWithInferredContext<(ins
+ "IREE::Stream::LifetimeAttr":$lifetime
+ ), [{
+ return $_get(lifetime.getContext(), lifetime.getValue());
+ }]>,
+ ];
+}
+
+def Stream_ResourceLifetimeUnknown : CPred<[{
+ $_self.cast<IREE::Stream::ResourceType>().getLifetime() ==
+ IREE::Stream::Lifetime::Unknown
+}]>;
+def Stream_ResourceLifetimeExternal : CPred<[{
+ $_self.cast<IREE::Stream::ResourceType>().getLifetime() ==
+ IREE::Stream::Lifetime::External
+}]>;
+def Stream_ResourceLifetimeStaging : CPred<[{
+ $_self.cast<IREE::Stream::ResourceType>().getLifetime() ==
+ IREE::Stream::Lifetime::Staging
+}]>;
+def Stream_ResourceLifetimeTransient : CPred<[{
+ $_self.cast<IREE::Stream::ResourceType>().getLifetime() ==
+ IREE::Stream::Lifetime::Transient
+}]>;
+def Stream_ResourceLifetimeVariable : CPred<[{
+ $_self.cast<IREE::Stream::ResourceType>().getLifetime() ==
+ IREE::Stream::Lifetime::Variable
+}]>;
+def Stream_ResourceLifetimeConstant : CPred<[{
+ $_self.cast<IREE::Stream::ResourceType>().getLifetime() ==
+ IREE::Stream::Lifetime::Constant
+}]>;
+
+def Stream_UnknownResource : DialectType<Stream_Dialect, And<[
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ Stream_ResourceLifetimeUnknown,
+]>, "resource"> {
+ let description = [{
+ A stream resource that has not yet had its lifetime calculated.
+ }];
+}
+
+def Stream_ExternalResource : DialectType<Stream_Dialect, And<[
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ Stream_ResourceLifetimeExternal,
+]>, "external resource"> {
+ let description = [{
+ Stream external values represent asynchronously-available and sequenced
+ values that are owned and managed by external code - such as those passed in
+ or out of the program entry points. Though external values are managed
+ during an invocation the same as other stream values the visibility into
+ them does not extend outside of the invocation they are provided to.
+
+ Stream values are not usable directly outside of a stream execution or
+ transfer operation. If the contents of the value are needed they must first
+ be transferred via `stream.transfer` - which may incur a copy.
+ }];
+}
+
+def Stream_StagingResource : DialectType<Stream_Dialect, And<[
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ Stream_ResourceLifetimeStaging,
+]>, "staging resource"> {
+ let description = [{
+ Stream upload/download staging resource. These are used outside of streams
+ and then transferred to other stream resources such as variables or
+ transients for use inside of streams. Dispatches and several other
+ operations cannot directly operate on these resources.
+ }];
+}
+
+def Stream_TransientResource : DialectType<Stream_Dialect, And<[
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ Stream_ResourceLifetimeTransient,
+]>, "transient resource"> {
+ let description = [{
+ Stream transients represent asynchronously-available and sequenced values
+ that have a short lifetime - often only passed between stream executions.
+ It is expected that transient values are not stored in global state and have
+ minimal lifetime as they may be heavily pooled or suballocated.
+
+ Stream values are not usable directly outside of a stream execution or
+ transfer operation. If the contents of the value are needed they must first
+ be transferred via `stream.transfer` - which may incur a copy.
+ }];
+}
+
+def Stream_VariableResource : DialectType<Stream_Dialect, And<[
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ Stream_ResourceLifetimeVariable,
+]>, "variable resource"> {
+ let description = [{
+ Stream variables represent asynchronously-available and sequenced values
+ that have a long lifetime relative to the work being performed on them.
+ These variables are often stored in global state and may live for the entire
+ duration of the program.
+
+ Stream values are not usable directly outside of a stream execution or
+ transfer operation. If the contents of the value are needed they must first
+ be transferred via `stream.transfer` - which may incur a copy.
+ }];
+}
+
+def Stream_ConstantResource : DialectType<Stream_Dialect, And<[
+ CPred<"$_self.isa<IREE::Stream::ResourceType>()">,
+ Stream_ResourceLifetimeConstant,
+]>, "constant resource"> {
+ let description = [{
+ Stream constants are immutable values that are available for the lifetime of
+ the program once initialized.
+ }];
+}
+
+def Stream_AnyStreamResource : AnyTypeOf<[
+ Stream_UnknownResource,
+ Stream_ExternalResource,
+ Stream_TransientResource,
+ Stream_VariableResource,
+ Stream_ConstantResource,
+]>;
+
+//===----------------------------------------------------------------------===//
+// Executable bindings
+//===----------------------------------------------------------------------===//
+
+def Stream_Binding : TypeDef<Stream_Dialect, "Binding", []> {
+ let mnemonic = "binding";
+
+ let summary = [{a managed resource binding into an executable scope}];
+ let description = [{
+ A resource binding available within an executable dispatch function.
+ The bindings map 1:1 with the resources bound during dispatch operations.
+ }];
+
+ // TODO(benvanik): carry lifetime like resources.
+ let parameters = (ins);
+
+ let builders = [
+ TypeBuilder<(ins), [{
+ return $_get($_ctxt);
+ }]>,
+ ];
+}
+
+def Stream_AnyBinding : AnyTypeOf<[
+ Stream_Binding,
+]>;
+
+//===----------------------------------------------------------------------===//
+// Stream op traits
+//===----------------------------------------------------------------------===//
+
+def Stream_TensorPhaseOp : NativeOpTrait<"IREE::Stream::TensorPhaseOp">;
+def Stream_AsyncPhaseOp : NativeOpTrait<"IREE::Stream::AsyncPhaseOp">;
+def Stream_CmdPhaseOp : NativeOpTrait<"IREE::Stream::CmdPhaseOp">;
+
+#endif // IREE_DIALECT_STREAM_BASE
diff --git a/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp b/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
new file mode 100644
index 0000000..d283dbc
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
@@ -0,0 +1,90 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+namespace {
+
+// Used to control inlining behavior.
+struct StreamInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ bool isLegalToInline(Operation *call, Operation *callable,
+ bool wouldBeCloned) const final {
+ // Sure!
+ return true;
+ }
+ bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
+ BlockAndValueMapping &valueMapping) const final {
+ // Sure!
+ return true;
+ }
+
+ bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
+ BlockAndValueMapping &valueMapping) const final {
+ // Sure!
+ return true;
+ }
+};
+
+struct StreamFolderInterface : public DialectFoldInterface {
+ using DialectFoldInterface::DialectFoldInterface;
+
+ bool shouldMaterializeInto(Region *region) const override {
+ // TODO(benvanik): redirect constants to the region scope when small.
+ return false;
+ }
+};
+
+} // namespace
+
+StreamDialect::StreamDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context, TypeID::get<StreamDialect>()) {
+ registerAttributes();
+ registerTypes();
+
+#define GET_OP_LIST
+ addOperations<
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.cpp.inc"
+ >();
+ addInterfaces<StreamInlinerInterface, StreamFolderInterface>();
+}
+
+Operation *StreamDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ if (mlir::ConstantOp::isBuildableWith(value, type)) {
+ return builder.create<mlir::ConstantOp>(loc, type, value);
+ } else if (arith::ConstantOp::isBuildableWith(value, type)) {
+ return builder.create<arith::ConstantOp>(loc, type, value);
+ } else if (value.isa<IREE::Stream::TimepointAttr>()) {
+ return builder.create<IREE::Stream::TimepointImmediateOp>(loc);
+ }
+ return nullptr;
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/IR/StreamDialect.h b/iree/compiler/Dialect/Stream/IR/StreamDialect.h
new file mode 100644
index 0000000..ea76dd2
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamDialect.h
@@ -0,0 +1,48 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMDIALECT_H_
+#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+class StreamDialect : public Dialect {
+ public:
+ explicit StreamDialect(MLIRContext *context);
+ static StringRef getDialectNamespace() { return "stream"; }
+
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
+
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
+ void printAttribute(Attribute attr, DialectAsmPrinter &p) const override;
+
+ Type parseType(DialectAsmParser &parser) const override;
+ void printType(Type type, DialectAsmPrinter &p) const override;
+
+ static bool isDialectOp(Operation *op) {
+ return op && op->getDialect() &&
+ op->getDialect()->getNamespace() == getDialectNamespace();
+ }
+
+ private:
+ void registerAttributes();
+ void registerTypes();
+};
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAMDIALECT_H_
diff --git a/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
new file mode 100644
index 0000000..d63d942
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
@@ -0,0 +1,118 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_STREAM_INTERFACES
+#define IREE_DIALECT_STREAM_INTERFACES
+
+include "iree/compiler/Dialect/Util/IR/UtilBase.td"
+
+//===----------------------------------------------------------------------===//
+// IREE::Stream::AffinityOpInterface
+//===----------------------------------------------------------------------===//
+
+def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> {
+ let description = [{
+ TBD. Used to denote a stream affinity for ops and specify the kind of
+ environment the ops are expected run in.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the stream affinity for the op, indicating where it should run.
+ }],
+ /*retTy=*/"IREE::Stream::AffinityAttr",
+ /*methodName=*/"getAffinity",
+ /*args=*/(ins),
+ /*methodBody=*/[{
+ return $_self->getAttr("affinity").template dyn_cast_or_null<IREE::Stream::AffinityAttr>();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Sets the stream affinity for the op, indicating where it should run.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"setAffinity",
+ /*args=*/(ins "IREE::Stream::AffinityAttr":$value),
+ /*methodBody=*/[{
+ $_self->setAttr("affinity", value);
+ }]
+ >,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// IREE::Stream::StreamableOpInterface
+//===----------------------------------------------------------------------===//
+
+def Stream_StreamableOp : OpInterface<"StreamableOpInterface"> {
+ let description = [{
+ Interface for ops that can be asynchronous executed in a streaming context.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if the op is slicing out/in memory instead of real work.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isMetadata",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if the op should be cloned into consumer streams.
+ These ops should be cheaper to recompute than to transfer their contents
+ across streams (such as splats).
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"preferCloneToConsumers",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// IREE::Stream::SubviewEffectOpInterface
+//===----------------------------------------------------------------------===//
+
+def Stream_SubviewEffectOp : OpInterface<"SubviewEffectOpInterface"> {
+ let description = [{
+ Interface for ops that operate on subviews of resources used to query the
+ memory effects for subviews on operands.
+ }];
+
+ let methods = [
+ // TODO(benvanik): get memory effect + range of an operand
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// IREE::Stream::AsyncOpInterface
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): add interface for async ops:
+// getWaitTimepoint
+// setWaitTimepoint
+// getWaitResources
+// getSignalTimepoint
+// setSignalTimepoint
+// getSignalResources
+// + maybe mutable resource accessors? (MutableOperandRange)
+// This would let us rework code relying on AsyncExecuteOp/CmdExecuteOp to work
+// with both, and wait elision canonicalization patterns to be shared across
+// the async resource ops and execution ops.
+
+#endif // IREE_DIALECT_STREAM_INTERFACES
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
new file mode 100644
index 0000000..b77c2ef
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -0,0 +1,2224 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <algorithm>
+#include <numeric>
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/StringExtras.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+//===----------------------------------------------------------------------===//
+// Utilities shared across patterns
+//===----------------------------------------------------------------------===//
+
+// Returns the stream.yield op in |block| if it is the only op.
+//
+// Example:
+// stream.async.concurrent ... {
+// stream.yield
+// }
+static Optional<IREE::Stream::YieldOp> getYieldIfOnlyOp(Block &block) {
+ if (block.empty()) return llvm::None;
+ if (&block.front() != &block.back()) return llvm::None;
+ auto yieldOp = dyn_cast<IREE::Stream::YieldOp>(block.back());
+ if (yieldOp) return yieldOp;
+ return llvm::None;
+}
+
+// Finds the insertion point before |targetOp| and after |earliestOp| that would
+// not oscillate if an op was moved there. Oscillations can occur if there are
+// multiple ops inserted before a single op as insertion order based on
+// canonicalization is undefined.
+//
+// Example:
+// %0 = op.a
+// %1 = op.b
+// %2 = op.c %0, %1
+// If %0 and %1 are sunk to %2 the ordering will depend on which sink pattern
+// runs first and each of the patterns will fight trying to sink lower than the
+// other.
+static Block::iterator findInsertionPointBefore(Operation *earliestOp,
+ Operation *targetOp) {
+ // Check if ops between this and the target are all used by the target.
+ // If they are, we skip sinking so that we don't get stuck in an infinite loop
+ // if there are two splats used by the same op (or another pattern sinking).
+ if (earliestOp->getBlock() == targetOp->getBlock()) {
+ SmallPtrSet<Operation *, 4> producerOps;
+ for (auto operand : targetOp->getOperands()) {
+ if (operand.getDefiningOp()) {
+ producerOps.insert(operand.getDefiningOp());
+ }
+ }
+ bool allUsed = true;
+ for (auto it = Block::iterator(earliestOp); it != Block::iterator(targetOp);
+ ++it) {
+ if (!producerOps.contains(&*it)) {
+ allUsed = false;
+ break;
+ }
+ }
+ if (allUsed) return Block::iterator(earliestOp);
+ }
+ return Block::iterator(targetOp);
+}
+
+// Sinks |op| down to |targetOp|, ensuring that we don't oscillate.
+// Returns success if the op was sunk and failure if sinking was not needed.
+static LogicalResult sinkOp(Operation *op, Operation *targetOp) {
+ auto ip = findInsertionPointBefore(op, targetOp);
+ if (ip == Block::iterator(op)) return failure();
+ op->moveBefore(targetOp);
+ return success();
+}
+
+// Sets |rewriter| to point immediately before the parent execution region.
+// Example:
+// %0 =
+// <-- insertion point set to here -->
+// stream.async.execute ... {
+// %1 = op
+// }
+static void setInsertionPointToParentExecutionScope(Operation *op,
+ PatternRewriter &rewriter) {
+ if (auto parentOp = op->getParentOfType<AsyncExecuteOp>()) {
+ rewriter.setInsertionPoint(parentOp);
+ } else if (auto parentOp = op->getParentOfType<CmdExecuteOp>()) {
+ rewriter.setInsertionPoint(parentOp);
+ } else {
+ llvm_unreachable("must be nested within an execution region");
+ }
+}
+
+namespace {
+
+// Erases an op if it has no uses.
+// This is to support ops that are "pure" but can't be marked as such because
+// the MLIR CSE pass would deduplicate them.
+template <typename Op>
+struct ElideUnusedOp : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ if (!op.use_empty()) return failure();
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+// Materialize copy-on-write (🐄) ops where required for |rootValue|.
+// Only valid in tensor/async ops - don't use with stream.cmd.*.
+static bool materializeCOW(Location loc, Value rootValue, OpBuilder &builder) {
+ auto valueType = rootValue.getType().dyn_cast<IREE::Stream::ResourceType>();
+ if (!valueType) return false;
+
+ // If our rootValue is a constant then we need to ensure that we aren't
+ // tied to a constant operand. If we are we need to clone to a
+ // non-constant value.
+ bool forceClone = valueType.getLifetime() == IREE::Stream::Lifetime::Constant;
+
+ // Identify if we need to insert a copy-on-write clone.
+ // We do this per use as a single consuming op may use the result of this
+ // multiple times - some tied and some not - and if it has it tied several
+ // times each will need its own clone.
+ struct TiedUse {
+ Operation *user;
+ unsigned operandIndex;
+ Value value;
+ };
+ SmallVector<TiedUse> tiedUses;
+ unsigned untiedUses = 0;
+ for (auto &use : rootValue.getUses()) {
+ if (isa<IREE::Stream::TimepointAwaitOp>(use.getOwner())) continue;
+ auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(use.getOwner());
+ bool isTied = tiedOp && tiedOp.isOperandTied(use.getOperandNumber());
+ if (isTied) {
+ tiedUses.push_back({use.getOwner(), use.getOperandNumber(), rootValue});
+ } else {
+ ++untiedUses;
+ }
+ }
+ if (tiedUses.empty()) {
+ // All uses are as normal capturing SSA values.
+ return false;
+ } else if (tiedUses.size() == 1 && untiedUses == 0 && !forceClone) {
+ // Only one use and it's tied - we've already reserved our results for it.
+ return false;
+ }
+
+ // Mixed/multiple tied uses. Clone for each tied use but leave the untied
+ // ones referencing us.
+ IREE::Stream::AffinityAttr sourceAffinity;
+ if (auto affinityOp = dyn_cast_or_null<IREE::Stream::AffinityOpInterface>(
+ rootValue.getDefiningOp())) {
+ sourceAffinity = affinityOp.getAffinity();
+ }
+ for (auto &tiedUse : tiedUses) {
+ auto cloneLoc =
+ FusedLoc::get(builder.getContext(), {loc, tiedUse.user->getLoc()});
+
+ builder.setInsertionPoint(tiedUse.user);
+
+ auto sizeAwareType =
+ tiedUse.value.getType()
+ .template cast<IREE::Util::SizeAwareTypeInterface>();
+ auto targetSize =
+ sizeAwareType.queryValueSize(cloneLoc, tiedUse.value, builder);
+
+ IREE::Stream::AffinityAttr targetAffinity;
+ if (auto affinityOp =
+ dyn_cast<IREE::Stream::AffinityOpInterface>(tiedUse.user)) {
+ targetAffinity = affinityOp.getAffinity();
+ }
+
+ auto cloneOp = builder.create<IREE::Stream::AsyncCloneOp>(
+ cloneLoc, tiedUse.value.getType(), tiedUse.value, targetSize,
+ targetSize, targetAffinity ? targetAffinity : sourceAffinity);
+ tiedUse.user->setOperand(tiedUse.operandIndex, cloneOp.result());
+ }
+
+ return true;
+}
+
+// Materialize copy-on-write (🐄) ops where required.
+// This models what a runtime normally does with copy-on-write but uses the
+// information we have in the SSA use-def chain to identify ties that write and
+// covering reads.
+template <typename Op>
+struct MaterializeCOW : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ bool didChange = false;
+
+ // Handle results of this op (primary use case).
+ for (auto result : op->getResults()) {
+ didChange = materializeCOW(op.getLoc(), result, rewriter) || didChange;
+ }
+
+ return didChange ? success() : failure();
+ }
+};
+
+// Ties the results of execution region to their operands when the region
+// operations are tied throughout the entire body.
+//
+// Example:
+// %ret:2 = stream.async.execute with(%src as %arg0) -> !stream.resource<*> {
+// %2 = stream.async.dispatch ... (%arg0) -> %arg0
+// stream.yield %2
+// }
+// ->
+// %ret:2 = stream.async.execute with(%src as %arg0) -> %src {
+// %2 = stream.async.dispatch ... (%arg0) -> %arg0
+// stream.yield %2
+// }
+template <typename Op>
+struct TieRegionResults : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ assert(op.getRegion().getBlocks().size() == 1 &&
+ "only one stream block supported");
+ bool didModify = false;
+ for (auto yieldOp : op.template getOps<IREE::Stream::YieldOp>()) {
+ for (auto result : llvm::enumerate(yieldOp.operands())) {
+ if (op.getTiedResultOperandIndex(result.index()).hasValue()) {
+ continue; // Already tied.
+ }
+ auto baseValue =
+ IREE::Util::TiedOpInterface::findTiedBaseValue(result.value());
+ if (auto blockArg = baseValue.template dyn_cast<BlockArgument>()) {
+ unsigned operandIndex = blockArg.getArgNumber();
+ op.setTiedResultOperandIndex(result.index(), operandIndex);
+ didModify = true;
+ }
+ }
+ }
+ return didModify ? success() : failure();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// stream.resource.alloc
+//===----------------------------------------------------------------------===//
+
+void ResourceAllocOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): sink to first user.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.alloca
+//===----------------------------------------------------------------------===//
+
+void ResourceAllocaOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): sink to first user.
+ // TODO(benvanik): elide if only user is dealloc.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.dealloca
+//===----------------------------------------------------------------------===//
+
+void ResourceDeallocaOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): move up to producer of timepoint.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.size
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ResourceSizeOp::fold(ArrayRef<Attribute> operands) {
+ auto sizeAwareType =
+ operand().getType().cast<IREE::Util::SizeAwareTypeInterface>();
+ return sizeAwareType.findSizeValue(operand(), *this);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.map
+//===----------------------------------------------------------------------===//
+
+void ResourceMapOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): fold subviews up into maps to limit range.
+ results.insert<ElideUnusedOp<ResourceMapOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.try_map
+//===----------------------------------------------------------------------===//
+
+void ResourceTryMapOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): fold subviews up into maps to limit range.
+ // TODO(benvanik): if mapping for staging then turn into a map?
+ results.insert<ElideUnusedOp<ResourceTryMapOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.load
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview offsets into loads.
+//
+// Example:
+// %0 = stream.resource.subview %src[%subview_offset] ... -> {%subview_length}
+// %1 = stream.resource.load %0[%offset]
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// %1 = stream.resource.load %src[%new_offset]
+struct FoldSubviewIntoLoadOp : public OpRewritePattern<ResourceLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ResourceLoadOp op,
+ PatternRewriter &rewriter) const override {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(op.source());
+ if (!subviewOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(), op.source_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.sourceMutable().assign(subviewOp.source());
+ op.source_sizeMutable().assign(subviewOp.source_size());
+ op.source_offsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void ResourceLoadOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): if staging resource comes from splat (through transfers)
+ // then pull splat value.
+ // TODO(benvanik): combine multiple loads from the same target if contiguous.
+ // TODO(benvanik): value->transfer->load -> value->slice->transfer->load?
+ results.insert<FoldSubviewIntoLoadOp>(context);
+ results.insert<ElideUnusedOp<ResourceLoadOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.store
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview offsets into stores.
+//
+// Example:
+// %0 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
+// stream.resource.store %c123_i32, %0[%offset]
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.resource.store %c123_i32, %dst[%new_offset]
+struct FoldSubviewIntoStoreOp : public OpRewritePattern<ResourceStoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ResourceStoreOp op,
+ PatternRewriter &rewriter) const override {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(op.target());
+ if (!subviewOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(), op.target_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.targetMutable().assign(subviewOp.source());
+ op.target_sizeMutable().assign(subviewOp.source_size());
+ op.target_offsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void ResourceStoreOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): combine multiple stores to the same target if contiguous.
+ // TODO(benvanik): if value is a constant splat then turn into fill?
+ results.insert<FoldSubviewIntoStoreOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.pack
+//===----------------------------------------------------------------------===//
+
+LogicalResult ResourcePackOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ Builder builder(getContext());
+
+ // If there are no slices then the entire pack results in a zero-length slab.
+ if (packed_offsets().empty()) {
+ results.push_back(builder.getZeroAttr(builder.getIndexType()));
+ return success();
+ }
+
+ // If there's a single slice then we just use that as there is no packing to
+ // perform.
+ if (packed_offsets().size() == 1) {
+ // Total length is the slice size and offset is always either 0 or the
+ // provided optional base offset.
+ results.push_back(dynamic_slice_sizes()[0]);
+ if (offset()) {
+ results.push_back(offset());
+ } else {
+ results.push_back(builder.getZeroAttr(builder.getIndexType()));
+ }
+ return success();
+ }
+
+ return failure();
+}
+
+namespace {
+
+// Propagates base offsets on a pack op to its results.
+// This allows for better folding of the results after packing has completed.
+// The offset value is just a convenience for when splitting pack ops and has
+// no impact on the actual packing operation.
+struct PropagateResourcePackBaseOffset
+ : public OpRewritePattern<ResourcePackOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ResourcePackOp op,
+ PatternRewriter &rewriter) const override {
+ // Offset is optional.
+ auto baseOffset = op.offset();
+ if (!baseOffset) return failure();
+
+ // We always strip the offset here.
+ rewriter.updateRootInPlace(op, [&]() { op.offsetMutable().clear(); });
+
+ // Zero offsets don't do anything and can just be removed so we can avoid
+ // inserting a bunch of additional IR.
+ if (auto constantOp = dyn_cast_or_null<arith::ConstantIndexOp>(
+ baseOffset.getDefiningOp())) {
+ if (constantOp.value() == 0) {
+ return success();
+ }
+ }
+
+ // Propagate the offset to all returned slice offsets.
+ rewriter.setInsertionPointAfter(op);
+ for (auto sliceOffset : op.packed_offsets()) {
+ auto addOp =
+ rewriter.create<arith::AddIOp>(op.getLoc(), baseOffset, sliceOffset);
+ SmallPtrSet<Operation *, 1> exclusions;
+ exclusions.insert(addOp);
+ sliceOffset.replaceAllUsesExcept(addOp.result(), exclusions);
+ }
+
+ return success();
+ }
+};
+
+// Sorts and compacts the slice intervals into a dense ascending order set.
+// This is not required by the packing algorithm but yields more
+// consistent-looking IR and makes the range overlaps easier to see for us
+// meatbags.
+//
+// Example:
+// %0:3 = stream.resource.pack slices({
+// [1, 2] = %size,
+// [0, 4] = %size,
+// }) : index
+// ->
+// %0:3 = stream.resource.pack slices({
+// [0, 4] = %size,
+// [1, 2] = %size,
+// }) : index
+struct CanonicalizeResourcePackIntervals
+ : public OpRewritePattern<ResourcePackOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ResourcePackOp op,
+ PatternRewriter &rewriter) const override {
+ // Get the slices in a possibly unsorted order and sort.
+ auto slices = op.getSlices();
+ std::stable_sort(slices.begin(), slices.end());
+
+ // See if the sorted order is different than how they are stored in the op.
+ bool orderChanged = false;
+ for (auto it : llvm::zip(slices, op.packed_offsets())) {
+ if (std::get<0>(it).packedOffset != std::get<1>(it)) {
+ orderChanged = true;
+ break;
+ }
+ }
+ if (!orderChanged) return failure();
+
+ // TODO(benvanik): compact the slice ranges.
+
+ // Rebuild the op with the sorted values.
+ SmallVector<int64_t> lifetimeIntervals(slices.size() * 2);
+ SmallVector<Value> dynamicSliceSizes(slices.size());
+ for (size_t i = 0; i < slices.size(); ++i) {
+ const auto &slice = slices[i];
+ lifetimeIntervals[2 * i + 0] = slice.lifetimeStart;
+ lifetimeIntervals[2 * i + 1] = slice.lifetimeEnd;
+ dynamicSliceSizes[i] = slice.dynamicSize;
+ }
+ SmallVector<Type> packedOffsetTypes(slices.size(), rewriter.getIndexType());
+ auto newOp = rewriter.create<ResourcePackOp>(
+ op.getLoc(), op.total_length().getType(), packedOffsetTypes,
+ op.offset(), rewriter.getIndexArrayAttr(lifetimeIntervals),
+ dynamicSliceSizes, op.affinityAttr());
+
+ // Remap existing values to the new values.
+ op.total_length().replaceAllUsesWith(newOp.total_length());
+ for (size_t i = 0; i < newOp.packed_offsets().size(); ++i) {
+ slices[i].packedOffset.replaceAllUsesWith(newOp.packed_offsets()[i]);
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+} // namespace
+
+void ResourcePackOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<PropagateResourcePackBaseOffset>(context);
+ results.insert<CanonicalizeResourcePackIntervals>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.pack
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ResourceSubviewOp::fold(ArrayRef<Attribute> operands) {
+ if (source_size() == result_size()) {
+ // Entire range is covered; return it all.
+ return source();
+ }
+ return {};
+}
+
+namespace {
+
+// Folds subview -> subview to point at the original source resource with an
+// updated range.
+struct FoldResourceSubviewOps : public OpRewritePattern<ResourceSubviewOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ResourceSubviewOp op,
+ PatternRewriter &rewriter) const override {
+ auto parentOp = ResourceSubviewOp::findSubviewOp(op.source());
+ if (!parentOp) return failure();
+ auto fusedLoc = rewriter.getFusedLoc({parentOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, parentOp.source_offset(), op.source_offset());
+ auto newOp = rewriter.create<ResourceSubviewOp>(
+ fusedLoc, parentOp.source(), parentOp.source_size(), newOffset,
+ op.result_size());
+ rewriter.replaceOp(op, newOp.result());
+ return success();
+ }
+};
+
+} // namespace
+
+void ResourceSubviewOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldResourceSubviewOps>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.import
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorImportOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): if operand comes from export then fold.
+ return {};
+}
+
+void TensorImportOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): check operand and dedupe imports.
+ results.insert<MaterializeCOW<TensorImportOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.export
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorExportOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): if operand comes from import then fold.
+ return {};
+}
+
+void TensorExportOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): check operand and dedupe exports.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.sizeof
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.constant
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TensorConstantToSplat : public OpRewritePattern<TensorConstantOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorConstantOp constantOp,
+ PatternRewriter &rewriter) const override {
+ auto splatAttr = constantOp.value().dyn_cast<SplatElementsAttr>();
+ if (!splatAttr || !splatAttr.isSplat()) {
+ return rewriter.notifyMatchFailure(
+ constantOp,
+ "only constant splat attrs can be converted to splat ops");
+ }
+
+ auto splatElementAttr = splatAttr.getSplatValue();
+ auto splatValue = rewriter.create<arith::ConstantOp>(
+ constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
+ auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext());
+ auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ constantOp.getLoc(), rewriter.getIndexType(),
+ TypeAttr::get(constantOp.result_encoding()),
+ constantOp.result_encoding_dims(), /*affinity=*/nullptr);
+ auto splatOp = rewriter.create<TensorSplatOp>(
+ constantOp.getLoc(), resultType, splatValue,
+ constantOp.result_encoding(), constantOp.result_encoding_dims(),
+ resultSize,
+ /*affinity=*/nullptr);
+ rewriter.replaceOpWithNewOp<AsyncTransferOp>(
+ constantOp, constantOp.result().getType(), splatOp.result(), resultSize,
+ resultSize, /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ return success();
+ }
+};
+
+} // namespace
+
+void TensorConstantOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): if value is _mostly_ a splat, turn into splat + updates.
+ results.insert<TensorConstantToSplat>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.splat
+//===----------------------------------------------------------------------===//
+
+void TensorSplatOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ElideUnusedOp<TensorSplatOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.clone
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorCloneOp::fold(ArrayRef<Attribute> operands) {
+ auto users = result().getUsers();
+ if (!users.empty() && std::next(users.begin()) == users.end()) {
+ // If the second user is the end it means there's one user.
+ return source();
+ }
+ return {};
+}
+
+namespace {
+
+// Elides clones that don't do anything meaningful (like setting up a tie).
+struct ElideUnneededTensorClones : public OpRewritePattern<TensorCloneOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorCloneOp cloneOp,
+ PatternRewriter &rewriter) const override {
+ if (!IREE::Util::TiedOpInterface::hasAnyTiedUses(cloneOp.result())) {
+ rewriter.replaceOp(cloneOp, cloneOp.source());
+ return success();
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
+void TensorCloneOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): splat -> clone duplicates splat.
+ // TODO(benvanik): some way to reduce deep clone->clone->clone chains.
+ // TODO(benvanik): clone + slice => slice.
+ // TODO(benvanik): if both operand and result are used once then elide.
+ // (if not tied block/fn arguments)
+ results.insert<ElideUnneededTensorClones>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.slice
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorSliceOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): fold if source_size == result_size and affinity/lifetime.
+ return {};
+}
+
+void TensorSliceOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): turn into a transfer if target_size == update_size and
+ // affinity/lifetime differ.
+ // TODO(benvanik): splat->slice -> splat.
+ // TODO(benvanik): clone->slice -> slice.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.fill
+//===----------------------------------------------------------------------===//
+
+void TensorFillOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): if target_size == sizeof(value) turn into splat.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.update
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorUpdateOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): fold if target_size == update_size and affinity/lifetime.
+ return {};
+}
+
+void TensorUpdateOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): turn into a transfer if target_size == update_size and
+ // affinity/lifetime differ.
+ // TODO(benvanik): turn into fill if source is a splat.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.load
+//===----------------------------------------------------------------------===//
+
+void TensorLoadOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): splat + load -> splat value.
+ // TODO(benvanik): clone + ex load -> slice (ranged) + load.
+ // TODO(benvanik): slice + ex load -> slice (ranged) + load.
+ // TODO(benvanik): value->transfer->load -> value->slice->transfer->load?
+ // TODO(benvanik): combine multiple loads from the same target if contiguous.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.store
+//===----------------------------------------------------------------------===//
+
+void TensorStoreOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): if value is a constant splat then turn into fill.
+ // TODO(benvanik): combine multiple stores to the same target if contiguous.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.alloca
+//===----------------------------------------------------------------------===//
+
+void AsyncAllocaOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): alloca (staging) -> non-staging change to target.
+ // TODO(benvanik): alloca (non-staging) -> staging change to target.
+ // TODO(benvanik): sink to first user.
+ results.insert<MaterializeCOW<AsyncAllocaOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.constant
+//===----------------------------------------------------------------------===//
+
+void AsyncConstantOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): if value is a splat turn into splat.
+ // TODO(benvanik): if value is _mostly_ a splat, turn into splat + updates.
+ results.insert<MaterializeCOW<AsyncConstantOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.splat
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Sinks splat ops down to its consumers to avoid cases where we splat and then
+// keep that live/copy-on-write it.
+struct SinkSplatsToConsumers : public OpRewritePattern<AsyncSplatOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncSplatOp splatOp,
+ PatternRewriter &rewriter) const override {
+ auto users = llvm::to_vector<4>(splatOp->getUsers());
+ if (users.size() == 0) return failure();
+
+ // If we have a single user then we can sink right to it.
+ if (users.size() == 1) {
+ return sinkOp(splatOp, users.front());
+ }
+
+ // If we only have users in the same block then we can safely move to the
+ // first (as no change to cross-block SSA dominance can happen).
+ if (!splatOp.result().isUsedOutsideOfBlock(splatOp->getBlock())) {
+ Operation *targetOp = nullptr;
+ for (auto user : users) {
+ if (!targetOp || user->isBeforeInBlock(targetOp)) {
+ targetOp = user;
+ }
+ }
+ assert(targetOp);
+ return sinkOp(splatOp, targetOp);
+ }
+
+ // Redundant computation here, but only in cases where we have multiple
+ // users that may live outside the block the op is in.
+ DominanceInfo domInfo(splatOp->getParentOp());
+
+ // Find the common dominator block across all uses. This may be the
+ // entry block itself.
+ Block *commonDominator = users.front()->getBlock();
+ for (auto user : users) {
+ commonDominator =
+ domInfo.findNearestCommonDominator(commonDominator, user->getBlock());
+ }
+
+ // Find the first use within the dominator block (if any) so that we
+ // can sink down to it.
+ Operation *firstUserInDominator = commonDominator->getTerminator();
+ for (auto user : users) {
+ if (user->getBlock() == commonDominator) {
+ if (user->isBeforeInBlock(firstUserInDominator)) {
+ firstUserInDominator = user;
+ }
+ }
+ }
+
+ // Sink to the common dominator - which may not even use the op but will
+ // at least prevent us from doing extra work.
+ return sinkOp(splatOp, firstUserInDominator);
+ }
+};
+
+} // namespace
+
+void AsyncSplatOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(#6972): find splat+update-from and turn into fill.
+ // TODO(#6972): find splat+copy-from and turn into fill.
+ // TODO(#6972): find splat+update-into and turn into alloca+fill+update.
+ // TODO(#6972): find splat+copy-into and turn into alloca+fill+copy.
+ // TODO(#6972): clone instead of sinking to common dominator.
+ results.insert<SinkSplatsToConsumers>(context);
+ results.insert<ElideUnusedOp<AsyncSplatOp>>(context);
+ results.insert<MaterializeCOW<AsyncSplatOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.clone
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AsyncCloneOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): trivial elides when there are no tied users/one user.
+ return {};
+}
+
+namespace {
+
+// Clones ops that prefer to be cloned directly.
+// This prevents us from splatting out a value and then cloning that (keeping
+// the memory live/etc) instead of just splatting it again on-demand.
+//
+// Example:
+// %0 = stream.async.splat %c123_i32
+// %1 = stream.async.clone %0
+// ->
+// %1 = stream.async.splat %c123_i32
+struct PropagateClonableOps : public OpRewritePattern<AsyncCloneOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncCloneOp cloneOp,
+ PatternRewriter &rewriter) const override {
+ if (cloneOp.use_empty()) return failure();
+ auto sourceOp = dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(
+ cloneOp.source().getDefiningOp());
+ if (!sourceOp || !sourceOp.preferCloneToConsumers()) return failure();
+ for (auto &use : llvm::make_early_inc_range(cloneOp.result().getUses())) {
+ rewriter.setInsertionPoint(use.getOwner());
+ auto clonedOp = rewriter.clone(*sourceOp);
+ use.set(clonedOp->getResult(0));
+ }
+ if (cloneOp.use_empty()) {
+ rewriter.eraseOp(cloneOp);
+ }
+ return success();
+ }
+};
+
+} // namespace
+
+void AsyncCloneOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): some way to reduce deep clone->clone->clone chains.
+ results.insert<PropagateClonableOps>(context);
+ results.insert<ElideUnusedOp<AsyncCloneOp>>(context);
+ results.insert<MaterializeCOW<AsyncCloneOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.slice
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AsyncSliceOp::fold(ArrayRef<Attribute> operands) {
+ if (source_size() == result_size()) {
+ // Slicing entire source - just reroute to source.
+ // Note that this breaks copy-on-write semantics but will be fixed up during
+ // canonicalization if needed.
+ return source();
+ }
+ return {};
+}
+
+namespace {
+
+// Clones a splat op through a slice as a splat+slice is just a smaller splat.
+//
+// Example:
+// %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%sz0}
+// %1 = stream.async.slice %0[%c0 to %c128] ... {%c128}
+// ->
+// %1 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c128}
+struct PropagateSplatsThroughSlices : public OpRewritePattern<AsyncSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto splatOp = dyn_cast_or_null<IREE::Stream::AsyncSplatOp>(
+ sliceOp.source().getDefiningOp());
+ if (!splatOp) return failure();
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(
+ sliceOp, sliceOp.result().getType(), splatOp.value(),
+ sliceOp.result_size(), sliceOp.affinityAttr());
+ return success();
+ }
+};
+
+} // namespace
+
+void AsyncSliceOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): turn into a transfer if target_size == update_size and
+ // affinity/lifetime differ.
+ results.insert<PropagateSplatsThroughSlices>(context);
+ results.insert<ElideUnusedOp<AsyncSliceOp>>(context);
+ results.insert<MaterializeCOW<AsyncSliceOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.fill
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Turns fills that cover an entire target resource into splats.
+// This acts as a discard as it indicates we don't care about the previous
+// resource contents.
+//
+// Example:
+// %0 = stream.async.fill %cst, %dst[%c0 to %dstsz for %dstsz] ... {%dstsz}
+// ->
+// %0 = stream.async.splat %cst : f32 -> !stream.resource<*>{%dstsz}
+struct FlattenFullFillToSplat : public OpRewritePattern<AsyncFillOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncFillOp fillOp,
+ PatternRewriter &rewriter) const override {
+ if (fillOp.target_length() == fillOp.target_size()) {
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(
+ fillOp, fillOp.result().getType(), fillOp.value(),
+ fillOp.target_size(), fillOp.affinityAttr());
+ return success();
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
+void AsyncFillOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FlattenFullFillToSplat>(context);
+ results.insert<ElideUnusedOp<AsyncFillOp>>(context);
+ results.insert<MaterializeCOW<AsyncFillOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.update
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AsyncUpdateOp::fold(ArrayRef<Attribute> operands) {
+ if (update_size() == target_size()) {
+ // If updating the entire target then just replace with the update.
+ // Note that this breaks copy-on-write semantics but will be fixed up during
+ // canonicalization if needed.
+ return update();
+ }
+ return {};
+}
+
+namespace {
+
+// Turns a splat+update-from into a fill.
+//
+// Example:
+// %0 = stream.async.splat %c123_i32 ... {%c128}
+// %1 = stream.async.update %0, %dst[%c0 to %c128]
+// ->
+// %1 = stream.async.fill %c123_i32, %dst[%c0 to %c128 for %c128]
+struct CombineSplatUpdateFromToFill : public OpRewritePattern<AsyncUpdateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncUpdateOp updateOp,
+ PatternRewriter &rewriter) const override {
+ auto splatOp = dyn_cast_or_null<IREE::Stream::AsyncSplatOp>(
+ updateOp.update().getDefiningOp());
+ if (!splatOp) return failure();
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncFillOp>(
+ updateOp, updateOp.result().getType(), updateOp.target(),
+ updateOp.target_size(), updateOp.target_offset(), updateOp.target_end(),
+ updateOp.update_size(), splatOp.value(), updateOp.tied_operandsAttr(),
+ updateOp.affinityAttr());
+ return success();
+ }
+};
+
+// Turns slice+update-from into a copy.
+// This is equivalent behavior at runtime but better to schedule as a single
+// operation.
+//
+// This could pessimize memory consumption if the slice is far from the consumer
+// update: it's better to slice away a small part of a resource to retain than
+// keeping the whole one around.
+//
+// Example:
+// %0 = stream.async.slice %src[%c0 to %c128]
+// %1 = stream.async.update %0, %dst[%c0 to %c128]
+// ->
+// %1 stream.async.copy %src[%c0 to %c128], %dst[%c0 to %c128], %c128
+//
+// TODO(benvanik): evaluate if we want to do this in all cases - we may only
+// want if it there are users of the source after this op such that we wouldn't
+// be the op keeping the entire unsliced source resource live.
+struct CombineSliceUpdateFromToCopy : public OpRewritePattern<AsyncUpdateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncUpdateOp updateOp,
+ PatternRewriter &rewriter) const override {
+ auto sliceOp = dyn_cast_or_null<IREE::Stream::AsyncSliceOp>(
+ updateOp.update().getDefiningOp());
+ if (!sliceOp || sliceOp->getBlock() != updateOp->getBlock()) {
+ // Source is not a slice or a slice from out-of-block. We don't want to
+ // grow memory usage by sinking the slice here (we may slice into the
+ // body of a for loop, for example).
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCopyOp>(
+ updateOp, updateOp.result().getType(), updateOp.target(),
+ updateOp.target_size(), updateOp.target_offset(), updateOp.target_end(),
+ sliceOp.source(), sliceOp.source_size(), sliceOp.source_offset(),
+ sliceOp.source_end(), sliceOp.result_size(),
+ updateOp.tied_operandsAttr(), updateOp.affinityAttr());
+ return success();
+ }
+};
+
+} // namespace
+
+void AsyncUpdateOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): turn into a transfer if target_size == update_size and
+ // affinity/lifetime differ.
+ // TODO(#6972): updates into splats could become alloca + fill exclusive
+ // region + update into undefined contents (used in padding).
+ results.insert<CombineSplatUpdateFromToFill>(context);
+ results.insert<CombineSliceUpdateFromToCopy>(context);
+ results.insert<ElideUnusedOp<AsyncUpdateOp>>(context);
+ results.insert<MaterializeCOW<AsyncUpdateOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.copy
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Turns a copy from an entire resource into an update. Updates can be more
+// efficient during allocation as we know the producer can write directly into
+// the target.
+//
+// Example:
+// %2 = stream.async.copy %0[%c0 to %sz0], %1[%c0 to %sz1], %sz0
+// ->
+// %2 = stream.async.update %0, %1[%c0 to %sz1]
+struct AsyncCopyFullSourceToUpdate : public OpRewritePattern<AsyncCopyOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncCopyOp copyOp,
+ PatternRewriter &rewriter) const override {
+ if (copyOp.source_end() == copyOp.source_size()) {
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncUpdateOp>(
+ copyOp, copyOp.result().getType(), copyOp.target(),
+ copyOp.target_size(), copyOp.target_offset(), copyOp.target_end(),
+ copyOp.source(), copyOp.source_size(), copyOp.tied_operandsAttr(),
+ copyOp.affinityAttr());
+ return success();
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
+void AsyncCopyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<AsyncCopyFullSourceToUpdate>(context);
+ results.insert<ElideUnusedOp<AsyncCopyOp>>(context);
+ results.insert<MaterializeCOW<AsyncCopyOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.transfer
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AsyncTransferOp::fold(ArrayRef<Attribute> operands) {
+ if (auto sourceTransferOp =
+ dyn_cast_or_null<AsyncTransferOp>(source().getDefiningOp())) {
+ if (sourceTransferOp.source().getType() == result().getType() &&
+ sourceTransferOp.source_affinity() == result_affinity()) {
+ return sourceTransferOp.source();
+ }
+ }
+ return {};
+}
+
+namespace {
+
+// Elides transfer operations that are a no-op (from/to the same affinity and
+// same resource type).
+struct RedundantTransferElision : public OpRewritePattern<AsyncTransferOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncTransferOp transferOp,
+ PatternRewriter &rewriter) const override {
+ if (transferOp.source_affinityAttr() == transferOp.result_affinityAttr() &&
+ transferOp.source().getType() == transferOp.result().getType()) {
+ // Transfer performs no work, elide.
+ rewriter.replaceOp(transferOp, transferOp.source());
+ return success();
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
+void AsyncTransferOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): staging propagation (fill of staging -> fill on device).
+ results.insert<RedundantTransferElision>(context);
+ results.insert<ElideUnusedOp<AsyncTransferOp>>(context);
+ results.insert<MaterializeCOW<AsyncTransferOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.dispatch
+//===----------------------------------------------------------------------===//
+
+void AsyncDispatchOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): nothing? maybe tied type/lifetime updates?
+ results.insert<ElideUnusedOp<AsyncDispatchOp>>(context);
+ results.insert<MaterializeCOW<AsyncDispatchOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.execute
+//===----------------------------------------------------------------------===//
+
+// Adds await dependencies on |newTimepoints| to the op with an optional
+// |existingTimepoint| by possibly producing a new timepoint to await.
+// This may just pass through the provided timepoint or create a join based on
+// the existing await behavior of the op and the new values.
+static Value joinAwaitTimepoints(Location loc, Value existingTimepoint,
+ ArrayRef<Value> newTimepoints,
+ OpBuilder &builder) {
+ if (newTimepoints.empty()) {
+ // No new timepoints - preserve existing.
+ return existingTimepoint;
+ } else if (newTimepoints.size() == 1 && !existingTimepoint) {
+ // Adding a single new timepoint.
+ return newTimepoints.front();
+ }
+
+ // Materialize a join of the new timepoints + the existing (if present).
+ SmallVector<Value> joinTimepoints;
+ if (existingTimepoint) {
+ joinTimepoints.push_back(existingTimepoint);
+ }
+ llvm::append_range(joinTimepoints, newTimepoints);
+ return builder.create<IREE::Stream::TimepointJoinOp>(
+ loc, builder.getType<IREE::Stream::TimepointType>(), joinTimepoints);
+}
+
+namespace {
+
+// Elides waits that are known to be immediately resolved.
+//
+// Example:
+// %0 = stream.timepoint.immediate
+// %1 = stream.async.execute await(%0) => with(...)
+// ->
+// %1 = stream.async.execute with(...)
+struct ElideImmediateAsyncExecuteWaits
+ : public OpRewritePattern<AsyncExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ bool isImmediate =
+ op.await_timepoint() && isa_and_nonnull<TimepointImmediateOp>(
+ op.await_timepoint().getDefiningOp());
+ if (!isImmediate) return failure();
+ rewriter.updateRootInPlace(op,
+ [&]() { op.await_timepointMutable().clear(); });
+ return success();
+ }
+};
+
+// If any operands are sourced from subviews clone those subviews into the
+// region and rewrite the operands to point at the original resource. This
+// allows us to progressively fold the subviews into the ops consuming them.
+//
+// Example:
+// %0 = stream.resource.subview %src[%offset] ...
+// %1 = stream.async.execute with(%0 as %arg0)
+// ->
+// %1 = stream.async.execute with(%src as %arg0) {
+// %2 = stream.resource.subview %arg0[%offset] ...
+// }
+struct ChainAsyncExecuteWaits : public OpRewritePattern<AsyncExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newTimepoints;
+ SmallVector<std::pair<unsigned, Value>> replacements;
+ for (auto operand : llvm::enumerate(op.operands())) {
+ if (auto awaitOp = dyn_cast_or_null<TimepointAwaitOp>(
+ operand.value().getDefiningOp())) {
+ newTimepoints.push_back(awaitOp.timepoint());
+ replacements.push_back(std::make_pair(
+ operand.index(), awaitOp.getTiedResultOperand(operand.value())));
+ }
+ }
+ if (replacements.empty()) return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ auto newTimepoint = joinAwaitTimepoints(op.getLoc(), op.await_timepoint(),
+ newTimepoints, rewriter);
+ op.await_timepointMutable().assign(newTimepoint);
+
+ for (auto replacement : replacements) {
+ op.operandsMutable()
+ .slice(replacement.first, 1)
+ .assign(replacement.second);
+ }
+ });
+ return success();
+ }
+};
+
+// If any operands are sourced from subviews clone those subviews into the
+// region and rewrite the operands to point at the original resource. This
+// allows us to progressively fold the subviews into the ops consuming them.
+struct CloneCapturedAsyncExecuteSubviewOps
+ : public OpRewritePattern<AsyncExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ struct SubviewCapture {
+ unsigned operandIdx;
+ IREE::Stream::ResourceSubviewOp subviewOp;
+ };
+ SmallVector<SubviewCapture> captures;
+ for (auto operand : llvm::enumerate(op.operands())) {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(operand.value());
+ if (!subviewOp) continue;
+ captures.push_back(
+ SubviewCapture{static_cast<unsigned>(operand.index()), subviewOp});
+ }
+ if (captures.empty()) return failure();
+ rewriter.startRootUpdate(op);
+
+ auto &entryBlock = op.body().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+ for (auto &capture : captures) {
+ // Replace operand with the source subview resource.
+ op.operandsMutable()
+ .slice(capture.operandIdx, 1)
+ .assign(capture.subviewOp.source());
+ op.operand_sizesMutable()
+ .slice(capture.operandIdx, 1)
+ .assign(capture.subviewOp.source_size());
+
+ // Clone the subview into the region and wire it up to take the same
+ // range as the original.
+ auto arg = entryBlock.getArgument(capture.operandIdx);
+ auto newOp = rewriter.create<ResourceSubviewOp>(
+ capture.subviewOp.getLoc(), arg, capture.subviewOp.source_size(),
+ capture.subviewOp.source_offset(), capture.subviewOp.result_size());
+ arg.replaceAllUsesExcept(newOp.result(), newOp);
+ }
+
+ rewriter.finalizeRootUpdate(op);
+ return success();
+ }
+};
+
+// Elides stream.async.execute ops when they have no meaningful work.
+// The returned timepoint is replaced with an immediately resolved timepoint.
+//
+// Example:
+// %result, %timepoint = stream.async.execute with(%capture as %arg0) {
+// stream.yield %arg0
+// }
+// ->
+// %result = %capture
+// %timepoint = stream.timepoint.immediate
+struct ElideNoOpAsyncExecuteOp : public OpRewritePattern<AsyncExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ auto &entryBlock = op.body().front();
+ auto yieldOp = getYieldIfOnlyOp(entryBlock);
+ if (!yieldOp.hasValue()) {
+ // Has non-yield ops.
+ return failure();
+ }
+ SmallVector<Value> newResults;
+ for (auto operand : yieldOp->operands()) {
+ auto arg = operand.cast<BlockArgument>();
+ auto capture = op.operands()[arg.getArgNumber()];
+ assert(arg.getType() == capture.getType() &&
+ "expect 1:1 types on captures to results");
+ newResults.push_back(capture);
+ }
+ auto immediateTimepoint = rewriter.create<TimepointImmediateOp>(
+ op.getLoc(), op.result_timepoint().getType());
+ newResults.push_back(immediateTimepoint);
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
+} // namespace
+
+void AsyncExecuteOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ElideImmediateAsyncExecuteWaits>(context);
+ results.insert<ChainAsyncExecuteWaits>(context);
+ results.insert<CloneCapturedAsyncExecuteSubviewOps>(context);
+ results.insert<ElideNoOpAsyncExecuteOp>(context);
+ results.insert<IREE::Util::ClosureOptimizationPattern<AsyncExecuteOp>>(
+ context);
+ results.insert<TieRegionResults<AsyncExecuteOp>>(context);
+ results.insert<ElideUnusedOp<AsyncExecuteOp>>(context);
+ results.insert<MaterializeCOW<AsyncExecuteOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.concurrent
+//===----------------------------------------------------------------------===//
+
+void AsyncConcurrentOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<IREE::Util::ClosureOptimizationPattern<AsyncConcurrentOp>>(
+ context);
+ results.insert<TieRegionResults<AsyncConcurrentOp>>(context);
+ results.insert<ElideUnusedOp<AsyncConcurrentOp>>(context);
+ results.insert<MaterializeCOW<AsyncConcurrentOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.flush
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview ranges into flush ranges.
+//
+// Example:
+// %0 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
+// stream.cmd.flush %0[%offset for %length]
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.cmd.flush %dst[%new_offset for %subview_length]
+struct FoldSubviewsIntoCmdFlushOp : public OpRewritePattern<CmdFlushOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdFlushOp op,
+ PatternRewriter &rewriter) const override {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(op.target());
+ if (!subviewOp) return failure();
+ setInsertionPointToParentExecutionScope(op, rewriter);
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(), op.target_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.targetMutable().assign(subviewOp.source());
+ op.target_sizeMutable().assign(subviewOp.source_size());
+ op.target_offsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdFlushOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldSubviewsIntoCmdFlushOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.invalidate
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview ranges into invalidate ranges.
+//
+// Example:
+// %0 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
+// stream.cmd.invalidate %0[%offset for %length]
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.cmd.invalidate %dst[%new_offset for %subview_length]
+struct FoldSubviewsIntoCmdInvalidateOp
+ : public OpRewritePattern<CmdInvalidateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdInvalidateOp op,
+ PatternRewriter &rewriter) const override {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(op.target());
+ if (!subviewOp) return failure();
+ setInsertionPointToParentExecutionScope(op, rewriter);
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(), op.target_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.targetMutable().assign(subviewOp.source());
+ op.target_sizeMutable().assign(subviewOp.source_size());
+ op.target_offsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdInvalidateOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldSubviewsIntoCmdInvalidateOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.discard
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview ranges into discard ranges.
+//
+// Example:
+// %0 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
+// stream.cmd.discard %0[%offset for %length]
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.cmd.discard %dst[%new_offset for %subview_length]
+struct FoldSubviewsIntoCmdDiscardOp : public OpRewritePattern<CmdDiscardOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdDiscardOp op,
+ PatternRewriter &rewriter) const override {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(op.target());
+ if (!subviewOp) return failure();
+ setInsertionPointToParentExecutionScope(op, rewriter);
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(), op.target_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.targetMutable().assign(subviewOp.source());
+ op.target_sizeMutable().assign(subviewOp.source_size());
+ op.target_offsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdDiscardOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldSubviewsIntoCmdDiscardOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.fill
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview ranges into fill ranges.
+//
+// Example:
+// %0 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
+// stream.cmd.fill %cst, %0[%offset for %length]
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.cmd.fill %cst, %dst[%new_offset for %subview_length]
+struct FoldSubviewsIntoCmdFillOp : public OpRewritePattern<CmdFillOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdFillOp op,
+ PatternRewriter &rewriter) const override {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(op.target());
+ if (!subviewOp) return failure();
+ setInsertionPointToParentExecutionScope(op, rewriter);
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(), op.target_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.targetMutable().assign(subviewOp.source());
+ op.target_sizeMutable().assign(subviewOp.source_size());
+ op.target_offsetMutable().assign(newOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdFillOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldSubviewsIntoCmdFillOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.copy
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview ranges into copy ranges.
+//
+// Example:
+// %0 = stream.resource.subview %src[%subview_offset] ... -> {%subview_length}
+// %1 = stream.resource.subview %dst[%subview_offset] ... -> {%subview_length}
+// stream.cmd.copy %0[%offset], %1[%offset], %length
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.cmd.copy %src[%new_offset], %dst[%new_offset], %subview_length
+struct FoldSubviewsIntoCmdCopyOp : public OpRewritePattern<CmdCopyOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdCopyOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceSubviewOp = ResourceSubviewOp::findSubviewOp(op.source());
+ auto targetSubviewOp = ResourceSubviewOp::findSubviewOp(op.target());
+ if (!sourceSubviewOp && !targetSubviewOp) return failure();
+ setInsertionPointToParentExecutionScope(op, rewriter);
+ if (sourceSubviewOp) {
+ auto fusedLoc =
+ rewriter.getFusedLoc({sourceSubviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, sourceSubviewOp.source_offset(), op.source_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.sourceMutable().assign(sourceSubviewOp.source());
+ op.source_sizeMutable().assign(sourceSubviewOp.source_size());
+ op.source_offsetMutable().assign(newOffset);
+ });
+ }
+ if (targetSubviewOp) {
+ auto fusedLoc =
+ rewriter.getFusedLoc({targetSubviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, targetSubviewOp.source_offset(), op.target_offset());
+ rewriter.updateRootInPlace(op, [&]() {
+ op.targetMutable().assign(targetSubviewOp.source());
+ op.target_sizeMutable().assign(targetSubviewOp.source_size());
+ op.target_offsetMutable().assign(newOffset);
+ });
+ }
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdCopyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldSubviewsIntoCmdCopyOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.dispatch
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Folds subview ranges into dispatch ranges.
+//
+// Example:
+// %0 = stream.resource.subview %src[%subview_offset] ... -> {%subview_length}
+// stream.cmd.dispatch ... {
+// rw %0[%offset] ... {%length}
+// }
+// ->
+// %new_offset = arith.addi %offset, %subview_offset
+// stream.cmd.dispatch ... {
+// rw %0[%new_offset] ... {%subview_length}
+// }
+struct FoldSubviewsIntoCmdDispatchOp : public OpRewritePattern<CmdDispatchOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdDispatchOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<ResourceSubviewOp> resourceSubviewOps;
+ resourceSubviewOps.reserve(op.resources().size());
+ bool anySubviewOps = false;
+ for (auto operand : op.resources()) {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(operand);
+ if (subviewOp) anySubviewOps = true;
+ resourceSubviewOps.push_back(subviewOp);
+ }
+ if (!anySubviewOps) return failure();
+ rewriter.startRootUpdate(op);
+
+ setInsertionPointToParentExecutionScope(op, rewriter);
+ for (auto it : llvm::enumerate(resourceSubviewOps)) {
+ unsigned resourceIdx = static_cast<unsigned>(it.index());
+ auto subviewOp = it.value();
+ if (!subviewOp) continue;
+ auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()});
+ auto newOffset = rewriter.createOrFold<arith::AddIOp>(
+ fusedLoc, subviewOp.source_offset(),
+ op.resource_offsets()[resourceIdx]);
+ op.resourcesMutable().slice(resourceIdx, 1).assign(subviewOp.source());
+ op.resource_sizesMutable()
+ .slice(resourceIdx, 1)
+ .assign(subviewOp.source_size());
+ op.resource_offsetsMutable().slice(resourceIdx, 1).assign(newOffset);
+ }
+
+ rewriter.finalizeRootUpdate(op);
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdDispatchOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldSubviewsIntoCmdDispatchOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.execute
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Elides waits that are known to be immediately resolved.
+//
+// Example:
+// %0 = stream.timepoint.immediate
+// %1 = stream.cmd.execute await(%0) => with(...)
+// ->
+// %1 = stream.cmd.execute with(...)
+struct ElideImmediateCmdExecuteWaits : public OpRewritePattern<CmdExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ bool isImmediate =
+ op.await_timepoint() && isa_and_nonnull<TimepointImmediateOp>(
+ op.await_timepoint().getDefiningOp());
+ if (!isImmediate) return failure();
+ rewriter.updateRootInPlace(op,
+ [&]() { op.await_timepointMutable().clear(); });
+ return success();
+ }
+};
+
+// Chains operand resources produced by an await to dependent execution regions.
+// This elides host waits and allows for device-side wait resolution.
+//
+// Example:
+// %0 = stream.cmd.execute with(%resource)
+// %1 = stream.timepoint.await %0 => %resource
+// %2 = stream.cmd.execute with(%resource)
+// ->
+// %0 = stream.cmd.execute with(%resource)
+// %2 = stream.cmd.execute await(%0) => with(%resource)
+struct ChainCmdExecuteWaits : public OpRewritePattern<CmdExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newTimepoints;
+ SmallVector<std::pair<unsigned, Value>> replacements;
+ for (auto operand : llvm::enumerate(op.operands())) {
+ if (auto awaitOp = dyn_cast_or_null<TimepointAwaitOp>(
+ operand.value().getDefiningOp())) {
+ newTimepoints.push_back(awaitOp.timepoint());
+ replacements.push_back(std::make_pair(
+ operand.index(), awaitOp.getTiedResultOperand(operand.value())));
+ }
+ }
+ if (replacements.empty()) return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ auto newTimepoint = joinAwaitTimepoints(op.getLoc(), op.await_timepoint(),
+ newTimepoints, rewriter);
+ op.await_timepointMutable().assign(newTimepoint);
+ for (auto replacement : replacements) {
+ op.operandsMutable()
+ .slice(replacement.first, 1)
+ .assign(replacement.second);
+ }
+ });
+ return success();
+ }
+};
+
+// If any operands are sourced from subviews clone those subviews into the
+// region and rewrite the operands to point at the original resource. This
+// allows us to progressively fold the subviews into the ops consuming them.
+//
+// Example:
+// %0 = stream.resource.subview %src[%offset] ...
+// %1 = stream.cmd.execute with(%0 as %arg0)
+// ->
+// %1 = stream.cmd.execute with(%src as %arg0) {
+// %2 = stream.resource.subview %arg0[%offset] ...
+// }
+struct CloneCapturedCmdExecuteSubviewOps
+ : public OpRewritePattern<CmdExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ struct SubviewCapture {
+ unsigned operandIdx;
+ IREE::Stream::ResourceSubviewOp subviewOp;
+ };
+ SmallVector<SubviewCapture> captures;
+ for (auto operand : llvm::enumerate(op.operands())) {
+ auto subviewOp = ResourceSubviewOp::findSubviewOp(operand.value());
+ if (!subviewOp) continue;
+ captures.push_back(
+ SubviewCapture{static_cast<unsigned>(operand.index()), subviewOp});
+ }
+ if (captures.empty()) return failure();
+ rewriter.startRootUpdate(op);
+
+ auto &entryBlock = op.body().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+ for (auto &capture : captures) {
+ // Replace operand with the source subview resource.
+ op.operandsMutable()
+ .slice(capture.operandIdx, 1)
+ .assign(capture.subviewOp.source());
+ op.operand_sizesMutable()
+ .slice(capture.operandIdx, 1)
+ .assign(capture.subviewOp.source_size());
+
+ // Clone the subview into the region and wire it up to take the same
+ // range as the original.
+ auto arg = entryBlock.getArgument(capture.operandIdx);
+ auto newOp = rewriter.create<ResourceSubviewOp>(
+ capture.subviewOp.getLoc(), arg, capture.subviewOp.source_size(),
+ capture.subviewOp.source_offset(), capture.subviewOp.result_size());
+ arg.replaceAllUsesExcept(newOp.result(), newOp);
+ }
+
+ rewriter.finalizeRootUpdate(op);
+ return success();
+ }
+};
+
+// Elides stream.cmd.execute ops when they have no meaningful work.
+// The returned timepoint is replaced with an immediately resolved timepoint.
+struct ElideNoOpCmdExecuteOp : public OpRewritePattern<CmdExecuteOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CmdExecuteOp op,
+ PatternRewriter &rewriter) const override {
+ auto &entryBlock = op.body().front();
+ auto yieldOp = getYieldIfOnlyOp(entryBlock);
+ if (!yieldOp.hasValue()) {
+ // Has non-yield ops.
+ return failure();
+ }
+ if (yieldOp->getNumOperands() != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "no ops in execute region but still passing through operands");
+ }
+ rewriter.replaceOpWithNewOp<TimepointImmediateOp>(
+ op, op.result_timepoint().getType());
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdExecuteOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ElideImmediateCmdExecuteWaits>(context);
+ results.insert<ChainCmdExecuteWaits>(context);
+ results.insert<CloneCapturedCmdExecuteSubviewOps>(context);
+ results.insert<ElideNoOpCmdExecuteOp>(context);
+ results.insert<IREE::Util::ClosureOptimizationPattern<CmdExecuteOp>>(context);
+ results.insert<ElideUnusedOp<CmdExecuteOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.serial
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Elides a region-carrying op when the region is empty.
+// Requires no results that need replacement.
+template <typename OpT>
+struct ElideEmptyCmdRegionOp : public OpRewritePattern<OpT> {
+ using OpRewritePattern<OpT>::OpRewritePattern;
+ LogicalResult matchAndRewrite(OpT op,
+ PatternRewriter &rewriter) const override {
+ auto &entryBlock = op.body().front();
+ auto yieldOp = getYieldIfOnlyOp(entryBlock);
+ if (!yieldOp.hasValue()) {
+ // Has non-yield ops.
+ return failure();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+} // namespace
+
+void CmdSerialOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ElideEmptyCmdRegionOp<CmdSerialOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.concurrent
+//===----------------------------------------------------------------------===//
+
+void CmdConcurrentOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ElideEmptyCmdRegionOp<CmdConcurrentOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.timepoint.immediate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TimepointImmediateOp::fold(ArrayRef<Attribute> operands) {
+ return IREE::Stream::TimepointAttr::get(getContext(), getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// stream.timepoint.join
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TimepointJoinOp::fold(ArrayRef<Attribute> operands) {
+ if (llvm::all_of(operands, [](auto operand) { return operand != nullptr; })) {
+ // Immediate wait; fold into immediate.
+ return IREE::Stream::TimepointAttr::get(getContext(),
+ getResult().getType());
+ } else if (timepoints().size() == 1) {
+ // Join of a single timepoint => that timepoint.
+ return timepoints().front();
+ }
+ return {};
+}
+
+namespace {
+
+struct ElideImmediateTimepointJoinOperands
+ : public OpRewritePattern<TimepointJoinOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointJoinOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newTimepoints;
+ newTimepoints.reserve(op.timepoints().size());
+ for (auto timepoint : op.timepoints()) {
+ if (!isa_and_nonnull<TimepointImmediateOp>(timepoint.getDefiningOp())) {
+ newTimepoints.push_back(timepoint);
+ }
+ }
+ if (newTimepoints.size() == op.timepoints().size()) return failure();
+ if (newTimepoints.empty()) {
+ // Fully immediate; replace entire join with immediate.
+ rewriter.replaceOpWithNewOp<TimepointImmediateOp>(op,
+ op.result().getType());
+ } else {
+ rewriter.updateRootInPlace(
+ op, [&]() { op.timepointsMutable().assign(newTimepoints); });
+ }
+ return success();
+ }
+};
+
+struct FoldDuplicateTimepointJoinOperands
+ : public OpRewritePattern<TimepointJoinOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointJoinOp op,
+ PatternRewriter &rewriter) const override {
+ SetVector<Value> newTimepoints;
+ newTimepoints.insert(op.timepoints().begin(), op.timepoints().end());
+ if (newTimepoints.size() == op.timepoints().size()) return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.timepointsMutable().assign(newTimepoints.takeVector());
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void TimepointJoinOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): elide operands if timepoint must be satisfied in use-def.
+ // TODO(benvanik): sink and pull in other timepoints (join on all needed).
+ results.insert<ElideImmediateTimepointJoinOperands>(context);
+ results.insert<FoldDuplicateTimepointJoinOperands>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.timepoint.await
+//===----------------------------------------------------------------------===//
+
+LogicalResult TimepointAwaitOp::fold(ArrayRef<Attribute> foldOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ if (foldOperands[0]) {
+ // Immediate wait; fold to all captured operands.
+ results.append(operands().begin(), operands().end());
+ return success();
+ }
+ return failure();
+}
+
+namespace {
+
+struct ElideImmediateAwaits : public OpRewritePattern<TimepointAwaitOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointAwaitOp op,
+ PatternRewriter &rewriter) const override {
+ if (isa_and_nonnull<TimepointImmediateOp>(op.timepoint().getDefiningOp())) {
+ rewriter.replaceOp(op, op.operands());
+ return success();
+ }
+ return failure();
+ }
+};
+
+// Sinks an await down to the first consumer of any resource. Note that there
+// may be multiple resources guarded by the await.
+struct SinkAwaitToFirstConsumer : public OpRewritePattern<TimepointAwaitOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointAwaitOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO(benvanik): amortize this dominance calculation.
+ DominanceInfo domInfo(op->getParentOp());
+
+ // Gather all direct users of the awaited resources and find the common
+ // dominator block across all uses. This may be the entry block itself.
+ SetVector<Operation *> allUsers;
+ Block *commonDominator = nullptr;
+ for (auto result : op.results()) {
+ for (auto &use : result.getUses()) {
+ if (allUsers.insert(use.getOwner())) {
+ auto *userBlock = use.getOwner()->getBlock();
+ commonDominator = commonDominator
+ ? domInfo.findNearestCommonDominator(
+ commonDominator, userBlock)
+ : userBlock;
+ }
+ }
+ }
+ if (!commonDominator) return failure();
+
+ // Find the first use within the dominator block (if any) so that we
+ // can sink down to it.
+ Operation *firstUserInDominator = commonDominator->getTerminator();
+ for (auto *user : allUsers) {
+ if (user->getBlock() == commonDominator) {
+ if (user->isBeforeInBlock(firstUserInDominator)) {
+ firstUserInDominator = user;
+ }
+ }
+ }
+
+ // Find the earliest point before |user| that is safe to insert into. If it
+ // ends up being where we already are then no-op.
+ auto ip = findInsertionPointBefore(op, firstUserInDominator);
+ if (ip == Block::iterator(op)) return failure();
+
+ rewriter.updateRootInPlace(op,
+ [&]() { op->moveBefore(ip->getBlock(), ip); });
+ return success();
+ }
+};
+
+// Moves stream.resource.subview ops across to results of an await.
+// This allows us to pass-through the subviews to consumers that can hopefully
+// fold the range.
+struct SinkSubviewsAcrossAwaits : public OpRewritePattern<TimepointAwaitOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointAwaitOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.startRootUpdate(op);
+ bool didChange = false;
+ for (auto operand : llvm::enumerate(op.operands())) {
+ auto subviewOp = dyn_cast_or_null<IREE::Stream::ResourceSubviewOp>(
+ operand.value().getDefiningOp());
+ if (!subviewOp) continue;
+ didChange = true;
+ unsigned operandIdx = static_cast<unsigned>(operand.index());
+
+ // Create a new subview op matching the original on our result and swap
+ // users to it.
+ auto result = op.results()[operandIdx];
+ auto newOp = rewriter.create<IREE::Stream::ResourceSubviewOp>(
+ subviewOp.getLoc(), result, subviewOp.source_size(),
+ subviewOp.source_offset(), subviewOp.result_size());
+ result.replaceAllUsesExcept(newOp.result(), newOp);
+
+ // Update our bound size to the subview source size (not the subrange).
+ op.operand_sizesMutable()
+ .slice(operandIdx, 1)
+ .assign(subviewOp.source_size());
+
+ // Replace our resource usage with the source of the subview op.
+ op.operandsMutable().slice(operandIdx, 1).assign(subviewOp.source());
+ }
+ if (didChange) {
+ rewriter.finalizeRootUpdate(op);
+ return success();
+ } else {
+ rewriter.cancelRootUpdate(op);
+ return failure();
+ }
+ }
+};
+
+// Finds timepoint awaits on the same timepoint within the same domination
+// paths and groups them together.
+//
+// Example:
+// %6 = stream.timepoint.await %tp => %3 : !stream.resource<external>{%c4000}
+// %7 = stream.tensor.export %6 ...
+// %8 = stream.timepoint.await %tp => %4 : !stream.resource<external>{%c4000}
+// %9 = stream.tensor.export %8 ...
+// ->
+// %6:2 = stream.timepoint.await %tp => %3, %4 :
+// !stream.resource<external>{%c4000}, !stream.resource<external>{%c4000}
+// %7 = stream.tensor.export %6#0 ...
+// %9 = stream.tensor.export %6#1 ...
+struct GroupAwaitsByTimepoint : public OpRewritePattern<TimepointAwaitOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointAwaitOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<TimepointAwaitOp> coveredOps;
+ for (auto &use : op.timepoint().getUses()) {
+ // TODO(benvanik): make this handle joins/ties; today we get blocked
+ // there. We rely on other canonicalizers to sink things such that
+ // (hopefully) we get them directly accessible here.
+ if (use.getOwner() == op) continue;
+ if (use.getOwner()->getBlock() != op->getBlock() ||
+ use.getOwner()->isBeforeInBlock(op)) {
+ // TODO(benvanik): allow dominated blocks.
+ continue;
+ }
+ auto awaitOp = dyn_cast<TimepointAwaitOp>(use.getOwner());
+ if (!awaitOp ||
+ !AffinityAttr::areCompatible(
+ op.affinityAttr().dyn_cast_or_null<AffinityAttr>(),
+ awaitOp.affinityAttr().dyn_cast_or_null<AffinityAttr>())) {
+ // Can't combine if the affinities differ as the wait semantics are
+ // load-bearing. Probably. They really shouldn't be.
+ // TODO(benvanik): remove affinity from stream.timepoint.await.
+ continue;
+ }
+ coveredOps.push_back(awaitOp);
+ }
+ if (coveredOps.empty()) return failure();
+ coveredOps.push_back(op);
+
+ // Sort the ops by their definition order; this gives us a deterministic
+ // operand ordering regardless of the order the patterns are run.
+ llvm::sort(coveredOps, [&](TimepointAwaitOp lhs, TimepointAwaitOp rhs) {
+ return lhs->isBeforeInBlock(rhs);
+ });
+
+ // Combine all awaits into a single one.
+ SmallVector<Value> newOperands;
+ SmallVector<Value> newOperandSizes;
+ for (auto coveredOp : coveredOps) {
+ llvm::append_range(newOperands, coveredOp.operands());
+ llvm::append_range(newOperandSizes, coveredOp.operand_sizes());
+ }
+ auto newOp = rewriter.create<TimepointAwaitOp>(
+ op.getLoc(), newOperands, newOperandSizes, op.timepoint());
+ if (op.affinity().hasValue()) {
+ newOp.affinityAttr(op.affinityAttr());
+ }
+
+ // Replace covered ops with the new results.
+ unsigned resultIdx = 0;
+ for (auto coveredOp : coveredOps) {
+ for (auto result : coveredOp.results()) {
+ result.replaceAllUsesWith(newOp.results()[resultIdx++]);
+ }
+ rewriter.eraseOp(coveredOp);
+ }
+ return success();
+ }
+};
+
+// Folds duplicate resources passing through an await op.
+//
+// Example:
+// %1:4 = stream.timepoint.await %tp => %1, %1, %2, %2
+// ->
+// %1:2 = stream.timepoint.await %tp => %1, %2
+struct FoldDuplicateAwaitResources : public OpRewritePattern<TimepointAwaitOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(TimepointAwaitOp op,
+ PatternRewriter &rewriter) const override {
+ DenseMap<Value, unsigned> baseMap;
+ SmallVector<std::pair<Value, unsigned>> replacements;
+ SmallVector<Value> newOperands;
+ SmallVector<Value> newOperandSizes;
+ for (auto it : llvm::zip(op.operands(), op.operand_sizes(), op.results())) {
+ auto operand = std::get<0>(it);
+ auto operandSize = std::get<1>(it);
+ auto result = std::get<2>(it);
+ auto insertion =
+ baseMap.insert(std::make_pair(operand, newOperands.size()));
+ if (insertion.second) {
+ // Inserted as a new unique operand.
+ newOperands.push_back(operand);
+ newOperandSizes.push_back(operandSize);
+ }
+ unsigned resultIdx = insertion.first->second;
+ replacements.push_back(std::make_pair(result, resultIdx));
+ }
+ if (newOperands.size() == op.operands().size()) {
+ return failure(); // No change.
+ }
+
+ // Create replacement op with deduped operands/results.
+ auto newOp = rewriter.create<IREE::Stream::TimepointAwaitOp>(
+ op.getLoc(), newOperands, newOperandSizes, op.timepoint());
+ if (op.affinity().hasValue()) {
+ newOp.affinityAttr(op.affinityAttr());
+ }
+
+ // Replace all duplicate results with the base results.
+ for (auto &replacement : replacements) {
+ auto oldResult = replacement.first;
+ auto newResult = newOp.results()[replacement.second];
+ oldResult.replaceAllUsesWith(newResult);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+} // namespace
+
+void TimepointAwaitOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): elide waits if timepoint must be satisfied in use-def.
+ results.insert<ElideImmediateAwaits>(context);
+ results.insert<SinkAwaitToFirstConsumer>(context);
+ results.insert<SinkSubviewsAcrossAwaits>(context);
+ results.insert<GroupAwaitsByTimepoint>(context);
+ results.insert<FoldDuplicateAwaitResources>(context);
+ results.insert<ElideUnusedOp<TimepointAwaitOp>>(context);
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
new file mode 100644
index 0000000..cd39e77
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -0,0 +1,1798 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+
+#include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+//===----------------------------------------------------------------------===//
+// Op utilities used within the stream dialect
+//===----------------------------------------------------------------------===//
+
+// Verifies that |dynamicDims| contains the appropriate number of dims for all
+// of the dynamic dimensions in |values|.
+static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
+ ValueRange dynamicDims) {
+ unsigned requiredCount = 0;
+ for (auto value : values) {
+ if (auto shapedType = value.getType().dyn_cast<ShapedType>()) {
+ requiredCount += shapedType.getNumDynamicDims();
+ }
+ }
+ if (dynamicDims.size() != requiredCount) {
+ return op->emitOpError()
+ << "value set has " << requiredCount
+ << " dynamic dimensions but only " << dynamicDims.size()
+ << " dimension values are attached";
+ }
+ return success();
+}
+
+// Verifies that |dynamicDims| contains the appropriate number of dims for all
+// the dynamic dimensions in |type|.
+static LogicalResult verifyOpDynamicDims(Operation *op, TypeRange types,
+ ValueRange dynamicDims) {
+ unsigned requiredCount = 0;
+ for (auto type : types) {
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ requiredCount += shapedType.getNumDynamicDims();
+ }
+ }
+ if (dynamicDims.size() != requiredCount) {
+ return op->emitOpError()
+ << "type set has " << requiredCount
+ << " dynamic dimensions but only " << dynamicDims.size()
+ << " dimension values are attached";
+ }
+ return success();
+}
+
+// Verifies that |sizes| contains the appropriate number of sizes for all of the
+// sized types in |values|.
+static LogicalResult verifyOpValueSizes(Operation *op, ValueRange values,
+ ValueRange sizes) {
+ unsigned requiredCount = 0;
+ for (auto value : values) {
+ if (value.getType().isa<IREE::Util::SizeAwareTypeInterface>()) {
+ ++requiredCount;
+ }
+ }
+ if (sizes.size() != requiredCount) {
+ return op->emitOpError() << "value set has " << requiredCount
+ << " dynamic dimensions but only " << sizes.size()
+ << " dimension values are attached";
+ }
+ return success();
+}
+
+// Verifies that all !stream.resources used within |region| are captured by
+// the entry arguments to the region.
+static LogicalResult verifyAllResourcesCaptured(Region ®ion) {
+ SetVector<Value> availableResources;
+ for (auto arg : region.front().getArguments()) {
+ availableResources.insert(arg);
+ }
+ for (auto &op : region.front()) {
+ for (auto result : op.getResults()) {
+ availableResources.insert(result);
+ }
+ for (auto operand : op.getOperands()) {
+ if (!operand.getType().isa<IREE::Stream::ResourceType>()) continue;
+ if (!availableResources.contains(operand)) {
+ return op.emitOpError() << "used resource not listed in explicit "
+ "captures (or produced internally)";
+ }
+ }
+ }
+ return success();
+}
+
+// Verifies that escaping !stream.resources have the sizes when they are
+// yielded match the sizes declared on the parent op. This information is
+// redundant but keeps analysis local and agnostic to the parent op structure
+// which is useful for when we outline things.
+static LogicalResult verifyEscapingResources(Region ®ion,
+ ResultRange results,
+ ValueRange resultSizes) {
+ // Ensure yielded resources match the signature.
+ for (auto yieldOp : region.getOps<IREE::Stream::YieldOp>()) {
+ if (results.size() != yieldOp.operands().size()) {
+ return yieldOp.emitOpError()
+ << "yield result count mismatch with parent op";
+ }
+ for (auto it : llvm::zip(results, yieldOp.operands())) {
+ auto outerValue = std::get<0>(it);
+ auto innerValue = std::get<1>(it);
+ if (outerValue.getType() != innerValue.getType()) {
+ return yieldOp.emitOpError()
+ << "result type mismatch: expected " << outerValue.getType()
+ << " but got " << innerValue.getType();
+ }
+ }
+ for (auto it : llvm::zip(resultSizes, yieldOp.operand_sizes())) {
+ auto outerSize = std::get<0>(it);
+ auto innerSize = std::get<1>(it);
+ if (outerSize != innerSize) {
+ return yieldOp.emitOpError() << "result size mismatch";
+ }
+ }
+ }
+ return success();
+}
+
+// Computes the value access bits starting from |rootValue|.
+// Traverses the IR graph along tied ops but does not handle branches.
+static IREE::Util::ValueAccess computeValueAccess(Value rootValue) {
+ IREE::Util::ValueAccess access;
+ DenseSet<Value> processedValues;
+ SmallVector<Value> worklist;
+ auto enqueueValue = [&](Value value) {
+ if (processedValues.contains(value)) return;
+ processedValues.insert(value);
+ worklist.push_back(value);
+ };
+ enqueueValue(rootValue);
+ while (!worklist.empty()) {
+ Value value = worklist.back();
+ worklist.pop_back();
+
+ // Walk up the definition chain.
+ if (auto definingOp = value.getDefiningOp()) {
+ // Value is produced within the region and thus written.
+ access.isWrite = true;
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
+ access.isRead = true;
+ auto operand = tiedOp.getTiedResultOperand(value);
+ if (operand) {
+ // Value is tied back to another value; continue analyzing past it.
+ enqueueValue(operand);
+ } else {
+ // Value contents are fully produced by this op.
+ access.isDiscard = true;
+ }
+ } else if (isa<IREE::Stream::SubviewEffectOpInterface>(definingOp)) {
+ // TODO(benvanik): actually query; for now assume *.
+ access.isRead = true;
+ access.isWrite = true;
+ } else {
+ // Value contents are fully produced by this op.
+ access.isDiscard = true;
+ }
+ }
+
+ // Walk down the use chain.
+ for (auto user : value.getUsers()) {
+ // Used by an op.
+ access.isRead = true;
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(user)) {
+ auto tiedIndices = tiedOp.getTiedResultOperandIndices();
+ for (int64_t tiedIndex : tiedIndices) {
+ if (tiedIndex == IREE::Util::TiedOpInterface::kUntiedIndex) continue;
+ auto operand = user->getOperand(tiedIndex);
+ if (operand == value) {
+ // Tied operand.
+ access.isRead = true;
+ access.isWrite = true;
+ enqueueValue(operand);
+ }
+ }
+ } else if (isa<IREE::Stream::SubviewEffectOpInterface>(user)) {
+ // TODO(benvanik): actually query; for now assume *.
+ access.isRead = true;
+ access.isWrite = true;
+ }
+ }
+ }
+ return access;
+}
+
+static void eraseStreamRegionResults(Region ®ion,
+ ArrayRef<unsigned> excludedResultIndices) {
+ for (auto &block : region.getBlocks()) {
+ auto yieldOp = dyn_cast<IREE::Stream::YieldOp>(block.getTerminator());
+ if (!yieldOp) continue;
+ llvm::SmallVector<Value, 4> newOperands;
+ for (auto i : llvm::reverse(excludedResultIndices)) {
+ yieldOp.operandsMutable().erase(i);
+ yieldOp.operand_sizesMutable().erase(i);
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ResourceRegion>($operands, type($operands), $operand_sizes,
+// type($results), $result_sizes,
+// $tied_operands, $body)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseResourceRegion(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
+ SmallVectorImpl<Type> &operandTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &operandSizes,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultSizes,
+ ArrayAttr &tiedOperands, Region &body) {
+ SmallVector<OpAsmParser::OperandType, 16> regionArgs;
+ if (failed(parser.parseLParen())) {
+ return failure();
+ }
+ if (failed(parser.parseOptionalRParen())) {
+ do {
+ // Reserve entries in the lists.
+ operands.emplace_back();
+ operandTypes.emplace_back();
+ operandSizes.emplace_back();
+ regionArgs.emplace_back();
+ if (failed(parser.parseOperand(operands.back())) ||
+ failed(parser.parseKeyword("as")) ||
+ failed(parser.parseRegionArgument(regionArgs.back())) ||
+ failed(parser.parseColon()) ||
+ failed(parseSizeAwareType(parser, operandTypes.back(),
+ operandSizes.back()))) {
+ return failure();
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+ if (failed(parser.parseRParen())) {
+ return failure();
+ }
+ }
+
+ if (failed(parser.parseArrow())) return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parseShapedResultList(parser, operands, operandTypes,
+ operandSizes, resultTypes, resultSizes,
+ tiedOperands)) ||
+ failed(parser.parseRParen())) {
+ return failure();
+ }
+ } else {
+ if (failed(parseShapedResultList(parser, operands, operandTypes,
+ operandSizes, resultTypes, resultSizes,
+ tiedOperands))) {
+ return failure();
+ }
+ }
+ return parser.parseRegion(body, regionArgs, operandTypes,
+ /*enableNameShadowing=*/false);
+}
+
+static void printResourceRegion(OpAsmPrinter &p, Operation *op,
+ ValueRange operands, TypeRange operandTypes,
+ ValueRange operandSizes, TypeRange resultTypes,
+ ValueRange resultSizes, ArrayAttr tiedOperands,
+ Region &body) {
+ p << "(";
+ llvm::interleaveComma(
+ llvm::zip(operands, body.getArguments()), p, [&](auto it) {
+ auto operand = std::get<0>(it);
+ auto arg = std::get<1>(it);
+ p << operand;
+ p << " as ";
+ p << arg;
+ p << ": ";
+ p << arg.getType();
+ if (arg.getType().template isa<IREE::Util::SizeAwareTypeInterface>()) {
+ p << "{" << operandSizes.front() << "}";
+ operandSizes = operandSizes.drop_front(1);
+ }
+ });
+ p << ") -> ";
+ if (resultTypes.size() != 1) p << "(";
+ printShapedResultList(p, op, operands, operandTypes, operandSizes,
+ resultTypes, resultSizes, tiedOperands);
+ if (resultTypes.size() != 1) p << ")";
+ p.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ExplicitResourceRegion>($operands, type($operands), $operand_sizes,
+// $body)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseExplicitResourceRegion(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
+ SmallVectorImpl<Type> &operandTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &operandSizes, Region &body) {
+ SmallVector<OpAsmParser::OperandType, 16> regionArgs;
+ if (failed(parser.parseLParen())) {
+ return failure();
+ }
+ if (failed(parser.parseOptionalRParen())) {
+ do {
+ // Reserve entries in the lists.
+ operands.emplace_back();
+ operandTypes.emplace_back();
+ operandSizes.emplace_back();
+ regionArgs.emplace_back();
+ if (failed(parser.parseOperand(operands.back())) ||
+ failed(parser.parseKeyword("as")) ||
+ failed(parser.parseRegionArgument(regionArgs.back())) ||
+ failed(parser.parseColon()) ||
+ failed(parseSizeAwareType(parser, operandTypes.back(),
+ operandSizes.back()))) {
+ return failure();
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+ if (failed(parser.parseRParen())) {
+ return failure();
+ }
+ }
+ if (failed(parser.parseRegion(body, regionArgs, operandTypes,
+ /*enableNameShadowing=*/false))) {
+ return failure();
+ }
+ // HACK: I can't figure out how to make this work with the default parsing -
+ // it doesn't call this like it should.
+ IREE::Stream::CmdExecuteOp::ensureTerminator(
+ body, parser.getBuilder(),
+ parser.getEncodedSourceLoc(parser.getCurrentLocation()));
+ return success();
+}
+
+static void printExplicitResourceRegion(OpAsmPrinter &p, Operation *op,
+ ValueRange operands,
+ TypeRange operandTypes,
+ ValueRange operandSizes, Region &body) {
+ p << "(";
+ llvm::interleaveComma(
+ llvm::zip(operands, body.getArguments()), p, [&](auto it) {
+ auto operand = std::get<0>(it);
+ auto arg = std::get<1>(it);
+ p << operand;
+ p << " as ";
+ p << arg;
+ p << ": ";
+ p << arg.getType();
+ if (arg.getType().template isa<IREE::Util::SizeAwareTypeInterface>()) {
+ p << "{" << operandSizes.front() << "}";
+ operandSizes = operandSizes.drop_front(1);
+ }
+ });
+ p << ")";
+ p.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// custom<PackSliceRanges>($lifetime_intervals,
+// $dynamic_slice_sizes,
+// type($packed_offsets))
+//===----------------------------------------------------------------------===//
+
+static ParseResult parsePackSliceRanges(
+ OpAsmParser &parser, ArrayAttr &lifetimeIntervals,
+ SmallVectorImpl<OpAsmParser::OperandType> &dynamicSliceSizes,
+ SmallVectorImpl<Type> &packedOffsetTypes) {
+ auto indexType = parser.getBuilder().getIndexType();
+ SmallVector<Attribute> lifetimeRangeValues;
+ do {
+ if (failed(parser.parseOptionalLSquare())) break;
+ IntegerAttr lifetimeStart;
+ IntegerAttr lifetimeEnd;
+ OpAsmParser::OperandType dynamicSliceSize;
+ if (failed(parser.parseAttribute(lifetimeStart, indexType)) ||
+ failed(parser.parseComma()) ||
+ failed(parser.parseAttribute(lifetimeEnd, indexType)) ||
+ failed(parser.parseRSquare()) || failed(parser.parseEqual()) ||
+ failed(parser.parseOperand(dynamicSliceSize))) {
+ return failure();
+ }
+ lifetimeRangeValues.push_back(lifetimeStart);
+ lifetimeRangeValues.push_back(lifetimeEnd);
+ dynamicSliceSizes.push_back(dynamicSliceSize);
+ packedOffsetTypes.push_back(indexType);
+ } while (succeeded(parser.parseOptionalComma()));
+ lifetimeIntervals = parser.getBuilder().getArrayAttr(lifetimeRangeValues);
+ return success();
+}
+
+static void printPackSliceRanges(OpAsmPrinter &p, Operation *op,
+ ArrayAttr lifetimeIntervals,
+ ValueRange dynamicSliceSizes,
+ TypeRange packedOffsetTypes) {
+ if (packedOffsetTypes.empty()) return;
+ for (unsigned i = 0; i < packedOffsetTypes.size(); ++i) {
+ auto lifetimeStart = lifetimeIntervals[i * 2];
+ auto lifetimeEnd = lifetimeIntervals[i * 2 + 1];
+ auto sliceSize = dynamicSliceSizes[i];
+ p.printNewline();
+ p << " [";
+ p.printAttributeWithoutType(lifetimeStart);
+ p << ", ";
+ p.printAttributeWithoutType(lifetimeEnd);
+ p << "] = ";
+ p.printOperand(sliceSize);
+ if (i < packedOffsetTypes.size() - 1) p << ",";
+ }
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ConstantValueList>(type($results),
+// $result_sizes,
+// $values)
+//===----------------------------------------------------------------------===//
+// !stream.resource<constant>{%sz} = #value,
+// !stream.resource<constant>{%sz} = #value
+
+static ParseResult parseConstantValueList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultSizes, ArrayAttr &values) {
+ SmallVector<Attribute> valueAttrs;
+ do {
+ Type resultType;
+ OpAsmParser::OperandType resultSize;
+ Attribute valueAttr;
+ if (failed(parseSizeAwareType(parser, resultType, resultSize)) ||
+ failed(parser.parseEqual()) ||
+ failed(parser.parseAttribute(valueAttr))) {
+ return failure();
+ }
+ resultTypes.push_back(resultType);
+ resultSizes.push_back(resultSize);
+ valueAttrs.push_back(valueAttr);
+ } while (succeeded(parser.parseOptionalComma()));
+ values = parser.getBuilder().getArrayAttr(valueAttrs);
+ return success();
+}
+
+static void printConstantValueList(OpAsmPrinter &p, Operation *op,
+ TypeRange resultTypes,
+ ValueRange resultSizes, ArrayAttr values) {
+ if (resultTypes.empty()) return;
+ for (unsigned i = 0; i < resultTypes.size(); ++i) {
+ p.printNewline();
+ p << " ";
+ printSizeAwareType(p, op, resultTypes[i], resultSizes[i]);
+ p << " = ";
+ p.printAttribute(values[i]);
+ if (i < resultTypes.size() - 1) p << ",";
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// custom<SymbolAlias>($sym_name, $alias)
+//===----------------------------------------------------------------------===//
+// @foo sym_name: @foo, alias: @foo
+// @foo as @bar sym_name: @bar, alias: @foo
+
+static ParseResult parseSymbolAlias(OpAsmParser &parser, StringAttr &sym_name,
+ FlatSymbolRefAttr &alias) {
+ if (failed(parser.parseAttribute(alias))) {
+ return failure();
+ }
+ if (succeeded(parser.parseOptionalKeyword("as"))) {
+ if (failed(parser.parseLParen()) ||
+ failed(parser.parseAttribute(sym_name)) ||
+ failed(parser.parseRParen())) {
+ return failure();
+ }
+ } else {
+ sym_name = StringAttr::get(parser.getContext(), alias.getValue());
+ }
+ return success();
+}
+
+static void printSymbolAlias(OpAsmPrinter &p, Operation *op,
+ StringAttr sym_name, FlatSymbolRefAttr alias) {
+ p.printAttributeWithoutType(alias);
+ if (sym_name.getValue() != alias.getValue()) {
+ p << " as(\"";
+ p.printSymbolName(sym_name.getValue());
+ p << "\")";
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.alloc
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceAllocOp op) {
+ if (failed(verifyOpValueSizes(op, op.results(), op.storage_sizes()))) {
+ return failure();
+ }
+
+ // All allocated resources must have the same lifetime.
+ auto anyType = op.results().front().getType();
+ for (auto type : op.getResultTypes()) {
+ if (type != anyType) {
+ return op.emitError()
+ << "all allocated resources must have the same lifetime";
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.map
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceMapOp op) {
+ if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.try_map
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceTryMapOp op) {
+ if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.load
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceLoadOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.store
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceStoreOp op) {
+ if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.pack
+//===----------------------------------------------------------------------===//
+
+void ResourcePackOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ // TODO(benvanik): figure out if we can get the names to coalesce when there
+ // are multiple results. Ideally we'd have `%total_length, %offsets:123` but
+ // unfortunately all get splatted out and create 10k+ char lines that are a
+ // pain to read.
+ // setNameFn(total_length(), "total_length");
+ // for (auto packedOffset : llvm::enumerate(packed_offsets())) {
+ // setNameFn(packedOffset.value(),
+ // "offset" + std::to_string(packedOffset.index()));
+ // }
+}
+
+static LogicalResult verifyOp(ResourcePackOp op) {
+ size_t sliceCount = op.packed_offsets().size();
+ if (op.lifetime_intervals().size() != sliceCount * 2) {
+ return op.emitOpError() << "requires a [start, end] range for each slice";
+ }
+ if (op.dynamic_slice_sizes().size() != sliceCount) {
+ return op.emitOpError() << "requires a size for each slice";
+ }
+ return success();
+}
+
+SmallVector<ResourcePackOp::Slice> ResourcePackOp::getSlices() {
+ auto intervalPairs = lifetime_intervals().getValue();
+ auto sizes = dynamic_slice_sizes();
+ auto offsets = packed_offsets();
+ SmallVector<ResourcePackOp::Slice> slices(offsets.size());
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ int64_t start = intervalPairs[i * 2 + 0].cast<IntegerAttr>().getInt();
+ int64_t end = intervalPairs[i * 2 + 1].cast<IntegerAttr>().getInt();
+ slices[i] = {start, end, sizes[i], offsets[i]};
+ }
+ return slices;
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.constants
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceConstantsOp op) {
+ size_t count = op.results().size();
+ if (op.result_sizes().size() != count || op.values().size() != count) {
+ return op.emitOpError() << "mismatched constant/result counts";
+ }
+
+ // All resources must have the same lifetime.
+ auto anyType = op.results().front().getType();
+ for (auto result : op.results()) {
+ if (result.getType() != anyType) {
+ return op.emitError()
+ << "all constant resources must have the same lifetime";
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.resource.subview
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(ResourceSubviewOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+bool ResourceSubviewOp::isMetadata() { return true; }
+
+Value ResourceSubviewOp::getViewSource() { return source(); }
+
+Value ResourceSubviewOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
+}
+
+::llvm::Optional<unsigned> ResourceSubviewOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // source
+}
+
+SmallVector<int64_t, 4> ResourceSubviewOp::getTiedResultOperandIndices() {
+ return {0}; // source
+}
+
+// static
+IREE::Stream::ResourceSubviewOp ResourceSubviewOp::findSubviewOp(Value value) {
+ while (value) {
+ auto *definingOp = value.getDefiningOp();
+ if (!definingOp) {
+ // Defined as a block argument - stop walk.
+ break;
+ } else if (auto subviewOp =
+ dyn_cast<IREE::Stream::ResourceSubviewOp>(definingOp)) {
+ // Found!
+ return subviewOp;
+ } else if (auto tiedOp =
+ dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
+ // Continue walking up through the tied operand.
+ value = tiedOp.getTiedResultOperand(value);
+ } else {
+ break;
+ }
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.import
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorImportOp op) {
+ if (failed(verifyOpDynamicDims(op, op.result_encoding(),
+ op.result_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value TensorImportOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
+}
+
+::llvm::Optional<unsigned> TensorImportOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // source
+}
+
+SmallVector<int64_t, 4> TensorImportOp::getTiedResultOperandIndices() {
+ return {0}; // source
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.export
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorExportOp op) {
+ if (failed(verifyOpDynamicDims(op, op.source_encoding(),
+ op.source_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value TensorExportOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
+}
+
+::llvm::Optional<unsigned> TensorExportOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // source
+}
+
+SmallVector<int64_t, 4> TensorExportOp::getTiedResultOperandIndices() {
+ return {0}; // source
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.sizeof
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorSizeOfOp op) {
+ if (failed(verifyOpDynamicDims(op, op.encoding(), op.encoding_dims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.constant
+//===----------------------------------------------------------------------===//
+
+void TensorConstantOp::getAsmResultNames(mlir::OpAsmSetValueNameFn setNameFn) {
+ setNameFn(result(), "cst");
+}
+
+static LogicalResult verifyOp(TensorConstantOp op) {
+ if (failed(verifyOpDynamicDims(op, op.result_encoding(),
+ op.result_encoding_dims()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.splat
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorSplatOp op) {
+ if (failed(verifyOpDynamicDims(op, op.result_encoding(),
+ op.result_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.clone
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorCloneOp op) {
+ // Clones can't change encodings but they can change shape information.
+ auto sourceEncoding = op.source_encoding().cast<RankedTensorType>();
+ auto resultEncoding = op.result_encoding().cast<RankedTensorType>();
+ if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
+ return op.emitOpError() << "clones changing tensor encoding from "
+ << sourceEncoding.getEncoding() << " to "
+ << resultEncoding.getEncoding() << "; not allowed";
+ }
+ if (failed(verifyOpDynamicDims(op, op.source_encoding(),
+ op.source_encoding_dims())) ||
+ failed(verifyOpDynamicDims(op, op.result_encoding(),
+ op.result_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.slice
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorSliceOp op) {
+ if (failed(verifyOpDynamicDims(op, op.source_encoding(),
+ op.source_encoding_dims())) ||
+ failed(verifyOpDynamicDims(op, op.result_encoding(),
+ op.result_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ auto sourceType = op.source_encoding().cast<ShapedType>();
+ if (op.start_indices().size() != sourceType.getRank() ||
+ op.lengths().size() != sourceType.getRank()) {
+ return op.emitOpError() << "start_indices/lengths rank mismatch";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.update
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorUpdateOp op) {
+ if (failed(verifyOpDynamicDims(op, op.update_encoding(),
+ op.update_encoding_dims())) ||
+ failed(verifyOpDynamicDims(op, op.target_encoding(),
+ op.target_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.update(), op.update_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value TensorUpdateOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> TensorUpdateOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> TensorUpdateOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.fill
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorFillOp op) {
+ if (failed(verifyOpDynamicDims(op, op.target_encoding(),
+ op.target_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value TensorFillOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> TensorFillOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> TensorFillOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.load
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorLoadOp op) {
+ if (failed(verifyOpDynamicDims(op, op.source_encoding(),
+ op.source_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
+ return failure();
+ }
+ auto sourceType = op.source_encoding().cast<ShapedType>();
+ if (op.indices().size() != sourceType.getRank()) {
+ return op.emitOpError() << "indices rank mismatch";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.tensor.store
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TensorStoreOp op) {
+ if (failed(verifyOpDynamicDims(op, op.target_encoding(),
+ op.target_encoding_dims())) ||
+ failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ auto targetType = op.target_encoding().cast<ShapedType>();
+ if (op.indices().size() != targetType.getRank()) {
+ return op.emitOpError() << "indices rank mismatch";
+ }
+ return success();
+}
+
+Value TensorStoreOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> TensorStoreOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> TensorStoreOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.alloca
+//===----------------------------------------------------------------------===//
+
+bool AsyncAllocaOp::isMetadata() { return true; }
+
+bool AsyncAllocaOp::preferCloneToConsumers() { return true; }
+
+//===----------------------------------------------------------------------===//
+// stream.async.constant
+//===----------------------------------------------------------------------===//
+
+bool AsyncConstantOp::isMetadata() { return true; }
+
+void AsyncConstantOp::getAsmResultNames(mlir::OpAsmSetValueNameFn setNameFn) {
+ setNameFn(result(), "cst");
+}
+
+static LogicalResult verifyOp(AsyncConstantOp op) {
+ if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.splat
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncSplatOp op) {
+ if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+bool AsyncSplatOp::preferCloneToConsumers() { return true; }
+
+//===----------------------------------------------------------------------===//
+// stream.async.clone
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncCloneOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+bool AsyncCloneOp::preferCloneToConsumers() { return true; }
+
+//===----------------------------------------------------------------------===//
+// stream.async.slice
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncSliceOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+bool AsyncSliceOp::isMetadata() { return true; }
+
+//===----------------------------------------------------------------------===//
+// stream.async.fill
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncFillOp op) {
+ if (failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value AsyncFillOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> AsyncFillOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> AsyncFillOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.update
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncUpdateOp op) {
+ if (failed(verifyOpValueSizes(op, op.update(), op.update_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+bool AsyncUpdateOp::isMetadata() { return true; }
+
+Value AsyncUpdateOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> AsyncUpdateOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> AsyncUpdateOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.copy
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncCopyOp op) {
+ if (op.source() == op.target()) {
+ // If we want to perform memmove-like operations where it's safe to copy
+ // overlapping ranges we'll need to emit some runtime checks. We can in
+ // many cases statically detect a lack of overlap just based on symbolic
+ // offset equality but that requires some analysis we don't have yet.
+ return op.emitOpError() << "cannot copy within the same resource (yet)";
+ }
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value AsyncCopyOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> AsyncCopyOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> AsyncCopyOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.transfer
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncTransferOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.dispatch
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncDispatchOp op) {
+ if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
+ failed(verifyOpValueSizes(op, op.results(), op.result_sizes()))) {
+ return failure();
+ }
+ return success();
+}
+
+std::pair<unsigned, unsigned> AsyncDispatchOp::getTiedOperandsIndexAndLength() {
+ return getODSOperandIndexAndLength(1); // $operands
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.execute
+//===----------------------------------------------------------------------===//
+
+void AsyncExecuteOp::build(OpBuilder &builder, OperationState &state,
+ TypeRange resultTypes, ValueRange resultSizes,
+ Value awaitTimepoint, ValueRange operands,
+ ValueRange operandSizes,
+ ArrayRef<int64_t> tiedOperands,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(resultTypes);
+ state.addTypes(IREE::Stream::TimepointType::get(builder.getContext()));
+ state.addOperands(operands);
+ state.addOperands(operandSizes);
+ state.addOperands(resultSizes);
+ if (awaitTimepoint) state.addOperands(awaitTimepoint);
+ state.addAttributes(attributes);
+ state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
+ state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
+ builder.getIndexArrayAttr(tiedOperands));
+ state.attributes.erase("operand_segment_sizes");
+ state.addAttribute("operand_segment_sizes",
+ builder.getI32VectorAttr({
+ static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(operandSizes.size()),
+ static_cast<int32_t>(resultSizes.size()),
+ awaitTimepoint ? 1 : 0,
+ }));
+ state.addRegion();
+}
+
+static LogicalResult verifyOp(AsyncExecuteOp op) {
+ if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
+ if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
+ failed(verifyOpValueSizes(op, op.results(), op.result_sizes()))) {
+ return failure();
+ }
+ if (failed(verifyAllResourcesCaptured(op.body())) ||
+ failed(verifyEscapingResources(op.body(), op.results(),
+ op.result_sizes()))) {
+ return failure();
+ }
+ return success();
+}
+
+std::pair<unsigned, unsigned> AsyncExecuteOp::getTiedResultsIndexAndLength() {
+ return {0, results().size()};
+}
+
+OperandRange AsyncExecuteOp::getSuccessorEntryOperands(unsigned index) {
+ assert(index == 0 && "invalid region index");
+ return operands();
+}
+
+void AsyncExecuteOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Unconditional control flow into the region and back to the parent, so
+ // return the correct RegionSuccessor purely based on the index being None or
+ // 0.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor(results()));
+ } else {
+ regions.push_back(RegionSuccessor(&body(), body().getArguments()));
+ }
+}
+
+Operation::operand_range AsyncExecuteOp::getClosureOperands() {
+ return operands();
+}
+
+Operation::result_range AsyncExecuteOp::getClosureResults() {
+ return results();
+}
+
+bool AsyncExecuteOp::canClosureContainOp(Operation *op) { return false; }
+
+IREE::Util::ValueAccess AsyncExecuteOp::getOperandAccess(
+ unsigned operandIndex) {
+ auto arg = body().getArgument(operandIndex);
+ return computeValueAccess(arg);
+}
+
+IREE::Util::ValueAccess AsyncExecuteOp::getResultAccess(unsigned resultIndex) {
+ auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
+ return computeValueAccess(yieldOp.getOperand(resultIndex));
+}
+
+IREE::Util::ClosureOpInterface
+AsyncExecuteOp::cloneReplacementExcludingOperandsAndResults(
+ ArrayRef<unsigned> excludedOperandIndices,
+ ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
+ auto newResultTypes = llvm::to_vector<4>(
+ llvm::map_range(results(), [](auto value) { return value.getType(); }));
+ auto newResultSizes = llvm::to_vector<4>(result_sizes());
+ auto newOperandsValues = llvm::to_vector<4>(operands());
+ auto newOperandSizes = llvm::to_vector<4>(operand_sizes());
+ IREE::Util::excludeClosureOperandsAndResults(
+ newOperandsValues, newOperandSizes, excludedOperandIndices,
+ newResultTypes, newResultSizes, excludedResultIndices);
+
+ auto newTiedOperandIndices =
+ llvm::to_vector<4>(getTiedResultOperandIndices());
+ IREE::Util::excludeTiedOperandAndResultIndices(
+ excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
+ assert(getTiedOperandsIndexAndLength().first == 0 &&
+ "operands must be the first ODS group");
+
+ auto newOp = rewriter.create<AsyncExecuteOp>(
+ getLoc(), newResultTypes, newResultSizes, await_timepoint(),
+ newOperandsValues, newOperandSizes, newTiedOperandIndices,
+ getOperation()->getAttrs());
+ auto &newBody = newOp.getClosureBodyRegion();
+ newBody.takeBody(getClosureBodyRegion());
+ eraseStreamRegionResults(newBody, excludedResultIndices);
+ newBody.front().eraseArguments(excludedOperandIndices);
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.concurrent
+//===----------------------------------------------------------------------===//
+
+void AsyncConcurrentOp::build(OpBuilder &builder, OperationState &state,
+ TypeRange resultTypes, ValueRange resultSizes,
+ ValueRange operands, ValueRange operandSizes,
+ ArrayRef<int64_t> tiedOperands,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(resultTypes);
+ state.addOperands(operands);
+ state.addOperands(operandSizes);
+ state.addOperands(resultSizes);
+ state.addAttributes(attributes);
+ state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
+ state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
+ builder.getIndexArrayAttr(tiedOperands));
+ state.attributes.erase("operand_segment_sizes");
+ state.addAttribute("operand_segment_sizes",
+ builder.getI32VectorAttr({
+ static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(operandSizes.size()),
+ static_cast<int32_t>(resultSizes.size()),
+ }));
+ state.addRegion();
+}
+
+static LogicalResult verifyOp(AsyncConcurrentOp op) {
+ if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
+ if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
+ failed(verifyOpValueSizes(op, op.results(), op.result_sizes()))) {
+ return failure();
+ }
+ if (failed(verifyAllResourcesCaptured(op.body())) ||
+ failed(verifyEscapingResources(op.body(), op.results(),
+ op.result_sizes()))) {
+ return failure();
+ }
+ return success();
+}
+
+OperandRange AsyncConcurrentOp::getSuccessorEntryOperands(unsigned index) {
+ assert(index == 0 && "invalid region index");
+ return operands();
+}
+
+void AsyncConcurrentOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Unconditional control flow into the region and back to the parent, so
+ // return the correct RegionSuccessor purely based on the index being None or
+ // 0.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor(results()));
+ } else {
+ regions.push_back(RegionSuccessor(&body(), body().getArguments()));
+ }
+}
+
+Operation::operand_range AsyncConcurrentOp::getClosureOperands() {
+ return operands();
+}
+
+Operation::result_range AsyncConcurrentOp::getClosureResults() {
+ return results();
+}
+
+bool AsyncConcurrentOp::canClosureContainOp(Operation *op) { return false; }
+
+IREE::Util::ValueAccess AsyncConcurrentOp::getOperandAccess(
+ unsigned operandIndex) {
+ auto arg = body().getArgument(operandIndex);
+ return computeValueAccess(arg);
+}
+
+IREE::Util::ValueAccess AsyncConcurrentOp::getResultAccess(
+ unsigned resultIndex) {
+ auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
+ return computeValueAccess(yieldOp.getOperand(resultIndex));
+}
+
+IREE::Util::ClosureOpInterface
+AsyncConcurrentOp::cloneReplacementExcludingOperandsAndResults(
+ ArrayRef<unsigned> excludedOperandIndices,
+ ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
+ auto newResultTypes = llvm::to_vector<4>(getResultTypes());
+ auto newResultSizes = llvm::to_vector<4>(result_sizes());
+ auto newOperandsValues = llvm::to_vector<4>(operands());
+ auto newOperandSizes = llvm::to_vector<4>(operand_sizes());
+ IREE::Util::excludeClosureOperandsAndResults(
+ newOperandsValues, newOperandSizes, excludedOperandIndices,
+ newResultTypes, newResultSizes, excludedResultIndices);
+
+ auto newTiedOperandIndices =
+ llvm::to_vector<4>(getTiedResultOperandIndices());
+ IREE::Util::excludeTiedOperandAndResultIndices(
+ excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
+ assert(getTiedOperandsIndexAndLength().first == 0 &&
+ "operands must be the first ODS group");
+
+ auto newOp = rewriter.create<AsyncConcurrentOp>(
+ getLoc(), newResultTypes, newResultSizes, newOperandsValues,
+ newOperandSizes, newTiedOperandIndices, getOperation()->getAttrs());
+ auto &newBody = newOp.getClosureBodyRegion();
+ newBody.takeBody(getClosureBodyRegion());
+ eraseStreamRegionResults(newBody, excludedResultIndices);
+ newBody.front().eraseArguments(excludedOperandIndices);
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.flush
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdFlushOp op) {
+ if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.invalidate
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdInvalidateOp op) {
+ if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.discard
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdDiscardOp op) {
+ if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.fill
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdFillOp op) {
+ if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.copy
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdCopyOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
+ failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.dispatch
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdDispatchOp op) {
+ size_t resourceCount = op.resources().size();
+ if (op.resource_sizes().size() != resourceCount ||
+ op.resource_offsets().size() != resourceCount ||
+ op.resource_lengths().size() != resourceCount ||
+ op.resource_accesses().size() != resourceCount) {
+ return op->emitOpError() << "dispatch with " << resourceCount
+ << " resources has mismatched associated ranges";
+ }
+ return success();
+}
+
+static ParseResult parseDispatchResources(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &resources,
+ SmallVectorImpl<Type> &resourceTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resourceSizes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resourceOffsets,
+ SmallVectorImpl<OpAsmParser::OperandType> &resourceLengths,
+ ArrayAttr &resourceAccesses) {
+ SmallVector<Attribute> accessAttrs;
+ do {
+ // Reserve entries in the lists.
+ resources.emplace_back();
+ resourceTypes.emplace_back();
+ resourceSizes.emplace_back();
+ resourceOffsets.emplace_back();
+ resourceLengths.emplace_back();
+ StringRef accessStr;
+ if (failed(parser.parseKeyword(&accessStr)) ||
+ failed(parser.parseOperand(resources.back())) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(resourceOffsets.back())) ||
+ failed(parser.parseKeyword("for")) ||
+ failed(parser.parseOperand(resourceLengths.back())) ||
+ failed(parser.parseRSquare()) || failed(parser.parseColon()) ||
+ failed(parseSizeAwareType(parser, resourceTypes.back(),
+ resourceSizes.back()))) {
+ return failure();
+ }
+ IREE::Stream::ResourceAccessBitfield accessBits =
+ IREE::Stream::ResourceAccessBitfield::None;
+ if (accessStr == "ro") {
+ accessBits = IREE::Stream::ResourceAccessBitfield::Read;
+ } else if (accessStr == "wo") {
+ accessBits = IREE::Stream::ResourceAccessBitfield::Write;
+ } else if (accessStr == "rw") {
+ accessBits = IREE::Stream::ResourceAccessBitfield::Read |
+ IREE::Stream::ResourceAccessBitfield::Write;
+ }
+ accessAttrs.push_back(IREE::Stream::ResourceAccessBitfieldAttr::get(
+ parser.getBuilder().getContext(), accessBits));
+ } while (succeeded(parser.parseOptionalComma()));
+ resourceAccesses = parser.getBuilder().getArrayAttr(accessAttrs);
+ return success();
+}
+
+static void printDispatchResources(OpAsmPrinter &p, Operation *op,
+ ValueRange resources,
+ TypeRange resourceTypes,
+ ValueRange resourceSizes,
+ ValueRange resourceOffsets,
+ ValueRange resourceLengths,
+ ArrayAttr resourceAccesses) {
+ for (size_t i = 0; i < resources.size(); ++i) {
+ auto resource = resources[i];
+ auto resourceType = resourceTypes[i];
+ auto resourceSize = resourceSizes[i];
+ auto resourceOffset = resourceOffsets[i];
+ auto resourceLength = resourceLengths[i];
+ auto resourceAccess = resourceAccesses[i]
+ .cast<IREE::Stream::ResourceAccessBitfieldAttr>()
+ .getValue();
+ p.printNewline();
+ p << " ";
+ if (bitEnumContains(resourceAccess,
+ IREE::Stream::ResourceAccessBitfield::Read) &&
+ bitEnumContains(resourceAccess,
+ IREE::Stream::ResourceAccessBitfield::Write)) {
+ p << "rw";
+ } else if (bitEnumContains(resourceAccess,
+ IREE::Stream::ResourceAccessBitfield::Read)) {
+ p << "ro";
+ } else if (bitEnumContains(resourceAccess,
+ IREE::Stream::ResourceAccessBitfield::Write)) {
+ p << "wo";
+ }
+ p << ' ';
+ p.printOperand(resource);
+ p << "[";
+ p.printOperand(resourceOffset);
+ p << " for ";
+ p.printOperand(resourceLength);
+ p << "] : ";
+ printSizeAwareType(p, op, resourceType, resourceSize);
+ if (i < resources.size() - 1) p << ",";
+ }
+}
+
+// This is sloppy because the function has interleaved bindings and operands;
+// if we had our own op we could just reuse the map we have for operands.
+// static
+SmallVector<unsigned> CmdDispatchOp::makeOperandToArgMap(mlir::FuncOp funcOp) {
+ unsigned operandCount = llvm::count_if(
+ funcOp.getArgumentTypes(),
+ [](Type type) { return !type.isa<IREE::Stream::BindingType>(); });
+ SmallVector<unsigned> map(operandCount);
+ unsigned operandIdx = 0;
+ for (auto it : llvm::enumerate(funcOp.getArgumentTypes())) {
+ unsigned argIdx = it.index();
+ auto argType = it.value();
+ if (!argType.isa<IREE::Stream::BindingType>()) {
+ map[operandIdx++] = argIdx;
+ }
+ }
+ return map;
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.execute
+//===----------------------------------------------------------------------===//
+
+void CmdExecuteOp::build(OpBuilder &builder, OperationState &state,
+ Value awaitTimepoint, ValueRange operands,
+ ValueRange operandSizes,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(IREE::Stream::TimepointType::get(builder.getContext()));
+ state.addOperands(operands);
+ state.addOperands(operandSizes);
+ if (awaitTimepoint) state.addOperands(awaitTimepoint);
+ state.addAttributes(attributes);
+ state.attributes.erase("operand_segment_sizes");
+ state.addAttribute("operand_segment_sizes",
+ builder.getI32VectorAttr({
+ static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(operandSizes.size()),
+ awaitTimepoint ? 1 : 0,
+ }));
+ state.addRegion();
+}
+
+// Returns success if the given op is a known valid stream.cmd.* op for use
+// within an execution region.
+static LogicalResult verifyCmdOp(Operation *op) {
+ // TODO(benvanik): add a trait that lets us avoid this switch.
+ if (!TypeSwitch<Operation *, bool>(op)
+ .Case<IREE::Stream::CmdFlushOp, IREE::Stream::CmdInvalidateOp,
+ IREE::Stream::CmdDiscardOp, IREE::Stream::CmdFillOp,
+ IREE::Stream::CmdCopyOp, IREE::Stream::CmdDispatchOp,
+ IREE::Stream::CmdSerialOp, IREE::Stream::CmdConcurrentOp>(
+ [](auto op) { return true; })
+ .Case<IREE::Stream::YieldOp>([](auto op) { return true; })
+ .Default(false)) {
+ return op->emitOpError()
+ << "explicit execution regions must only contain explicit ops";
+ }
+ return success();
+}
+
+static LogicalResult verifyOp(CmdExecuteOp op) {
+ if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
+ if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes()))) {
+ return failure();
+ }
+ if (failed(verifyAllResourcesCaptured(op.body()))) {
+ return failure();
+ }
+ for (auto &nestedOp : op.body().front()) {
+ if (failed(verifyCmdOp(&nestedOp))) return failure();
+ }
+ return success();
+}
+
+OperandRange CmdExecuteOp::getSuccessorEntryOperands(unsigned index) {
+ assert(index == 0 && "invalid region index");
+ return operands();
+}
+
+void CmdExecuteOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Unconditional control flow into the region and back to the parent, so
+ // return the correct RegionSuccessor purely based on the index being None or
+ // 0.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor({}));
+ } else {
+ regions.push_back(RegionSuccessor(&body(), body().getArguments()));
+ }
+}
+
+Operation::operand_range CmdExecuteOp::getClosureOperands() {
+ return operands();
+}
+
+Operation::result_range CmdExecuteOp::getClosureResults() {
+ return Operation::result_range(nullptr, 0);
+}
+
+bool CmdExecuteOp::canClosureContainOp(Operation *op) { return false; }
+
+IREE::Util::ValueAccess CmdExecuteOp::getOperandAccess(unsigned operandIndex) {
+ auto arg = body().getArgument(operandIndex);
+ return computeValueAccess(arg);
+}
+
+IREE::Util::ValueAccess CmdExecuteOp::getResultAccess(unsigned resultIndex) {
+ return IREE::Util::ValueAccess::None();
+}
+
+IREE::Util::ClosureOpInterface
+CmdExecuteOp::cloneReplacementExcludingOperandsAndResults(
+ ArrayRef<unsigned> excludedOperandIndices,
+ ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
+ SmallVector<Type, 4> newResultTypes;
+ SmallVector<Value, 4> newResultSizes;
+ auto newOperandsValues = llvm::to_vector<4>(operands());
+ auto newOperandSizes = llvm::to_vector<4>(operand_sizes());
+ IREE::Util::excludeClosureOperandsAndResults(
+ newOperandsValues, newOperandSizes, excludedOperandIndices,
+ newResultTypes, newResultSizes, excludedResultIndices);
+
+ auto newOp = rewriter.create<CmdExecuteOp>(getLoc(), await_timepoint(),
+ newOperandsValues, newOperandSizes,
+ getOperation()->getAttrs());
+ auto &newBody = newOp.getClosureBodyRegion();
+ newBody.takeBody(getClosureBodyRegion());
+ newBody.front().eraseArguments(excludedOperandIndices);
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.serial
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdSerialOp op) {
+ for (auto &nestedOp : op.body().front()) {
+ if (failed(verifyCmdOp(&nestedOp))) return failure();
+ }
+ return success();
+}
+
+void CmdSerialOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Unconditional control flow into the region and back to the parent, so
+ // return the correct RegionSuccessor purely based on the index being None or
+ // 0.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor({}));
+ } else {
+ regions.push_back(RegionSuccessor(&body(), {}));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// stream.cmd.concurrent
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(CmdConcurrentOp op) {
+ for (auto &nestedOp : op.body().front()) {
+ if (failed(verifyCmdOp(&nestedOp))) return failure();
+ }
+ return success();
+}
+
+void CmdConcurrentOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Unconditional control flow into the region and back to the parent, so
+ // return the correct RegionSuccessor purely based on the index being None or
+ // 0.
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor({}));
+ } else {
+ regions.push_back(RegionSuccessor(&body(), {}));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// stream.timepoint.join
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(TimepointJoinOp op) {
+ // We could test if timepoints all come from the same place - this is not
+ // strictly required but if we could avoid it things will be easier to
+ // implement at runtime (won't have to do a cuda<->vulkan sync, etc).
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.timepoint.await
+//===----------------------------------------------------------------------===//
+
+void TimepointAwaitOp::build(OpBuilder &builder, OperationState &state,
+ ValueRange operands, ValueRange operandSizes,
+ Value timepoint,
+ ArrayRef<NamedAttribute> attributes) {
+ state.addTypes(llvm::map_range(
+ operands, [&](Value operand) { return operand.getType(); }));
+ state.addOperands(operands);
+ state.addOperands(operandSizes);
+ state.addOperands(timepoint);
+ state.addAttributes(attributes);
+ state.attributes.erase("operand_segment_sizes");
+ state.addAttribute("operand_segment_sizes",
+ builder.getI32VectorAttr({
+ static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(operandSizes.size()),
+ static_cast<int32_t>(1), // timepoint
+ }));
+}
+
+static LogicalResult verifyOp(TimepointAwaitOp op) {
+ if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
+ failed(verifyOpValueSizes(op, op.results(), op.operand_sizes()))) {
+ return failure();
+ }
+ return success();
+}
+
+::llvm::Optional<unsigned> TimepointAwaitOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {resultIndex};
+}
+
+SmallVector<int64_t, 4> TimepointAwaitOp::getTiedResultOperandIndices() {
+ return llvm::to_vector<4>(llvm::seq<int64_t>(0, operands().size()));
+}
+
+//===----------------------------------------------------------------------===//
+// stream.executable
+//===----------------------------------------------------------------------===//
+
+void ExecutableOp::build(OpBuilder &builder, OperationState &state,
+ StringRef sym_name) {
+ ensureTerminator(*state.addRegion(), builder, state.location);
+ state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(sym_name));
+}
+
+static LogicalResult verifyOp(ExecutableOp op) {
+ // TODO(benvanik): check export name conflicts.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.executable.entry
+//===----------------------------------------------------------------------===//
+
+void ExecutableExportOp::build(OpBuilder &builder, OperationState &state,
+ StringRef sym_name,
+ FlatSymbolRefAttr function_ref) {
+ build(builder, state, /*sym_visibility=*/nullptr,
+ builder.getStringAttr(sym_name), function_ref);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.binding.subspan
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(BindingSubspanOp op) {
+ if (auto shapedType = op.getType().dyn_cast<ShapedType>()) {
+ if (failed(verifyOpDynamicDims(op, shapedType, op.dynamic_dims()))) {
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+Value BindingSubspanOp::buildOperandRankedShape(unsigned idx,
+ OpBuilder &builder) {
+ return {};
+}
+
+Value BindingSubspanOp::buildResultRankedShape(unsigned idx,
+ OpBuilder &builder) {
+ return Shape::buildRankedShapeForValue(getLoc(), result(), dynamic_dims(),
+ builder);
+}
+
+//===----------------------------------------------------------------------===//
+// stream.yield
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange YieldOp::getMutableSuccessorOperands(
+ Optional<unsigned> index) {
+ return operandsMutable();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// TableGen definitions (intentionally last)
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.cpp.inc" // IWYU pragma: keep
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.h b/iree/compiler/Dialect/Stream/IR/StreamOps.h
new file mode 100644
index 0000000..a57fea1
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.h
@@ -0,0 +1,34 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMOPS_H_
+#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMOPS_H_
+
+#include <cstdint>
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionSupport.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+
+#define GET_OP_CLASSES
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h.inc" // IWYU pragma: export
+
+#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAMOPS_H_
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.td b/iree/compiler/Dialect/Stream/IR/StreamOps.td
new file mode 100644
index 0000000..9d77a6e
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -0,0 +1,2585 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_STREAM_OPS
+#define IREE_DIALECT_STREAM_OPS
+
+include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
+include "iree/compiler/Dialect/Stream/IR/StreamBase.td"
+include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td"
+include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
+
+class Stream_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ Stream_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+
+//===----------------------------------------------------------------------===//
+// Generic resource ops
+//===----------------------------------------------------------------------===//
+
+def Stream_ResourceAllocOp : Stream_PureOp<"resource.alloc", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Util_SizeAwareOp,
+ MemoryEffects<[MemAlloc]>,
+]> {
+ let summary = [{allocates a persistent value with undefined contents}];
+ let description = [{
+ Allocates a persistent value (one that is long-lived and possibly external
+ to the program) with undefined contents. Consumers of the allocated
+ result must assume nothing of the contents and use `discard` access.
+
+ Uninitialized allocations will have undefined contents and must only be used
+ when all bytes are discarded prior to any reads. Runtimes decide what
+ "undefined contents" means and here it only indicates that execution will be
+ correct even if the memory starts with non-zero values.
+
+ If multiple values are allocated from the same operation it implies that
+ they have matching lifetimes. When lowering to execution environments the
+ separate allocations may be fused into one or more slab allocations in order
+ to reduce overheads. How many allocations can be fused is based on the size
+ of the individual resources and the target constraints (how large any single
+ buffer may be, etc). At the stream dialect level treat a multi-result alloc
+ as a way to indicate similar lifetimes.
+ }];
+
+ let arguments = (ins
+ Variadic<Stream_Size>:$storage_sizes,
+ UnitAttr:$uninitialized,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<Stream_AnyResource>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`uninitialized` $uninitialized^)?
+ attr-dict `:` custom<SizeAwareTypeList>(type($results), $storage_sizes)
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return storage_sizes()[idx]; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourceAllocaOp : Stream_PureOp<"resource.alloca", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Util_SizeAwareOp,
+ MemoryEffects<[MemAlloc]>,
+]> {
+ let summary = [{allocates a transient value with undefined contents}];
+ let description = [{
+ Allocates a transient value (one that is short-lived and local to the
+ current computation) with undefined contents. Consumers of the allocated
+ result must assume nothing of the contents and use `discard` access.
+
+ The resource returned is not valid for use until the timepoint is reached;
+ execution using this resource must await on the timepoint.
+ }];
+
+ let arguments = (ins
+ Stream_Size:$storage_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ AnyTypeOf<[
+ Stream_StagingResource,
+ Stream_TransientResource,
+ ]>:$result,
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ `uninitialized`
+ attr-dict `:`
+ type($result) `{` $storage_size `}`
+ `=` `` `>`
+ type($result_timepoint)
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return storage_size(); }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourceDeallocaOp : Stream_Op<"resource.dealloca", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Util_SizeAwareOp,
+ MemoryEffects<[MemFree]>,
+]> {
+ let summary = [{frees a transient value when available}];
+ let description = [{
+ Deallocates a transient value (one that is short-lived and local to the
+ current computation) previously allocated using `stream.resource.alloca`.
+
+ The resource is considered live and valid until the provided timepoint is
+ reached and the memory is only made available for future requests afterward.
+ }];
+
+ let arguments = (ins
+ AnyTypeOf<[
+ Stream_StagingResource,
+ Stream_TransientResource,
+ ]>:$operand,
+ Stream_Size:$operand_size,
+ Stream_Timepoint:$timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ `await` `(` $timepoint `)`
+ `=` `` `>`
+ $operand `:` type($operand) `{` $operand_size `}`
+ attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return operand_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourceSizeOp : Stream_PureOp<"resource.size", [
+ Stream_AffinityOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{returns the size of the resource storage in bytes}];
+ let description = [{
+ Returns a possibly runtime-dynamic byte size of the resource backing
+ storage. This may differ from the logical storage size of a value based on
+ the alignment requirements of the target as well as encoding of higher level
+ values such as sparse tensor formats.
+ }];
+
+ let arguments = (ins
+ Stream_AnyResource:$operand,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Size:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $operand
+ attr-dict `:` type($operand)
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return result(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasFolder = 1;
+}
+
+def Stream_ResourceMapOp : Stream_Op<"resource.map", [
+ Stream_AffinityOp,
+ Util_SizeAwareOp,
+ MemoryEffects<[MemAlloc]>,
+]> {
+ let summary = [{maps read-only memory into a staging resource}];
+ let description = [{
+ Synchronously maps a host heap buffer into a stream-accessible staging
+ resource. Will never fail but may induce a copy if required and as such the
+ mapped resource is not coherent with the original source buffer: changing
+ the source buffer after mapping has undefined behavior.
+ }];
+
+ let arguments = (ins
+ ByteBufferType:$source,
+ Stream_Offset:$source_offset,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_StagingResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `[` $source_offset `]` `:`
+ type($source)
+ `->`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourceTryMapOp : Stream_PureOp<"resource.try_map", [
+ Stream_AffinityOp,
+ Util_SizeAwareOp,
+ MemoryEffects<[MemAlloc]>,
+]> {
+ let summary = [{maps read-only memory into a resource}];
+ let description = [{
+ Synchronously maps a host heap buffer into a stream-accessible resource
+ with constant lifetime. If the given source cannot be mapped into a constant
+ a failure is returned and the resulting resource value is null. As with
+ `stream.resource.map` the resulting resource is not coherent with the source
+ and changes will not be reflected.
+ }];
+
+ let arguments = (ins
+ ByteBufferType:$source,
+ Stream_Offset:$source_offset,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ I1:$did_map,
+ Stream_ConstantResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `[` $source_offset `]` `:`
+ type($source)
+ `->`
+ type($did_map) `,` type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourceLoadOp : Stream_Op<"resource.load", [
+ Util_SizeAwareOp,
+]> {
+ let summary = [{loads a value from a staging resource}];
+ let description = [{
+ Returns the element(s) at the given offset in the staging resource.
+ The operation will complete synchronously against the resource though it may
+ introduce a yield point if the staging resource needs to be transferred.
+ }];
+
+ let arguments = (ins
+ Stream_StagingResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset
+ );
+ let results = (outs
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` `:`
+ type($source) `` `{` $source_size `}`
+ `->`
+ type($result)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourceStoreOp : Stream_Op<"resource.store", [
+ Util_SizeAwareOp,
+ MemoryEffects<[MemWrite]>,
+]> {
+ let summary = [{stores a value into a staging resource}];
+ let description = [{
+ The operation will complete synchronously against the resource though it may
+ introduce a yield point if the staging resource needs to be acquired.
+ }];
+
+ let arguments = (ins
+ Stream_StagingResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$value
+ );
+
+ let assemblyFormat = [{
+ $value `,`
+ $target `[` $target_offset `]` `:`
+ type($value)
+ `->`
+ type($target) `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_ResourcePackOp : Stream_PureOp<"resource.pack", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ Stream_AffinityOp,
+]> {
+ let summary = [{packs variable-sized slices into a single slab}];
+ let description = [{
+ Performs a greedy packing of one or more sized slices with specified
+ lifetimes and returns their relative offsets in an aliased linear space.
+
+ Slices are `[start, end] = %slice_byte_size`, where the start and end values
+ define an inclusive lifetime range and the size is the total number of bytes
+ required to be live for that range.
+
+ ```mlir
+ // Computes the total length required for the packed values and the offsets
+ // of the 3 slices requested relative to the base of the packed memory:
+ %total_length, %offset_0, %offset_1, %offset_2 =
+ stream.resource.pack
+ // Each slice gets one result offset:
+ slices({
+ // 3 slices where A and B overlap and will get unique offsets
+ // while B and C do not overlap and are allowed to alias.
+ [0, 10] = %size_0, // A => %offset_0
+ [3, 8] = %size_1, // B => %offset_1
+ [9, 10] = %size_2, // C => %offset_2
+ ...
+ }) : index
+ ```
+
+ The lifetime start and end points (inclusive) are only used for relative
+ comparisons and may originate with any meaning (op order in block, epoch,
+ phase of the moon, etc). The packing algorithm uses the intervals to
+ determine slice liveness and when aliasing is safe.
+
+ The size of each slice may either be a constant or runtime-computed dynamic
+ value. Constant slices can achieve more dense packing than the dynamic
+ values and CSE/canonicalization should be applied to ensure that as many of
+ the dynamic values are equivalent if possible.
+
+ The total length required to pack all slices is returned and can be used to
+ acquire storage. The individual slice offsets are 0-based and as such if are
+ directly used as buffer offsets may need additional offsetting. This can
+ either be applied via the optional `offset` operand or slicing of the
+ underlying allocation buffer.
+ }];
+
+ let arguments = (ins
+ Optional<Stream_Offset>:$offset,
+ Stream_IndexArrayAttr:$lifetime_intervals,
+ Variadic<Stream_Size>:$dynamic_slice_sizes,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Size:$total_length,
+ Variadic<Stream_Offset>:$packed_offsets
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`offset` `(` $offset^ `)`)?
+ `slices` `(` `{`
+ custom<PackSliceRanges>($lifetime_intervals,
+ $dynamic_slice_sizes,
+ type($packed_offsets))
+ `}` `)`
+ `:` type($total_length)
+ attr-dict-with-keyword
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let extraClassDeclaration = [{
+ struct Slice {
+ int64_t lifetimeStart;
+ int64_t lifetimeEnd;
+ Value dynamicSize;
+ Value packedOffset;
+
+ bool operator==(const Slice &rhs) const {
+ return lifetimeStart == rhs.lifetimeStart &&
+ lifetimeEnd == rhs.lifetimeEnd;
+ }
+ bool operator!=(const Slice &rhs) const {
+ return !(*this == rhs);
+ }
+ bool operator<(const Slice &rhs) const {
+ return std::make_pair(lifetimeStart, lifetimeEnd) <
+ std::make_pair(rhs.lifetimeStart, rhs.lifetimeEnd);
+ }
+ bool intersects(const Slice &rhs) const {
+ return lifetimeEnd >= rhs.lifetimeStart &&
+ rhs.lifetimeEnd >= lifetimeStart;
+ }
+ };
+
+ /// Returns all of the slices to be packed.
+ /// Order is ascending by lifetime interval (post-canonicalization).
+ SmallVector<Slice> getSlices();
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_ResourceConstantsOp : Stream_PureOp<"resource.constants", [
+ SameVariadicResultSize,
+ Stream_AffinityOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{asynchronously uploads or maps constant values}];
+ let description = [{
+ Represents an upload of constant resources that may be packed, suballocated,
+ and mapped depending on the final lowering target.
+
+ In runtime environments where memory is shared between host and device this
+ turns into a mapping operation that avoids additional memory allocation and
+ copies. When memory cannot be shared an asynchronous stream will be created
+ to allocate and copy all of the constant values.
+
+ Though this op returns a unique resource for each constant value it's
+ expected that almost all end up aliasing into the same storage. The exact
+ packing and number of storage resources that are needed are not known until
+ lowering to a particular backend, though, so they are separate here for
+ proper usage tracking.
+
+ Both constant and variable resources can be produced; a constant is
+ immutable while a variable will be treated as a constant-value initializer
+ for a mutable resource. By modeling these together it's not required that
+ variable initializers first be allocated, copied to the target, and then
+ copied into the variable storage if the target is capable of doing a direct
+ upload or mapping.
+ }];
+
+ let arguments = (ins
+ TypedArrayAttrBase<ElementsAttr, "constant value array attribute">:$values,
+ Variadic<Stream_Size>:$result_sizes,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyTypeOf<[
+ Stream_ConstantResource,
+ Stream_VariableResource,
+ ]>>:$results,
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ attr-dict `:`
+ custom<ConstantValueList>(type($results),
+ $result_sizes,
+ $values)
+ `\n` ` ` ` ` `=` `` `>` type($result_timepoint)
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_sizes()[idx]; }
+ }];
+}
+
+def Stream_ResourceSubviewOp : Stream_PureOp<"resource.subview", [
+ AllTypesMatch<["source", "result"]>,
+ DeclareOpInterfaceMethods<ViewLikeOpInterface>,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "isMetadata",
+ ]>,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{slices out a cloned subview of a value}];
+ let description = [{
+ Aliases a byte subrange of a resource.
+ }];
+
+ let arguments = (ins
+ Stream_AnyResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset,
+ Stream_Size:$result_size
+ );
+ let results = (outs
+ Stream_AnyResource:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return result_size(); }
+
+ // Walks up the use-def chain to find a subview op that feeds into |value|.
+ static IREE::Stream::ResourceSubviewOp findSubviewOp(Value value);
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Pseudo ops for conversion support
+//===----------------------------------------------------------------------===//
+
+def Stream_TensorImportOp : Stream_PureOp<"tensor.import", [
+ Stream_AffinityOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{conversion placeholder for other->stream type conversion}];
+ let description = [{
+ Defines a conversion from a higher-level dialect type such as `tensor` that
+ is resolved during lowering into the stream dialect. This can be used to
+ interoperate between levels of the stack that require specifying stream
+ types and those that prior to lowering do not handle them.
+ }];
+
+ let arguments = (ins
+ AnyType:$source,
+ TypeAttr:$result_encoding,
+ Stream_ShapeDynamicDims:$result_encoding_dims,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_ExternalResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `:`
+ type($source)
+ `->`
+ $result_encoding (`` `{` $result_encoding_dims^ `}`)?
+ `in`
+ type($result) `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return ValueRange{}; }
+ ValueRange getResultDynamicDims(unsigned idx) { return result_encoding_dims(); }
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_TensorExportOp : Stream_PureOp<"tensor.export", [
+ Stream_AffinityOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{conversion placeholder for stream->other type conversion}];
+ let description = [{
+ Defines a conversion to a higher-level dialect type such as `tensor` that
+ is resolved during lowering into the stream dialect. This can be used to
+ interoperate between levels of the stack that require specifying stream
+ types and those that prior to lowering do not handle them.
+ }];
+
+ let arguments = (ins
+ Stream_ExternalResource:$source,
+ TypeAttr:$source_encoding,
+ Stream_ShapeDynamicDims:$source_encoding_dims,
+ Stream_Size:$source_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `:`
+ $source_encoding (`` `{` $source_encoding_dims^ `}`)?
+ `in`
+ type($source) `` `{` $source_size `}`
+ `->`
+ type($result)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return source_encoding_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return ValueRange{}; }
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Builtin tensor ops
+//===----------------------------------------------------------------------===//
+
+def Stream_TensorSizeOfOp : Stream_PureOp<"tensor.sizeof", [
+ Stream_AffinityOp,
+ Stream_TensorPhaseOp,
+]> {
+ let summary = [{calculates the storage size of a given high-level type}];
+ let description = [{
+ Target-dependent storage size calculation using a high-level annotated type.
+ While within the stream dialect the storage size of a value is left as a
+ placeholder using this op. The requisite target-specific parameters for
+ expanding the size calculation are only available after affinities have been
+ assigned.
+ }];
+
+ let arguments = (ins
+ TypeAttr:$encoding,
+ Stream_ShapeDynamicDims:$encoding_dims,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Size:$storage_size
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $encoding (`{` $encoding_dims^ `}`)?
+ attr-dict `:` type($storage_size)
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+}
+
+def Stream_TensorConstantOp : Stream_PureOp<"tensor.constant", [
+ Stream_AffinityOp,
+ Stream_StreamableOp,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{defines a constant tensor value}];
+ let description = [{
+ Returns a typed resource initialized to the given constant value.
+ }];
+
+ let arguments = (ins
+ ElementsAttr:$value,
+ TypeAttr:$result_encoding,
+ Stream_ShapeDynamicDims:$result_encoding_dims,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ `:`
+ $result_encoding (`` `{` $result_encoding_dims^ `}`)?
+ `in`
+ type($result)
+ `=`
+ $value
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return ValueRange{}; }
+ ValueRange getResultDynamicDims(unsigned idx) { return result_encoding_dims(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_TensorSplatOp : Stream_PureOp<"tensor.splat", [
+ Stream_AffinityOp,
+ Stream_StreamableOp,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{splats a value into a shaped tensor}];
+ let description = [{
+ Returns a typed resource initialized to the given primitive value.
+ }];
+
+ let arguments = (ins
+ Stream_PrimitiveType:$value,
+ TypeAttr:$result_encoding,
+ Stream_ShapeDynamicDims:$result_encoding_dims,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $value
+ `:` type($value)
+ `->`
+ $result_encoding (`` `{` $result_encoding_dims^ `}`)?
+ `in`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return ValueRange{}; }
+ ValueRange getResultDynamicDims(unsigned idx) { return result_encoding_dims(); }
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_TensorCloneOp : Stream_PureOp<"tensor.clone", [
+ AttrSizedOperandSegments,
+ Stream_AffinityOp,
+ Stream_StreamableOp,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{clones the contents of a value}];
+ let description = [{
+ Clones the contents of a value at a snapshot in time. Future changes to the
+ cloned value will not effect the result. Acts as a copy-on-write operation.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$source,
+ TypeAttr:$source_encoding,
+ Stream_ShapeDynamicDims:$source_encoding_dims,
+ Stream_Size:$source_size,
+ TypeAttr:$result_encoding,
+ Stream_ShapeDynamicDims:$result_encoding_dims,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `:`
+ $source_encoding (`` `{` $source_encoding_dims^ `}`)?
+ `in`
+ type($source) `` `{` $source_size `}`
+ `->`
+ $result_encoding (`` `{` $result_encoding_dims^ `}`)?
+ `in`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return source_encoding_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return result_encoding_dims(); }
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_TensorSliceOp : Stream_PureOp<"tensor.slice", [
+ AttrSizedOperandSegments,
+ Stream_AffinityOp,
+ Stream_StreamableOp,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{slices out a cloned subview of a value}];
+ let description = [{
+ Slices a subrange of a stream resource based on a tensor encoding. Acts as a
+ copy-on-write operation.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$source,
+ TypeAttr:$source_encoding,
+ Stream_ShapeDynamicDims:$source_encoding_dims,
+ Stream_Size:$source_size,
+ Variadic<Stream_Dim>:$start_indices,
+ Variadic<Stream_Dim>:$lengths,
+ TypeAttr:$result_encoding,
+ Stream_ShapeDynamicDims:$result_encoding_dims,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `[` $start_indices `for` $lengths `]` `:`
+ $source_encoding (`` `{` $source_encoding_dims^ `}`)?
+ `in`
+ type($source) `` `{` $source_size `}`
+ `->`
+ $result_encoding (`` `{` $result_encoding_dims^ `}`)?
+ `in`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return source_encoding_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return result_encoding_dims(); }
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_TensorFillOp : Stream_Op<"tensor.fill", [
+ AttrSizedOperandSegments,
+ AllTypesMatch<["target", "result"]>,
+ Stream_AffinityOp,
+ Stream_StreamableOp,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{fills a subview of a stream resource with a value}];
+ let description = [{
+ Splats a value into a subview of the given stream resource and returns the
+ resource with the update applied.
+
+ Equivalent to a stream.tensor.splat + stream.tensor.update.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$target,
+ TypeAttr:$target_encoding,
+ Stream_ShapeDynamicDims:$target_encoding_dims,
+ Stream_Size:$target_size,
+ Variadic<Stream_Dim>:$start_indices,
+ Variadic<Stream_Dim>:$lengths,
+ Stream_PrimitiveType:$value,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $value `,` $target `[` $start_indices `for` $lengths `]` `:`
+ type($value)
+ `->`
+ $target_encoding (`` `{` $target_encoding_dims^ `}`)?
+ `in`
+ custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return target_encoding_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return target_encoding_dims(); }
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_TensorUpdateOp : Stream_Op<"tensor.update", [
+ AttrSizedOperandSegments,
+ AllTypesMatch<["target", "result"]>,
+ Stream_AffinityOp,
+ Stream_StreamableOp,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{updates a slice of a subview of a resource in-place}];
+ let description = [{
+ Copies a value into a resource based on tensor encodings. The returned value
+ is the entire updated target value.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$target,
+ TypeAttr:$target_encoding,
+ Stream_ShapeDynamicDims:$target_encoding_dims,
+ Stream_Size:$target_size,
+ Variadic<Stream_Dim>:$start_indices,
+ Stream_AnyStreamResource:$update,
+ TypeAttr:$update_encoding,
+ Stream_ShapeDynamicDims:$update_encoding_dims,
+ Stream_Size:$update_size,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $update `,` $target `[` $start_indices `]` `:`
+ $update_encoding (`` `{` $update_encoding_dims^ `}`)?
+ `in`
+ type($update) `` `{` $update_size `}`
+ `->`
+ $target_encoding (`` `{` $target_encoding_dims^ `}`)?
+ `in`
+ custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) {
+ return idx == 0 ? target_encoding_dims() : update_encoding_dims();
+ }
+ ValueRange getResultDynamicDims(unsigned idx) { return target_encoding_dims(); }
+ Value getOperandSize(unsigned idx) {
+ return idx == 0 ? target_size() : update_size();
+ }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_TensorLoadOp : Stream_PureOp<"tensor.load", [
+ AttrSizedOperandSegments,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{loads a value from a tensor element}];
+ let description = [{
+ Returns the element at the given location from within the tensor.
+ }];
+
+ let arguments = (ins
+ Stream_StagingResource:$source,
+ TypeAttr:$source_encoding,
+ Stream_ShapeDynamicDims:$source_encoding_dims,
+ Stream_Size:$source_size,
+ Variadic<Stream_Dim>:$indices
+ );
+ let results = (outs
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$result
+ );
+
+ let assemblyFormat = [{
+ $source (`[` $indices^ `]`)? `:`
+ $source_encoding (`` `{` $source_encoding_dims^ `}`)?
+ `in`
+ type($source) `` `{` $source_size `}`
+ `->`
+ type($result)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return source_encoding_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return ValueRange{}; }
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_TensorStoreOp : Stream_Op<"tensor.store", [
+ AttrSizedOperandSegments,
+ AllTypesMatch<["target", "result"]>,
+ Stream_TensorPhaseOp,
+ Util_ShapeAwareOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{stores a value into a tensor element}];
+ let description = [{
+ Returns a tensor with the element at the given index set to the given value.
+ }];
+
+ let arguments = (ins
+ Stream_StagingResource:$target,
+ TypeAttr:$target_encoding,
+ Stream_ShapeDynamicDims:$target_encoding_dims,
+ Stream_Size:$target_size,
+ Variadic<Stream_Dim>:$indices,
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$value,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands
+ );
+ let results = (outs
+ Stream_StagingResource:$result
+ );
+
+ let assemblyFormat = [{
+ $value `,`
+ $target (`[` $indices^ `]`)? `:`
+ type($value)
+ `->`
+ $target_encoding (`` `{` $target_encoding_dims^ `}`)?
+ `in`
+ custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return target_encoding_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return target_encoding_dims(); }
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Resource transfer ops
+//===----------------------------------------------------------------------===//
+
+def Stream_AsyncAllocaOp : Stream_PureOp<"async.alloca", [
+ DeclareOpInterfaceMethods<Stream_AffinityOp, [
+ "getAffinity",
+ "setAffinity",
+ ]>,
+ Stream_AsyncPhaseOp,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "isMetadata",
+ "preferCloneToConsumers",
+ ]>,
+ Util_SizeAwareOp,
+ MemoryEffects<[MemAlloc]>,
+]> {
+ let summary = [{allocates a transient value with undefined contents}];
+ let description = [{
+ Allocates a transient value (one that is short-lived and local to the
+ current computation) with undefined contents. Consumers of the allocated
+ result must assume nothing of the contents and use `discard` access.
+ }];
+
+ let arguments = (ins
+ Stream_Size:$storage_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_TransientResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ attr-dict `:` type($result) `{` $storage_size `}`
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return storage_size(); }
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncConstantOp : Stream_PureOp<"async.constant", [
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "isMetadata",
+ ]>,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{defines a constant resource}];
+ let description = [{
+ Returns a new resource with the given constant value.
+ }];
+
+ let arguments = (ins
+ ElementsAttr:$value,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ `:`
+ type($result) `` `{` $result_size `}`
+ `=`
+ $value
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncSplatOp : Stream_Op<"async.splat", [
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "preferCloneToConsumers",
+ ]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{splats a value into a resource}];
+ let description = [{
+ Returns a new resource with the given primitive value splatted out to fill
+ the entire contents.
+ }];
+
+ let arguments = (ins
+ Stream_PrimitiveType:$value,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $value `:` type($value) `->` type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return {}; }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncCloneOp : Stream_Op<"async.clone", [
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "preferCloneToConsumers",
+ ]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{clones the contents of a value}];
+ let description = [{
+ Clones the contents of a value at a snapshot in time. Future changes to the
+ cloned value will not effect the result. Acts as a copy-on-write operation.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$source,
+ Stream_Size:$source_size,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_AsyncSliceOp : Stream_PureOp<"async.slice", [
+ AllTypesMatch<["source", "result"]>,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "isMetadata",
+ ]>,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{slices out a cloned subview of a value}];
+ let description = [{
+ Slices a subrange of a stream resource based on a byte range. Acts as a
+ copy-on-write operation.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset,
+ Stream_Offset:$source_end,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `[` $source_offset `to` $source_end `]` `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_AsyncFillOp : Stream_Op<"async.fill", [
+ AllTypesMatch<["target", "result"]>,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ Stream_StreamableOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{fills a subview of a stream resource with a value}];
+ let description = [{
+ Splats a value into a subview of the given stream resource and returns the
+ resource with the update applied.
+
+ Equivalent to a stream.async.splat + stream.async.update.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Offset:$target_end,
+ Stream_Size:$target_length,
+ Stream_PrimitiveType:$value,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $value `,`
+ $target `[` $target_offset `to` $target_end `for` $target_length `]` `:`
+ type($value) `->`
+ custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncUpdateOp : Stream_Op<"async.update", [
+ AllTypesMatch<["target", "result"]>,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ DeclareOpInterfaceMethods<Stream_StreamableOp, [
+ "isMetadata",
+ ]>,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{updates a slice of a subview of a resource in-place}];
+ let description = [{
+ Copies a value into a resource based on a byte range. The returned value
+ is the entire updated target value. Updates can be turned into placement
+ allocations and avoid copies.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Offset:$target_end,
+ Stream_AnyStreamResource:$update,
+ Stream_Size:$update_size,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $update `,`
+ $target `[` $target_offset `to` $target_end `]` `:`
+ type($update) `` `{` $update_size `}` `->`
+ custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return idx == 0 ? target_size() : update_size();
+ }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_AsyncCopyOp : Stream_Op<"async.copy", [
+ AllTypesMatch<["target", "result"]>,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ Stream_StreamableOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{copies a subview of a stream resource to another}];
+ let description = [{
+ Copies a subview of a resource into a subview of another.
+ As with memcpy this does not support overlapping updates into the same
+ resource. Unlike `stream.async.update` copy sources cannot be allocated
+ in-place.
+
+ Equivalent to a stream.async.slice + stream.async.update.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Offset:$target_end,
+ Stream_AnyStreamResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset,
+ Stream_Offset:$source_end,
+ Stream_Size:$length,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_AnyStreamResource:$result
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $source `[` $source_offset `to` $source_end `]` `,`
+ $target `[` $target_offset `to` $target_end `]` `,`
+ $length `:`
+ type($source) `` `{` $source_size `}` `->`
+ custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return idx == 0 ? target_size() : source_size();
+ }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ Stream_StreamableOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{transfers a resource from one location/state to another}];
+ let description = [{
+ Transfers a resource between different states (such as a `staging` lifetime
+ to a `local` lifetime) or different affinities. This is roughly equivalent
+ to a cast but may have special semantics when later lowered to one or more
+ devices with discrete memory spaces or pools.
+ }];
+
+ let arguments = (ins
+ AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>:$source,
+ Stream_Size:$source_size,
+ Stream_Size:$result_size,
+ OptionalAttr<Stream_AffinityAttr>:$source_affinity,
+ OptionalAttr<Stream_AffinityAttr>:$result_affinity
+ );
+ let results = (outs
+ AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>:$result
+ );
+
+ let assemblyFormat = [{
+ (`from` `(` $source_affinity^ `)`)?
+ $source `:`
+ type($source) `` `{` $source_size `}` `->`
+ (`to` `(` $result_affinity^ `)`)?
+ type($result) `` `{` $result_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return result_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_AsyncDispatchOp : Stream_Op<"async.dispatch", [
+ AttrSizedOperandSegments,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ Stream_StreamableOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedOperandsIndexAndLength",
+ ]>,
+]> {
+ let summary = [{dispatches a parallelized grid of work}];
+ let description = [{
+ Calls the specified entry point function once for each element in the
+ specified workgroup count. Each workgroup has access to the same operands
+ and results and is able to load/store at will.
+ }];
+
+ let arguments = (ins
+ Variadic<Index>:$workgroup_count,
+ SymbolRefAttr:$entry_point,
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_PrimitiveType,
+ ]>>:$operands,
+ Variadic<Stream_Size>:$operand_sizes,
+ Variadic<Stream_Size>:$result_sizes,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<Stream_AnyStreamResource>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $entry_point `[` $workgroup_count `]` ``
+ `(` $operands `)` attr-dict `:`
+ custom<ShapedFunctionType>(ref($operands),
+ type($operands), $operand_sizes,
+ type($results), $result_sizes,
+ $tied_operands)
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return findValueSizeInList(idx, getOperands(), operand_sizes());
+ }
+ Value getResultSize(unsigned idx) {
+ return findValueSizeInList(idx, getResults(), result_sizes());
+ }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Async control flow ops
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): stream.async.if
+// TODO(benvanik): stream.async.select
+// TODO(benvanik): stream.async.for
+
+def Stream_AsyncExecuteOp : Stream_Op<"async.execute", [
+ AttrSizedOperandSegments,
+ RecursiveSideEffects,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+ "getSuccessorEntryOperands",
+ ]>,
+ SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResultsIndexAndLength",
+ ]>,
+]> {
+ let summary = [{executes a dependency-aware sequence of streamable ops}];
+ let description = [{
+ Evaluates the operations within the region by dependency order while obeying
+ ties when present. Nested ops execute serially in block order and nested
+ `stream.async.concurrent` ops can be used to run multiple ops concurrently
+ within the stream. All resource inputs must be captured explicitly. All
+ results are only ready once all nested ops complete execution and the
+ returned timepoint is reached. Zero or more timepoints may be provided to
+ block execution until they are all reached; zero timepoints indicates that
+ execution may begin immediately.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$operands,
+ Variadic<Stream_Size>:$operand_sizes,
+ Variadic<Stream_Size>:$result_sizes,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$results,
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ `with` ``
+ custom<ResourceRegion>($operands, type($operands), $operand_sizes,
+ type($results), $result_sizes,
+ $tied_operands, $body)
+ `=` `` `>` type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "TypeRange":$resultTypes, "ValueRange":$resultSizes,
+ "Value":$awaitTimepoint,
+ "ValueRange":$operands, "ValueRange":$operandSizes,
+ "ArrayRef<int64_t>":$tiedOperands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ ];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return findValueSizeInList(idx, getOperands(), operand_sizes());
+ }
+ Value getResultSize(unsigned idx) {
+ return findValueSizeInList(idx, getResults(), result_sizes());
+ }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncConcurrentOp : Stream_Op<"async.concurrent", [
+ ParentOneOf<[
+ "IREE::Stream::AsyncExecuteOp",
+ "IREE::Stream::AsyncConcurrentOp",
+ ]>,
+ AttrSizedOperandSegments,
+ RecursiveSideEffects,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+ "getSuccessorEntryOperands",
+ ]>,
+ SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
+ Stream_AffinityOp,
+ Stream_AsyncPhaseOp,
+ Stream_StreamableOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface>,
+]> {
+ let summary = [{executes all ops concurrently}];
+ let description = [{
+ Represents a wave of work scheduled concurrently (each op executing at the
+ same time). All resource inputs must be captured explicitly. All results are
+ only ready once all nested ops complete execution.
+
+ Waves can be nested to create a DAG. For example, take the following graph:
+ ```
+ |
+ v---------+---------v
+ +-------|-------+ +-------|-------+
+ | v--+--v | | v--+--v |
+ | +----+ +----+ | | +----+ +----+ |
+ | | %a | | %b | | | | %c | | %d | |
+ | +----+ +----+ | | +----+ +----+ |
+ | +--v--+ | | +--v--+ |
+ +-------|-------+ +-------|-------+
+ +---------v---------+
+ |
+ ```
+
+ Represented with nested waves:
+ ```mlir
+ %0 = stream.async.concurrent with(%arg) -> ... {
+ %1 = stream.async.concurrent with(%arg as %arg0) -> ... {
+ %a = ...
+ %b = ...
+ stream.yield %a, %b
+ }
+ %2 = stream.async.concurrent with(%arg as %arg1) -> ... {
+ %c = ...
+ %d = ...
+ stream.yield %c, %d
+ }
+ stream.yield %1, %2
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$operands,
+ Variadic<Stream_Size>:$operand_sizes,
+ Variadic<Stream_Size>:$result_sizes,
+ OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$results
+ );
+
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ `with` ``
+ custom<ResourceRegion>($operands, type($operands), $operand_sizes,
+ type($results), $result_sizes,
+ $tied_operands, $body)
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "TypeRange":$resultTypes, "ValueRange":$resultSizes,
+ "ValueRange":$operands, "ValueRange":$operandSizes,
+ "ArrayRef<int64_t>":$tiedOperands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ ];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return findValueSizeInList(idx, getOperands(), operand_sizes());
+ }
+ Value getResultSize(unsigned idx) {
+ return findValueSizeInList(idx, getResults(), result_sizes());
+ }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Explicit command ops
+//===----------------------------------------------------------------------===//
+
+def Stream_CmdFlushOp : Stream_Op<"cmd.flush", [
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+ Stream_SubviewEffectOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{flushes a subview of a resource}];
+ let description = [{
+ Transfers a resource to an external target. The resource memory is made
+ available to the target and can be made visible there using
+ `stream.cmd.invalidate`.
+ }];
+
+ let arguments = (ins
+ Stream_AnyResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Size:$target_length,
+ OptionalAttr<Stream_AffinityAttr>:$source_affinity
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ (`to` `(` $source_affinity^ `)`)?
+ $target `[` $target_offset `for` $target_length `]` `:`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdInvalidateOp : Stream_Op<"cmd.invalidate", [
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+ Stream_SubviewEffectOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{invalidates a subview of a resource}];
+ let description = [{
+ Transfers a resource from an external source into the current target. The
+ resource memory is assumed to have been made available at the source via
+ `stream.cmd.flush`.
+ }];
+
+ let arguments = (ins
+ Stream_AnyResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Size:$target_length,
+ OptionalAttr<Stream_AffinityAttr>:$source_affinity
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ (`from` `(` $source_affinity^ `)`)?
+ $target `[` $target_offset `for` $target_length `]` `:`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdDiscardOp : Stream_Op<"cmd.discard", [
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+ Stream_SubviewEffectOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{discards a subview of a resource}];
+ let description = [{
+ Discards a subview of a resource, indicating that after this command the
+ specified contents are no longer needed. This can be used to trim memory
+ or invalidate caches.
+ }];
+
+ let arguments = (ins
+ Stream_AnyResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Size:$target_length
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target `[` $target_offset `for` $target_length `]` `:`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdFillOp : Stream_Op<"cmd.fill", [
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+ Stream_SubviewEffectOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{fills a subview of a stream resource with a value}];
+ let description = [{
+ Splats a value into a subview of the given stream resource and returns the
+ resource with the update applied.
+ }];
+
+ let arguments = (ins
+ Stream_AnyStreamResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Size:$target_length,
+ Stream_PrimitiveType:$value
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $value `,`
+ $target `[` $target_offset `for` $target_length `]` `:`
+ type($value) `->`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdCopyOp : Stream_Op<"cmd.copy", [
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+ Stream_SubviewEffectOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{copies a subview of a stream resource to another}];
+ let description = [{
+ Copies a subview of a resource into a subview of another.
+ As with memcpy this does not support overlapping updates into the same
+ resource.
+ }];
+
+ let arguments = (ins
+ Stream_AnyResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset,
+ Stream_AnyResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ Stream_Size:$length
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` `,`
+ $target `[` $target_offset `]` `,`
+ $length `:`
+ type($source) `` `{` $source_size `}` `->`
+ type($target) `` `{` $target_size `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return idx == 0 ? source_size() : target_size();
+ }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdDispatchOp : Stream_Op<"cmd.dispatch", [
+ AttrSizedOperandSegments,
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+ Stream_SubviewEffectOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{dispatches a parallelized grid of work}];
+ let description = [{
+ Calls the specified entry point function once for each element in the
+ specified workgroup count. Each workgroup has access to the same operands
+ and results and is able to load/store at will.
+ }];
+
+ let arguments = (ins
+ Variadic<Index>:$workgroup_count,
+ SymbolRefAttr:$entry_point,
+ Variadic<Stream_PrimitiveType>:$operands,
+ Variadic<Stream_AnyStreamResource>:$resources,
+ Variadic<Stream_Size>:$resource_sizes,
+ Variadic<Stream_Offset>:$resource_offsets,
+ Variadic<Stream_Size>:$resource_lengths,
+ Stream_ResourceAccessArrayAttr:$resource_accesses
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $entry_point `[` $workgroup_count `]` ``
+ (`(` $operands^ `:` type($operands) `)`)? `{`
+ custom<DispatchResources>($resources, type($resources), $resource_sizes,
+ $resource_offsets, $resource_lengths,
+ $resource_accesses)
+ `\n` `}`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return findValueSizeInList(
+ idx - getODSOperandIndexAndLength(2).first,
+ resources(), resource_sizes());
+ }
+ Value getResultSize(unsigned idx) { return {}; }
+
+ // Builds a map of operand index to argument index.
+ static SmallVector<unsigned> makeOperandToArgMap(mlir::FuncOp funcOp);
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdExecuteOp : Stream_Op<"cmd.execute", [
+ AttrSizedOperandSegments,
+ RecursiveSideEffects,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+ "getSuccessorEntryOperands",
+ ]>,
+ SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
+ Stream_AffinityOp,
+ Stream_CmdPhaseOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
+]> {
+ let summary = [{executes a dependency-aware sequence of streamable ops}];
+ let description = [{
+ Evaluates the operations within the region by dependency order while obeying
+ ties when present. Nested ops execute serially in block order and nested
+ `stream.cmd.concurrent` ops can be used to run multiple ops concurrently
+ within the stream. All resource inputs must be captured explicitly. All
+ results are only ready once all nested ops complete execution and the
+ returned timepoint is reached. Zero or more timepoints may be provided to
+ block execution until they are all reached; zero timepoints indicates that
+ execution may begin immediately.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$operands,
+ Variadic<Stream_Size>:$operand_sizes,
+ Optional<Stream_Timepoint>:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ (`await` `(` $await_timepoint^ `)` `=` `` `>`)?
+ `with` ``
+ custom<ExplicitResourceRegion>($operands, type($operands), $operand_sizes,
+ $body)
+ `=` `` `>` type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "Value":$awaitTimepoint,
+ "ValueRange":$operands, "ValueRange":$operandSizes,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ ];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return findValueSizeInList(idx, getOperands(), operand_sizes());
+ }
+ Value getResultSize(unsigned idx) {
+ return {};
+ }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+// TODO(benvanik): stream.cmd.serial for serialized execution.
+def Stream_CmdSerialOp : Stream_Op<"cmd.serial", [
+ ParentOneOf<[
+ "IREE::Stream::CmdExecuteOp",
+ "IREE::Stream::CmdSerialOp",
+ "IREE::Stream::CmdConcurrentOp",
+ ]>,
+ RecursiveSideEffects,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+]> {
+ let summary = [{executes all ops serially (in-order)}];
+ let description = [{
+ Represents a sequence of work scheduled serially (each op executing one
+ after the other).
+
+ Regions can be nested to create a DAG. For example, take the following graph:
+ ```
+ |
+ v---------+-----v
+ +-------|-------+ +---|----+
+ | v--+--v | | v |
+ | +----+ +----+ | | +----+ |
+ | | @a | | @b | | | | @c | |
+ | +----+ +----+ | | +----+ |
+ | | | | | | |
+ | | | | | +-v--+ |
+ | | | | | | @d | |
+ | | | | | +----+ |
+ | +--v--+ | | | |
+ +-------|-------+ +---|----+
+ +---------v-----+
+ |
+ ```
+
+ Represented with nested regions:
+ ```mlir
+ stream.cmd.concurrent {
+ stream.cmd.concurrent {
+ stream.cmd.dispatch @a
+ stream.cmd.dispatch @b
+ }
+ stream.cmd.serial {
+ stream.cmd.dispatch @c
+ stream.cmd.dispatch @d
+ }
+ }
+ ```
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ `` $body
+ attr-dict-with-keyword
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_CmdConcurrentOp : Stream_Op<"cmd.concurrent", [
+ ParentOneOf<[
+ "IREE::Stream::CmdExecuteOp",
+ "IREE::Stream::CmdSerialOp",
+ "IREE::Stream::CmdConcurrentOp",
+ ]>,
+ RecursiveSideEffects,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
+ Stream_CmdPhaseOp,
+ Stream_StreamableOp,
+]> {
+ let summary = [{executes all ops concurrently}];
+ let description = [{
+ Represents a wave of work scheduled concurrently (each op executing at the
+ same time).
+
+ Waves can be nested to create a DAG. For example, take the following graph:
+ ```
+ |
+ v---------+---------v
+ +-------|-------+ +-------|-------+
+ | v--+--v | | v--+--v |
+ | +----+ +----+ | | +----+ +----+ |
+ | | @a | | @b | | | | @c | | @d | |
+ | +----+ +----+ | | +----+ +----+ |
+ | +--v--+ | | +--v--+ |
+ +-------|-------+ +-------|-------+
+ +---------v---------+
+ |
+ ```
+
+ Represented with nested waves:
+ ```mlir
+ stream.cmd.concurrent {
+ stream.cmd.concurrent {
+ stream.cmd.dispatch @a
+ stream.cmd.dispatch @b
+ }
+ stream.cmd.concurrent {
+ stream.cmd.dispatch @c
+ stream.cmd.dispatch @d
+ }
+ }
+ ```
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ `` $body
+ attr-dict-with-keyword
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Synchronization ops
+//===----------------------------------------------------------------------===//
+
+def Stream_TimepointImmediateOp : Stream_PureOp<"timepoint.immediate", [
+ ConstantLike,
+]> {
+ let summary = [{results an immediately-available timepoint}];
+ let description = [{
+ Timepoints indicate a point in the execution timeline and this op can be
+ used to get a placeholder representing the start of the timeline. Any waits
+ on the returned timepoint will resolve immediately. This generally folds
+ away but can be useful if needing to initialize globals or branch args.
+ }];
+
+ let arguments = (ins);
+ let results = (outs
+ Stream_Timepoint:$timepoint
+ );
+
+ let assemblyFormat = [{
+ attr-dict
+ `=` `` `>` type($timepoint)
+ }];
+
+ let hasFolder = 1;
+}
+
+def Stream_TimepointJoinOp : Stream_PureOp<"timepoint.join", []> {
+ let summary = [{joins one or more timepoints into the max of all of them}];
+ let description = [{
+ Returns a timepoint that indicates that all of the input timepoints have
+ been reached.
+ }];
+
+ let arguments = (ins
+ Variadic<Stream_Timepoint>:$timepoints
+ );
+ let results = (outs
+ Stream_Timepoint:$result
+ );
+
+ let assemblyFormat = [{
+ `max` `(` $timepoints `)` `=` `` `>` type($result)
+ attr-dict-with-keyword
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [
+ AttrSizedOperandSegments,
+ Stream_AffinityOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{awaits a timepoint before returning a set of resources}];
+ let description = [{
+ After asynchronous execution scheduling resources may exist in different
+ states at different points in the execution timeline. This op enables
+ resolving the version of a resource after a particular point in the
+ timeline. As timepoints transitively chain the timepoint must only cover the
+ resource availability but not be limited to its original production
+ timepoint.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$operands,
+ Variadic<Stream_Size>:$operand_sizes,
+ Stream_Timepoint:$timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $timepoint `=` `` `>`
+ $operands `:`
+ custom<SizeAwareTypeList>(type($operands), type($results), $operand_sizes)
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "ValueRange":$operands, "ValueRange":$operandSizes,
+ "Value":$timepoint,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ ];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return operand_sizes()[idx];
+ }
+ Value getResultSize(unsigned idx) {
+ return operand_sizes()[idx];
+ }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Executables
+//===----------------------------------------------------------------------===//
+
+def Stream_ExecutableOp : Stream_Op<"executable", [
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"IREE::Stream::ExecutableEndOp">,
+ NativeOpTrait<"SymbolTable">,
+ Symbol,
+]> {
+ let summary = [{generic executable module}];
+ let description = [{
+ An executable module containing one or more public functions. The contents
+ of the functions are safe to dispatch and can be lowered further to
+ target-specific backend IR representations.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<StrAttr>:$sym_visibility,
+ SymbolNameAttr:$sym_name
+ );
+
+ let regions = (region SizedRegion<1>:$body);
+
+ let assemblyFormat = [{
+ custom<SymbolVisibility>($sym_visibility)
+ $sym_name
+ attr-dict-with-keyword
+ ``
+ regions
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "StringRef":$sym_name)>,
+ ];
+
+ let extraClassDeclaration = [{
+ Block& getBlock() { return body().front(); }
+ ::mlir::ModuleOp getInnerModule() {
+ return *getBlock().getOps<::mlir::ModuleOp>().begin();
+ }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+}
+
+def Stream_ExecutableEndOp : Stream_Op<"executable.end", [
+ HasParent<"IREE::Stream::ExecutableOp">,
+ Terminator,
+]> {
+ let summary = [{terminator pseudo-op for the executable op}];
+ let assemblyFormat = "attr-dict";
+}
+
+def Stream_ExecutableExportOp : Stream_Op<"executable.export", [
+ HasParent<"IREE::Stream::ExecutableOp">,
+ Symbol,
+]> {
+ let summary = [{defines an executable entry point for dispatch operations}];
+ let description = [{
+ Specifies an exported function with an externally-visible alias. Multiple
+ exports can reference the same internal function.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<StrAttr>:$sym_visibility,
+ SymbolNameAttr:$sym_name,
+ FlatSymbolRefAttr:$function_ref
+ );
+
+ let assemblyFormat = [{
+ custom<SymbolVisibility>($sym_visibility)
+ custom<SymbolAlias>($sym_name, $function_ref)
+ attr-dict-with-keyword
+ }];
+
+ let builders = [
+ OpBuilder<(ins
+ "StringRef":$sym_name,
+ "FlatSymbolRefAttr":$function_ref
+ )>,
+ ];
+}
+
+def Stream_BindingSubspanOp : Stream_PureOp<"binding.subspan", [
+ DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
+ Util_ShapeAwareOp,
+ ]> {
+ let summary = [{returns an alias to a subspan of interface binding data}];
+ let description = [{
+ Returns a subview to a tensor or memref-like type from a binding. The same
+ binding may have multiple subviews at different byte offsets.
+ }];
+
+ let arguments = (ins
+ Stream_AnyBinding:$binding,
+ Stream_Offset:$byte_offset,
+ Stream_ShapeDynamicDims:$dynamic_dims
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ $binding `` `[` $byte_offset `]`
+ attr-dict `:` type($binding) `->` type($result) (`{` $dynamic_dims^ `}`)?
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return ValueRange{}; }
+ ValueRange getResultDynamicDims(unsigned idx) { return dynamic_dims(); }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Misc
+//===----------------------------------------------------------------------===//
+
+def Stream_YieldOp : Stream_Op<"yield", [
+ ParentOneOf<[
+ "IREE::Stream::AsyncExecuteOp",
+ "IREE::Stream::AsyncConcurrentOp",
+ "IREE::Stream::CmdExecuteOp",
+ "IREE::Stream::CmdSerialOp",
+ "IREE::Stream::CmdConcurrentOp",
+ ]>,
+ NoSideEffect,
+ DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface, [
+ "getMutableSuccessorOperands",
+ ]>,
+ Terminator,
+ SameVariadicOperandSize,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{yields stream values from an execution region}];
+ let description = [{
+ The values returned represent the asynchronous value at the point in time
+ the SSA value is defined (or tied).
+ }];
+
+ let arguments = (ins
+ Variadic<AnyTypeOf<[
+ Stream_AnyStreamResource,
+ Stream_StagingResource,
+ ]>>:$operands,
+ Variadic<Index>:$operand_sizes
+ );
+
+ let assemblyFormat = [{
+ attr-dict
+ ($operands^ `:` custom<SizeAwareTypeList>(type($operands), $operand_sizes))?
+ }];
+
+ let builders = [
+ OpBuilder<(ins),
+ [{
+ build($_builder, $_state, ValueRange{}, ValueRange{});
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) {
+ return findValueSizeInList(idx, getOperands(), operand_sizes());
+ }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+}
+
+#endif // IREE_DIALECT_STREAM_OPS
diff --git a/iree/compiler/Dialect/Stream/IR/StreamTraits.h b/iree/compiler/Dialect/Stream/IR/StreamTraits.h
new file mode 100644
index 0000000..cc79a6d
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamTraits.h
@@ -0,0 +1,40 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMTRAITS_H_
+#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace IREE {
+namespace Stream {
+
+template <typename ConcreteType>
+class TensorPhaseOp : public OpTrait::TraitBase<ConcreteType, TensorPhaseOp> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) { return success(); }
+};
+
+template <typename ConcreteType>
+class AsyncPhaseOp : public OpTrait::TraitBase<ConcreteType, AsyncPhaseOp> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) { return success(); }
+};
+
+template <typename ConcreteType>
+class CmdPhaseOp : public OpTrait::TraitBase<ConcreteType, CmdPhaseOp> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) { return success(); }
+};
+
+} // namespace Stream
+} // namespace IREE
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAMTRAITS_H_
diff --git a/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
new file mode 100644
index 0000000..011ad85
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -0,0 +1,351 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/DialectImplementation.h"
+
+// clang-format off: must be included after all LLVM/MLIR headers.
+#define GET_ATTRDEF_CLASSES
+#include "iree/compiler/Dialect/Stream/IR/StreamAttrs.cpp.inc" // IWYU pragma: keep
+#include "iree/compiler/Dialect/Stream/IR/StreamEnums.cpp.inc" // IWYU pragma: keep
+#define GET_TYPEDEF_CLASSES
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.cpp.inc" // IWYU pragma: keep
+// clang-format on
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+//===----------------------------------------------------------------------===//
+// #stream.resource_config<...>
+//===----------------------------------------------------------------------===//
+
+// static
+Attribute ResourceConfigAttr::parse(mlir::DialectAsmParser &p,
+ mlir::Type type) {
+ if (failed(p.parseLBrace())) return {};
+
+ int64_t maxAllocationSize = 0;
+ int64_t minBufferOffsetAlignment = 0;
+ int64_t maxBufferRange = 0;
+ int64_t minBufferRangeAlignment = 0;
+ while (failed(p.parseOptionalRBrace())) {
+ StringRef key;
+ int64_t value = 0;
+ if (failed(p.parseKeyword(&key)) || failed(p.parseEqual()) ||
+ failed(p.parseInteger(value))) {
+ return {};
+ }
+ if (key == "max_allocation_size") {
+ maxAllocationSize = value;
+ } else if (key == "min_buffer_offset_alignment") {
+ minBufferOffsetAlignment = value;
+ } else if (key == "max_buffer_range") {
+ maxBufferRange = value;
+ } else if (key == "min_buffer_range_alignment") {
+ minBufferRangeAlignment = value;
+ }
+ (void)p.parseOptionalComma();
+ }
+
+ return ResourceConfigAttr::get(p.getContext(), maxAllocationSize,
+ minBufferOffsetAlignment, maxBufferRange,
+ minBufferRangeAlignment);
+}
+
+void ResourceConfigAttr::print(mlir::DialectAsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << getMnemonic() << "{";
+ os << "max_allocation_size = " << getMaxAllocationSize() << ", ";
+ os << "min_buffer_offset_alignment = " << getMinBufferOffsetAlignment()
+ << ", ";
+ os << "max_buffer_range = " << getMaxBufferRange() << ", ";
+ os << "min_buffer_range_alignment = " << getMinBufferRangeAlignment();
+ os << "}";
+}
+
+// static
+ResourceConfigAttr ResourceConfigAttr::intersectBufferConstraints(
+ ResourceConfigAttr lhs, ResourceConfigAttr rhs) {
+ if (!lhs) return rhs;
+ if (!rhs) return lhs;
+ Builder b(lhs.getContext());
+ return ResourceConfigAttr::get(
+ b.getContext(),
+ std::min(lhs.getMaxAllocationSize(), rhs.getMaxAllocationSize()),
+ std::max(lhs.getMinBufferOffsetAlignment(),
+ rhs.getMinBufferOffsetAlignment()),
+ std::min(lhs.getMaxBufferRange(), rhs.getMaxBufferRange()),
+ std::max(lhs.getMinBufferRangeAlignment(),
+ rhs.getMinBufferRangeAlignment()));
+}
+
+// static
+ResourceConfigAttr ResourceConfigAttr::getDefaultHostConstraints(
+ MLIRContext *context) {
+ // Picked to represent what we kind of want on CPU today.
+ uint64_t maxAllocationSize = UINT32_MAX;
+ uint64_t minBufferOffsetAlignment = 32ull;
+ uint64_t maxBufferRange = UINT32_MAX;
+ uint64_t minBufferRangeAlignment = 32ull;
+ return ResourceConfigAttr::get(context, maxAllocationSize,
+ minBufferOffsetAlignment, maxBufferRange,
+ minBufferRangeAlignment);
+}
+
+// TODO(benvanik): find a way to go affinity -> resource config.
+// For now we just always fall back to the conservative host config.
+static ResourceConfigAttr inferResourceConfigFromAffinity(
+ AffinityAttr affinityAttr) {
+ return {};
+}
+
+// static
+ResourceConfigAttr ResourceConfigAttr::lookup(Operation *op) {
+ auto *context = op->getContext();
+ auto attrId = mlir::Identifier::get("stream.resources", context);
+ while (op) {
+ if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
+ auto affinityAttr = affinityOp.getAffinity();
+ if (affinityAttr) {
+ auto attr = inferResourceConfigFromAffinity(affinityAttr);
+ if (attr) return attr;
+ }
+ }
+ auto attr = op->getAttrOfType<ResourceConfigAttr>(attrId);
+ if (attr) return attr;
+ op = op->getParentOp();
+ }
+ // No config found; use conservative host config.
+ return getDefaultHostConstraints(context);
+}
+
+//===----------------------------------------------------------------------===//
+// #stream.timepoint<...>
+//===----------------------------------------------------------------------===//
+
+Attribute TimepointAttr::parse(mlir::DialectAsmParser &p, mlir::Type type) {
+ StringRef timeStr;
+ if (failed(p.parseLess())) return {};
+ if (failed(p.parseKeyword(&timeStr))) {
+ return {};
+ }
+ if (failed(p.parseGreater())) return {};
+ if (timeStr != "immediate") {
+ p.emitError(p.getCurrentLocation(),
+ "only immediate timepoint attrs are supported");
+ return {};
+ }
+ return TimepointAttr::get(p.getContext(), TimepointType::get(p.getContext()));
+}
+
+void TimepointAttr::print(mlir::DialectAsmPrinter &p) const {
+ p << getMnemonic() << "<";
+ p << "immediate";
+ p << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// #stream.affinity
+//===----------------------------------------------------------------------===//
+
+AffinityAttr AffinityAttr::lookup(Operation *op) {
+ auto attrId = mlir::Identifier::get("stream.affinity", op->getContext());
+ while (op) {
+ if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
+ auto affinity = affinityOp.getAffinity();
+ if (affinity) return affinity;
+ }
+ auto attr = op->getAttrOfType<AffinityAttr>(attrId);
+ if (attr) return attr;
+ op = op->getParentOp();
+ }
+ return {}; // No affinity found; let caller decide what to do.
+}
+
+// static
+bool AffinityAttr::areCompatible(AffinityAttr desiredAffinity,
+ AffinityAttr requiredAffinity) {
+ // We could do a fuzzier match here (interface isCompatible() etc).
+ return desiredAffinity == requiredAffinity;
+}
+
+//===----------------------------------------------------------------------===//
+// #stream.partitioning_config
+//===----------------------------------------------------------------------===//
+
+void PartitioningConfigAttr::walkImmediateSubElements(
+ function_ref<void(Attribute)> walkAttrsFn,
+ function_ref<void(Type)> walkTypesFn) const {
+ walkAttrsFn(getFavor());
+}
+
+Attribute PartitioningConfigAttr::parse(mlir::DialectAsmParser &p,
+ mlir::Type type) {
+ std::string favorStr;
+ if (failed(p.parseLess())) return {};
+ if (succeeded(p.parseOptionalStar())) {
+ favorStr = "size";
+ } else if (failed(p.parseString(&favorStr))) {
+ return {};
+ }
+ if (failed(p.parseGreater())) return {};
+ auto favor = symbolizeFavor(favorStr);
+ if (!favor.hasValue()) {
+ p.emitError(p.getNameLoc(), "unknown favor value: ") << favorStr;
+ return {};
+ }
+ return PartitioningConfigAttr::get(
+ FavorAttr::get(p.getContext(), favor.getValue()));
+}
+
+void PartitioningConfigAttr::print(mlir::DialectAsmPrinter &p) const {
+ p << getMnemonic() << "<";
+ p << "favor-";
+ p << stringifyFavor(getFavor().getValue());
+ p << ">";
+}
+
+PartitioningConfigAttr PartitioningConfigAttr::lookup(Operation *op) {
+ auto attrId = mlir::Identifier::get("stream.partitioning", op->getContext());
+ while (op) {
+ auto attr = op->getAttrOfType<PartitioningConfigAttr>(attrId);
+ if (attr) return attr;
+ op = op->getParentOp();
+ }
+ return {}; // No config found; let caller decide what to do.
+}
+
+//===----------------------------------------------------------------------===//
+// !stream.resource<lifetime>
+//===----------------------------------------------------------------------===//
+
+static llvm::Optional<Lifetime> parseLifetime(StringRef str) {
+ if (str == "*") {
+ return Lifetime::Unknown;
+ } else if (str == "external") {
+ return Lifetime::External;
+ } else if (str == "staging") {
+ return Lifetime::Staging;
+ } else if (str == "transient") {
+ return Lifetime::Transient;
+ } else if (str == "variable") {
+ return Lifetime::Variable;
+ } else if (str == "constant") {
+ return Lifetime::Constant;
+ } else {
+ return llvm::None;
+ }
+}
+
+static void printLifetime(Lifetime lifetime, llvm::raw_ostream &os) {
+ if (lifetime == Lifetime::Unknown) {
+ os << "*";
+ } else {
+ os << stringifyLifetime(lifetime).lower();
+ }
+}
+
+Type ResourceType::parse(mlir::DialectAsmParser &p) {
+ StringRef lifetimeStr;
+ if (failed(p.parseLess())) return {};
+ if (succeeded(p.parseOptionalStar())) {
+ lifetimeStr = "*";
+ } else if (failed(p.parseKeyword(&lifetimeStr))) {
+ return {};
+ }
+ if (failed(p.parseGreater())) return {};
+ auto lifetime = parseLifetime(lifetimeStr);
+ if (!lifetime.hasValue()) {
+ p.emitError(p.getNameLoc(), "unknown lifetime value: ") << lifetimeStr;
+ return {};
+ }
+ return ResourceType::get(p.getContext(), lifetime.getValue());
+}
+
+void ResourceType::print(mlir::DialectAsmPrinter &p) const {
+ p << getMnemonic() << "<";
+ printLifetime(getLifetime(), p.getStream());
+ p << ">";
+}
+
+bool ResourceType::isAccessStorageCompatible(Type accessType) const {
+ if (auto resourceType = accessType.dyn_cast<ResourceType>()) {
+ // We could allow widening loads or stores here but today we require
+ // transfers to accomplish that.
+ return accessType == resourceType;
+ }
+ return accessType.isa<ShapedType>();
+}
+
+//===----------------------------------------------------------------------===//
+// Dialect registration
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.cpp.inc" // IWYU pragma: keep
+#include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.cpp.inc" // IWYU pragma: keep
+
+void StreamDialect::registerAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "iree/compiler/Dialect/Stream/IR/StreamAttrs.cpp.inc" // IWYU pragma: keep
+ >();
+}
+
+void StreamDialect::registerTypes() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.cpp.inc" // IWYU pragma: keep
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// Type printing and parsing
+//===----------------------------------------------------------------------===//
+
+Attribute StreamDialect::parseAttribute(DialectAsmParser &parser,
+ Type type) const {
+ StringRef mnemonic;
+ if (failed(parser.parseKeyword(&mnemonic))) return {};
+ Attribute attr;
+ auto parseResult = generatedAttributeParser(parser, mnemonic, type, attr);
+ if (parseResult.hasValue()) return attr;
+ parser.emitError(parser.getCurrentLocation())
+ << "unknown Stream attribute: " << mnemonic;
+ return {};
+}
+
+void StreamDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
+ if (failed(generatedAttributePrinter(attr, p))) {
+ llvm_unreachable("unknown Stream attribute");
+ }
+}
+
+Type StreamDialect::parseType(DialectAsmParser &parser) const {
+ StringRef mnemonic;
+ if (failed(parser.parseKeyword(&mnemonic))) return {};
+ Type type;
+ OptionalParseResult parseResult = generatedTypeParser(parser, mnemonic, type);
+ if (parseResult.hasValue()) return type;
+ parser.emitError(parser.getCurrentLocation())
+ << "unknown Stream type: " << mnemonic;
+ return {};
+}
+
+void StreamDialect::printType(Type type, DialectAsmPrinter &p) const {
+ if (failed(generatedTypePrinter(type, p))) {
+ llvm_unreachable("unknown Stream type");
+ }
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/iree/compiler/Dialect/Stream/IR/StreamTypes.h
new file mode 100644
index 0000000..50e272c
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -0,0 +1,65 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMTYPES_H_
+#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMTYPES_H_
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+
+// clang-format off: must be included after all LLVM/MLIR headers.
+#include "iree/compiler/Dialect/Stream/IR/StreamEnums.h.inc" // IWYU pragma: export
+// clang-format on
+
+// clang-format off: must be included after all LLVM/MLIR headers.
+#define GET_ATTRDEF_CLASSES
+#include "iree/compiler/Dialect/Stream/IR/StreamAttrs.h.inc" // IWYU pragma: keep
+// clang-format on
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+class AffinityAttr;
+
+#include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.h.inc" // IWYU pragma: export
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+// clang-format off: must be included after all LLVM/MLIR headers.
+#define GET_TYPEDEF_CLASSES
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h.inc" // IWYU pragma: keep
+// clang-format on
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.h.inc" // IWYU pragma: export
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_IR_STREAMTYPES_H_
diff --git a/iree/compiler/Dialect/Stream/IR/test/BUILD b/iree/compiler/Dialect/Stream/IR/test/BUILD
new file mode 100644
index 0000000..7e5c7d6
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/BUILD
@@ -0,0 +1,38 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "async_folding.mlir",
+ "async_ops.mlir",
+ "cmd_folding.mlir",
+ "cmd_ops.mlir",
+ "executable_ops.mlir",
+ "resource_folding.mlir",
+ "resource_ops.mlir",
+ "tensor_folding.mlir",
+ "tensor_ops.mlir",
+ "timepoint_folding.mlir",
+ "timepoint_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
new file mode 100644
index 0000000..8c95432
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt
@@ -0,0 +1,33 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/IR/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "async_folding.mlir"
+ "async_ops.mlir"
+ "cmd_folding.mlir"
+ "cmd_ops.mlir"
+ "executable_ops.mlir"
+ "resource_folding.mlir"
+ "resource_ops.mlir"
+ "tensor_folding.mlir"
+ "tensor_ops.mlir"
+ "timepoint_folding.mlir"
+ "timepoint_ops.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
new file mode 100644
index 0000000..64710ac
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
@@ -0,0 +1,296 @@
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// Ensures that the splat moves to the first common dominator of bb2/bb3.
+// We likely want to clone instead to reduce lifetime of the splats.
+
+// CHECK-LABEL: @SinkSplatsToConsumers
+func @SinkSplatsToConsumers(
+ %arg0: i1, %arg1: i1,
+ %arg2: !stream.resource<*>,
+ %arg3: !stream.resource<*>,
+ %arg4: !stream.resource<*>
+) -> !stream.resource<*> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c100 = arith.constant 100 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: stream.async.splat
+ %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c100}
+ // CHECK: cond_br %arg0, ^bb1, ^bb4
+ cond_br %arg0, ^bb1, ^bb4(%arg4 : !stream.resource<*>)
+// CHECK: ^bb1:
+^bb1:
+ // CHECK: %[[SPLAT:.+]] = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c100}
+ // CHECK-NEXT: cond_br %arg1, ^bb2, ^bb3
+ cond_br %arg1, ^bb2, ^bb3
+// CHECK: ^bb2:
+^bb2:
+ // CHECK: = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%[[SPLAT]])
+ %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c2, %c3](%0) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+ br ^bb4(%2 : !stream.resource<*>)
+// CHECK: ^bb3:
+^bb3:
+ // CHECK: = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%[[SPLAT]])
+ %3 = stream.async.dispatch @executable::@dispatch1[%c1, %c2, %c3](%0) : (!stream.resource<*>{%c100}) -> !stream.resource<*>{%c100}
+ br ^bb4(%3 : !stream.resource<*>)
+// CHECK: ^bb4(
+^bb4(%arg6: !stream.resource<*>):
+ return %arg6 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @PropagateClonableOps
+func @PropagateClonableOps(%arg0: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: %[[T:.+]] = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%arg0}
+ %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%arg0}
+ // CHECK-NOT: stream.async.clone
+ %1 = stream.async.clone %0 : !stream.resource<*>{%arg0} -> !stream.resource<*>{%arg0}
+ // CHECK: return %[[T]]
+ return %1 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @FoldAsyncSliceOp
+func @FoldAsyncSliceOp(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: stream.async.slice
+ %0 = stream.async.slice %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%arg1}
+ // CHECK: return %arg0
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @PropagateSplatsThroughSlices
+func @PropagateSplatsThroughSlices(%arg0: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: %[[T:.+]] = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c128}
+ %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%arg0}
+ // CHECK-NOT: stream.async.slice
+ %1 = stream.async.slice %0[%c0 to %c128] : !stream.resource<*>{%arg0} -> !stream.resource<*>{%c128}
+ // CHECK: return %[[T]]
+ return %1 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @FlattenFullFillToSplat
+func @FlattenFullFillToSplat(%arg0: !stream.resource<*>, %arg1: index, %arg2: f32) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[T:.+]] = stream.async.splat %arg2 : f32 -> !stream.resource<*>{%arg1}
+ %0 = stream.async.fill %arg2, %arg0[%c0 to %arg1 for %arg1] : f32 -> %arg0 as !stream.resource<*>{%arg1}
+ // CHECK: return %[[T]]
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @FoldAsyncUpdateOp
+func @FoldAsyncUpdateOp(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: stream.async.update
+ %0 = stream.async.update %arg1, %arg0[%c0 to %arg2] : !stream.resource<*>{%arg2} -> %arg0 as !stream.resource<*>{%arg2}
+ // CHECK: return %arg1
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @CombineSplatUpdateFromToFill
+func @CombineSplatUpdateFromToFill(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: stream.async.splat
+ %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%c128}
+ // CHECK: %[[T:.+]] = stream.async.fill %c123_i32, %arg0[%c0 to %c128 for %c128] : i32 -> %arg0 as !stream.resource<*>{%arg1}
+ %1 = stream.async.update %0, %arg0[%c0 to %c128] : !stream.resource<*>{%c128} -> %arg0 as !stream.resource<*>{%arg1}
+ // CHECK: return %[[T]]
+ return %1 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @CombineSliceUpdateFromToCopy
+func @CombineSliceUpdateFromToCopy(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK-NOT: stream.async.slice
+ %0 = stream.async.slice %arg0[%c0 to %c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c128}
+ // CHECK: %[[T:.+]] = stream.async.copy %arg0[%c0 to %c128], %arg2[%c0 to %c128], %c128 : !stream.resource<*>{%arg1} -> %arg2 as !stream.resource<*>{%arg3}
+ %1 = stream.async.update %0, %arg2[%c0 to %c128] : !stream.resource<*>{%c128} -> %arg2 as !stream.resource<*>{%arg3}
+ // CHECK: return %[[T]]
+ return %1 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @AsyncCopyFullSourceToUpdate
+func @AsyncCopyFullSourceToUpdate(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.async.update %arg2, %arg0[%c0 to %arg3] : !stream.resource<*>{%arg3} -> %arg0 as !stream.resource<*>{%arg1}
+ %0 = stream.async.copy %arg2[%c0 to %arg3], %arg0[%c0 to %arg3], %arg3 : !stream.resource<*>{%arg3} -> %arg0 as !stream.resource<*>{%arg1}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @FoldAsyncTransferOp
+func @FoldAsyncTransferOp(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.resource<transient> {
+ // CHECK-NOT: stream.async.transfer
+ %0 = stream.async.transfer %arg0 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg1}
+ %1 = stream.async.transfer %0 : !stream.resource<staging>{%arg1} -> !stream.resource<transient>{%arg1}
+ return %1 : !stream.resource<transient>
+}
+
+// -----
+
+// CHECK-LABEL: @RedundantTransferElision
+func @RedundantTransferElision(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.resource<transient> {
+ // CHECK-NOT: stream.async.transfer
+ %0 = stream.async.transfer %arg0 : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%arg1}
+ return %0 : !stream.resource<transient>
+}
+
+// -----
+
+// CHECK-LABEL: @ElideImmediateAsyncExecuteWaits
+func @ElideImmediateAsyncExecuteWaits(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.timepoint) {
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: stream.timepoint.immediate
+ %imm = stream.timepoint.immediate => !stream.timepoint
+ // CHECK: stream.async.execute with
+ %0:2 = stream.async.execute await(%imm) => with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> %arg0 as !stream.resource<*>{%arg1} {
+ // CHECK: stream.async.dispatch
+ %1 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg2) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}
+ // CHECK: stream.yield
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ChainAsyncExecuteWaits
+func @ChainAsyncExecuteWaits(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.timepoint) -> (!stream.resource<*>, !stream.timepoint) {
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: stream.timepoint.await
+ %0 = stream.timepoint.await %arg2 => %arg0 : !stream.resource<*>{%arg1}
+ // CHECK: stream.async.execute await(%arg2) => with
+ %1:2 = stream.async.execute with(%0 as %arg3: !stream.resource<*>{%arg1}) -> %0 as !stream.resource<*>{%arg1} {
+ // CHECK: stream.async.dispatch
+ %1 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg3) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}
+ // CHECK: stream.yield
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return %1#0, %1#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @CloneCapturedAsyncExecuteSubviewOps
+func @CloneCapturedAsyncExecuteSubviewOps(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.timepoint) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ // CHECK-NOT: stream.resource.subview
+ %0 = stream.resource.subview %arg0[%c0] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c128}
+ // CHECK: = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> %arg0{%c128}
+ %1:2 = stream.async.execute with(%0 as %arg3: !stream.resource<*>{%c128}) -> %0{%c128} {
+ // CHECK: %[[T:.+]] = stream.resource.subview %arg2[%c0] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c128}
+ // CHECK: stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%[[T]])
+ %1 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg3) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}
+ // CHECK: stream.yield
+ stream.yield %1 : !stream.resource<*>{%c128}
+ } => !stream.timepoint
+ return %1#0, %1#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideNoOpAsyncExecuteOp
+func @ElideNoOpAsyncExecuteOp(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.timepoint) -> (!stream.resource<*>, !stream.timepoint) {
+ // CHECK-NOT: stream.async.execute
+ %1:2 = stream.async.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<*>{%arg1}) -> %arg0{%arg1} {
+ stream.yield %arg3 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ // CHECK: %[[IMM:.+]] = stream.timepoint.immediate
+ // CHECK: return %arg0, %[[IMM]]
+ return %1#0, %1#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @TieRegionResultsAsyncExecuteOp
+func @TieRegionResultsAsyncExecuteOp(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.timepoint) {
+ %c1 = arith.constant 1 : index
+ // CHECK: = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> %arg0{%arg1}
+ %0:2 = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} {
+ // CHECK: %[[T:.+]] = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg2)
+ %1 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg2) : (!stream.resource<*>{%arg1}) -> %arg2{%arg1}
+ // CHECK: stream.yield %[[T]]
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideUnusedAsyncExecuteOp
+func @ElideUnusedAsyncExecuteOp(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.timepoint) {
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: stream.async.execute
+ %0:2 = stream.async.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} {
+ %1 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg3) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @TieRegionResultsAsyncConcurrentOp
+func @TieRegionResultsAsyncConcurrentOp(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.timepoint) {
+ %c1 = arith.constant 1 : index
+ // CHECK: = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> %arg0{%arg1}
+ %0:2 = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} {
+ // CHECK: %[[EXEC_T:.+]] = stream.async.concurrent with(%arg2 as %arg3: !stream.resource<*>{%arg1}) -> %arg2{%arg1}
+ %1 = stream.async.concurrent with(%arg2 as %arg3: !stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} {
+ // CHECK: %[[WAVE_T:.+]] = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg3) : (!stream.resource<*>{%arg1}) -> %arg3{%arg1}
+ %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg3) : (!stream.resource<*>{%arg1}) -> %arg3{%arg1}
+ // CHECK: stream.yield %[[WAVE_T]]
+ stream.yield %2 : !stream.resource<*>{%arg1}
+ }
+ // CHECK: stream.yield %[[EXEC_T]]
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideUnusedAsyncConcurrentOp
+func @ElideUnusedAsyncConcurrentOp(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.timepoint) -> (!stream.resource<*>, !stream.timepoint) {
+ %c1 = arith.constant 1 : index
+ // CHECK: stream.async.execute
+ %0:2 = stream.async.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} {
+ // CHECK: stream.async.dispatch @executable::@dispatch0
+ %1 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg3) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}
+ // CHECK-NOT: stream.async.concurrent
+ %2 = stream.async.concurrent with(%arg3 as %arg4: !stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} {
+ // CHECK-NOT: stream.async.dispatch @executable::@dispatch1
+ %3 = stream.async.dispatch @executable::@dispatch1[%c1, %c1, %c1](%arg4) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}
+ stream.yield %3 : !stream.resource<*>{%arg1}
+ }
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
new file mode 100644
index 0000000..352a521
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
@@ -0,0 +1,132 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @asyncAlloca
+func @asyncAlloca(%arg0: index) -> !stream.resource<transient> {
+ // CHECK: = stream.async.alloca : !stream.resource<transient>{%arg0}
+ %0 = stream.async.alloca : !stream.resource<transient>{%arg0}
+ return %0 : !stream.resource<transient>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncConstant
+func @asyncConstant(%arg0: index) -> !stream.resource<transient> {
+ // CHECK: = stream.async.constant : !stream.resource<transient>{%arg0} = dense<3> : tensor<8xi32>
+ %0 = stream.async.constant : !stream.resource<transient>{%arg0} = dense<3> : tensor<8xi32>
+ return %0 : !stream.resource<transient>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncSplat
+func @asyncSplat(%arg0: index, %arg1: f32) -> !stream.resource<*> {
+ // CHECK: = stream.async.splat %arg1 : f32 -> !stream.resource<*>{%arg0}
+ %0 = stream.async.splat %arg1 : f32 -> !stream.resource<*>{%arg0}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncClone
+func @asyncClone(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ // CHECK: = stream.async.clone %arg0 : !stream.resource<*>{%arg1} -> !stream.resource<*>{%arg1}
+ %0 = stream.async.clone %arg0 : !stream.resource<*>{%arg1} -> !stream.resource<*>{%arg1}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncSlice
+func @asyncSlice(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.async.slice %arg0[%c0 to %c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c128}
+ %0 = stream.async.slice %arg0[%c0 to %c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c128}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncFill
+func @asyncFill(%arg0: !stream.resource<*>, %arg1: index, %arg2: f32) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.async.fill %arg2, %arg0[%c0 to %c128 for %c128] : f32 -> %arg0 as !stream.resource<*>{%arg1}
+ %0 = stream.async.fill %arg2, %arg0[%c0 to %c128 for %c128] : f32 -> %arg0 as !stream.resource<*>{%arg1}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncUpdate
+func @asyncUpdate(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.async.update %arg2, %arg0[%c0 to %c128] : !stream.resource<*>{%arg3} -> %arg0 as !stream.resource<*>{%arg1}
+ %0 = stream.async.update %arg2, %arg0[%c0 to %c128] : !stream.resource<*>{%arg3} -> %arg0 as !stream.resource<*>{%arg1}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncCopy
+func @asyncCopy(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.async.copy %arg2[%c0 to %c128], %arg0[%c0 to %c128], %c128 : !stream.resource<*>{%arg3} -> %arg0 as !stream.resource<*>{%arg1}
+ %0 = stream.async.copy %arg2[%c0 to %c128], %arg0[%c0 to %c128], %c128 : !stream.resource<*>{%arg3} -> %arg0 as !stream.resource<*>{%arg1}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncTransfer
+func @asyncTransfer(%arg0: !stream.resource<constant>, %arg1: index) -> !stream.resource<staging> {
+ // CHECK: = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} -> !stream.resource<staging>{%arg1}
+ %0 = stream.async.transfer %arg0 : !stream.resource<constant>{%arg1} -> !stream.resource<staging>{%arg1}
+ return %0 : !stream.resource<staging>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncDispatch
+func @asyncDispatch(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ // CHECK: = stream.async.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (!stream.resource<*>{%arg1}, index) -> %arg0{%arg1}
+ %0 = stream.async.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (!stream.resource<*>{%arg1}, index) -> %arg0{%arg1}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @asyncExecute
+func @asyncExecute(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.timepoint) -> (!stream.resource<*>, !stream.timepoint) {
+ // CHECK: = stream.async.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<*>{%arg1}) -> %arg0{%arg1} {
+ %0:2 = stream.async.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<*>{%arg1}) -> %arg0 as !stream.resource<*>{%arg1} {
+ // CHECK: %[[W:.+]] = stream.async.concurrent with(%arg3 as %arg4: !stream.resource<*>{%arg1}) -> %arg3{%arg1} {
+ %1 = stream.async.concurrent with(%arg3 as %arg4: !stream.resource<*>{%arg1}) -> %arg3 as !stream.resource<*>{%arg1} {
+ // CHECK: stream.yield %arg4 : !stream.resource<*>{%arg1}
+ stream.yield %arg4 : !stream.resource<*>{%arg1}
+ }
+ // CHECK: stream.yield %[[W]] : !stream.resource<*>{%arg1}
+ stream.yield %1 : !stream.resource<*>{%arg1}
+ } => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @asyncExecuteNoCaptures
+func @asyncExecuteNoCaptures(%arg0: index, %arg1: f32) -> (!stream.resource<*>, !stream.timepoint) {
+ // CHECK: = stream.async.execute with() -> !stream.resource<*>{%arg0} {
+ %0:2 = stream.async.execute with() -> !stream.resource<*>{%arg0} {
+ // CHECK: %[[T:.+]] = stream.async.splat
+ %1 = stream.async.splat %arg1 : f32 -> !stream.resource<*>{%arg0}
+ // CHECK: stream.yield %[[T]] : !stream.resource<*>{%arg0}
+ stream.yield %1 : !stream.resource<*>{%arg0}
+ } => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/cmd_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/cmd_folding.mlir
new file mode 100644
index 0000000..a271972
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/cmd_folding.mlir
@@ -0,0 +1,144 @@
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @FoldSubviewsIntoCmdTOp
+func @FoldSubviewsIntoCmdTOp(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c1000 = arith.constant 1000 : index
+ %c2000 = arith.constant 2000 : index
+ %c3000 = arith.constant 3000 : index
+ %cst = arith.constant 4.2 : f32
+ %0 = stream.resource.subview %arg0[%c64] : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%c3000}
+ %1 = stream.cmd.execute with(%0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.flush %arg2[%c1064 for %c2000] : !stream.resource<transient>{%arg1}
+ stream.cmd.flush %arg2[%c1000 for %c2000] : !stream.resource<transient>{%c3000}
+ // CHECK: stream.cmd.invalidate %arg2[%c1064 for %c2000] : !stream.resource<transient>{%arg1}
+ stream.cmd.invalidate %arg2[%c1000 for %c2000] : !stream.resource<transient>{%c3000}
+ // CHECK: stream.cmd.discard %arg2[%c1064 for %c2000] : !stream.resource<transient>{%arg1}
+ stream.cmd.discard %arg2[%c1000 for %c2000] : !stream.resource<transient>{%c3000}
+ // CHECK: stream.cmd.fill %cst, %arg2[%c1064 for %c2000] : f32 -> !stream.resource<transient>{%arg1}
+ stream.cmd.fill %cst, %arg2[%c1000 for %c2000] : f32 -> !stream.resource<transient>{%c3000}
+ } => !stream.timepoint
+ return %1 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubviewsIntoCmdCopyOp
+func @FoldSubviewsIntoCmdCopyOp(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c1000 = arith.constant 1000 : index
+ %c2000 = arith.constant 2000 : index
+ %c3000 = arith.constant 3000 : index
+ %c4000 = arith.constant 4000 : index
+ %0 = stream.resource.subview %arg0[%c64] : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%c3000}
+ %1 = stream.resource.subview %arg0[%c128] : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%c4000}
+ %2 = stream.cmd.execute with(%0 as %arg2: !stream.resource<transient>{%c3000}, %1 as %arg3: !stream.resource<transient>{%c4000}) {
+ // CHECK: stream.cmd.copy %arg2[%c1064], %arg2[%c2128], %c1000 : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%arg1}
+ stream.cmd.copy %arg2[%c1000], %arg3[%c2000], %c1000 : !stream.resource<transient>{%c3000} -> !stream.resource<transient>{%c4000}
+ } => !stream.timepoint
+ return %2 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubviewsIntoCmdDispatchOp
+func @FoldSubviewsIntoCmdDispatchOp(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c1000 = arith.constant 1000 : index
+ %c2000 = arith.constant 2000 : index
+ %c3000 = arith.constant 3000 : index
+ %c4000 = arith.constant 4000 : index
+ %0 = stream.resource.subview %arg0[%c64] : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%c3000}
+ %1 = stream.resource.subview %arg0[%c128] : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%c4000}
+ %2 = stream.cmd.execute with(%0 as %arg2: !stream.resource<transient>{%c3000}, %1 as %arg3: !stream.resource<transient>{%c4000}) {
+ // CHECK: stream.cmd.dispatch
+ stream.cmd.dispatch @executable::@dispatch[%c1, %c1, %c1] {
+ // CHECK-NEXT: ro %arg2[%c1064 for %c1000] : !stream.resource<transient>{%arg1}
+ ro %arg2[%c1000 for %c1000] : !stream.resource<transient>{%c3000},
+ // CHECK-NEXT: wo %arg2[%c2128 for %c1000] : !stream.resource<transient>{%arg1}
+ wo %arg3[%c2000 for %c1000] : !stream.resource<transient>{%c4000}
+ }
+ } => !stream.timepoint
+ return %2 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideImmediateCmdExecuteWaits
+func @ElideImmediateCmdExecuteWaits(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: stream.timepoint.immediate
+ %imm = stream.timepoint.immediate => !stream.timepoint
+ // CHECK: stream.cmd.execute with
+ %0 = stream.cmd.execute await(%imm) => with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ stream.cmd.discard %arg2[%c0 for %arg1] : !stream.resource<transient>{%arg1}
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ChainCmdExecuteWaits
+func @ChainCmdExecuteWaits(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.timepoint) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK-NOT: stream.timepoint.await
+ %0 = stream.timepoint.await %arg2 => %arg0 : !stream.resource<transient>{%arg1}
+ // CHECK: stream.cmd.execute await(%arg2) => with
+ %1 = stream.cmd.execute with(%0 as %arg3: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.discard
+ stream.cmd.discard %arg3[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ } => !stream.timepoint
+ return %1 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @CloneCapturedCmdExecuteSubviewOps
+func @CloneCapturedCmdExecuteSubviewOps(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c1000 = arith.constant 1000 : index
+ %c2000 = arith.constant 2000 : index
+ %c3000 = arith.constant 3000 : index
+ // CHECK-NOT: stream.resource.subview
+ %0 = stream.resource.subview %arg0[%c64] : !stream.resource<transient>{%arg1} -> !stream.resource<transient>{%c3000}
+ // CHECK: = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1})
+ %1 = stream.cmd.execute with(%0 as %arg3: !stream.resource<transient>{%c3000}) {
+ // CHECK: stream.cmd.discard %arg2[%c1064 for %c2000] : !stream.resource<transient>{%arg1}
+ stream.cmd.discard %arg3[%c1000 for %c2000] : !stream.resource<transient>{%arg1}
+ } => !stream.timepoint
+ return %1 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideNoOpCmdExecuteOp
+func @ElideNoOpCmdExecuteOp(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.timepoint) -> !stream.timepoint {
+ // CHECK-NOT: stream.cmd.execute
+ %0 = stream.cmd.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<transient>{%arg1}) {
+ } => !stream.timepoint
+ // CHECK: %[[IMM:.+]] = stream.timepoint.immediate
+ // CHECK: return %[[IMM]]
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideUnusedCmdExecuteOp
+func @ElideUnusedCmdExecuteOp(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.timepoint) {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK-NOT: stream.cmd.execute
+ %0 = stream.cmd.execute await(%arg2) => with(%arg0 as %arg3: !stream.resource<transient>{%arg1}) {
+ stream.cmd.discard %arg3[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ } => !stream.timepoint
+ return
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/cmd_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/cmd_ops.mlir
new file mode 100644
index 0000000..b0afc28
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/cmd_ops.mlir
@@ -0,0 +1,94 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @cmdMemoryControl
+func @cmdMemoryControl(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.flush %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ stream.cmd.flush %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ // CHECK: stream.cmd.invalidate %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ stream.cmd.invalidate %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ // CHECK: stream.cmd.discard %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ stream.cmd.discard %arg2[%c0 for %c128] : !stream.resource<transient>{%arg1}
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @cmdFill
+func @cmdFill(%arg0: !stream.resource<transient>, %arg1: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 4.2 : f32
+ %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource<transient>{%arg1}) {
+ // CHECK: stream.cmd.fill %cst, %arg2[%c0 for %c128] : f32 -> !stream.resource<transient>{%arg1}
+ stream.cmd.fill %cst, %arg2[%c0 for %c128] : f32 -> !stream.resource<transient>{%arg1}
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @cmdCopy
+func @cmdCopy(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<staging>{%arg3}) {
+ // CHECK: stream.cmd.copy %arg4[%c0], %arg5[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
+ stream.cmd.copy %arg4[%c0], %arg5[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @cmdDispatch
+func @cmdDispatch(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<external>, %arg3: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c5 = arith.constant 5 : index
+ %c128 = arith.constant 128 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) {
+ // CHECK: stream.cmd.dispatch @executable::@dispatch[%c1, %c2, %c3](%c4, %c5 : index, index) {
+ // CHECK-NEXT: ro %arg4[%c0 for %c128] : !stream.resource<transient>{%arg1},
+ // CHECK-NEXT: wo %arg5[%c0 for %c128] : !stream.resource<external>{%arg3}
+ // CHECK-NEXT: }
+ stream.cmd.dispatch @executable::@dispatch[%c1, %c2, %c3](%c4, %c5 : index, index) {
+ ro %arg4[%c0 for %c128] : !stream.resource<transient>{%arg1},
+ wo %arg5[%c0 for %c128] : !stream.resource<external>{%arg3}
+ }
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @cmdExecute
+func @cmdExecute(%arg0: !stream.resource<transient>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.cmd.execute await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
+ %0 = stream.cmd.execute await(%arg4) => with(%arg0 as %arg5: !stream.resource<transient>{%arg1}, %arg2 as %arg6: !stream.resource<staging>{%arg3}) {
+ // CHECK: stream.cmd.concurrent {
+ stream.cmd.concurrent {
+ // CHECK-NEXT: stream.cmd.copy
+ stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
+ // CHECK-NEXT: stream.cmd.copy
+ stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
+ // CHECK: stream.cmd.serial {
+ stream.cmd.serial {
+ // CHECK-NEXT: stream.cmd.copy
+ stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
+ // CHECK-NEXT: stream.cmd.copy
+ stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource<transient>{%arg1} -> !stream.resource<staging>{%arg3}
+ }
+ }
+ // CHECK: } => !stream.timepoint
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
new file mode 100644
index 0000000..6d607c2
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
@@ -0,0 +1,20 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: stream.executable private @executable
+stream.executable private @executable {
+ // CHECK-NEXT: stream.executable.export public @dispatch
+ stream.executable.export public @dispatch
+ // CHECK-NEXT: builtin.module
+ builtin.module {
+ // CHECK-NEXT: func @dispatch(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index) {
+ func @dispatch(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ // CHECK-DAG: = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:?x5x64xf32>{%arg2}
+ %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:?x5x64xf32>{%arg2}
+ // CHECK-DAG: = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:?x5x4xf32>{%arg2}
+ %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:?x5x4xf32>{%arg2}
+ // CHECK: return
+ return
+ }
+ }
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
new file mode 100644
index 0000000..0c7b333
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
@@ -0,0 +1,162 @@
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @FoldResourceSizeOp
+func @FoldResourceSizeOp(%arg0: !stream.resource<staging>, %arg1: index) -> (index, i32) {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: stream.resource.size
+ %0 = stream.resource.size %arg0 : !stream.resource<staging>
+ // CHECK: %[[LOAD:.+]] = stream.resource.load
+ %1 = stream.resource.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> i32
+ // CHECK: return %arg1, %[[LOAD]]
+ return %0, %1 : index, i32
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubviewIntoLoadOp
+func @FoldSubviewIntoLoadOp(%arg0: !stream.resource<staging>, %arg1: index) -> i32 {
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ // CHECK-NOT: stream.resource.subview
+ %0 = stream.resource.subview %arg0[%c128] : !stream.resource<staging>{%arg1} -> !stream.resource<staging>{%c256}
+ // CHECK: = stream.resource.load %arg0[%c192] : !stream.resource<staging>{%arg1} -> i32
+ %1 = stream.resource.load %0[%c64] : !stream.resource<staging>{%c256} -> i32
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @FoldSubviewIntoStoreOp
+func @FoldSubviewIntoStoreOp(%arg0: !stream.resource<staging>, %arg1: index) {
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: stream.resource.subview
+ %0 = stream.resource.subview %arg0[%c128] : !stream.resource<staging>{%arg1} -> !stream.resource<staging>{%c256}
+ // CHECK: stream.resource.store %c123_i32, %arg0[%c192] : i32 -> !stream.resource<staging>{%arg1}
+ stream.resource.store %c123_i32, %0[%c64] : i32 -> !stream.resource<staging>{%c256}
+ return
+}
+
+// -----
+
+// A pack with no slices folds to a zero-length slab.
+
+// CHECK-LABEL: @FoldResourcePackOpEmpty
+func @FoldResourcePackOpEmpty(%allocator: !hal.allocator) -> index {
+ // CHECK-NEXT: %[[ZERO_LENGTH:.+]] = arith.constant 0
+ %total_length = stream.resource.pack slices({}) : index
+ // CHECK-NEXT: return %[[ZERO_LENGTH]]
+ return %total_length : index
+}
+
+// -----
+
+// A pack with a single slices folds to just that slice.
+
+// CHECK-LABEL: @FoldResourcePackOpOneSlice
+// CHECK-SAME: %[[OFFSET:.+]]: index,
+// CHECK-SAME: %[[SIZE:.+]]: index
+func @FoldResourcePackOpOneSlice(%offset: index, %size: index) -> (index, index) {
+ // CHECK-NOT: stream.resource.pack
+ %total_length, %offset_0 =
+ stream.resource.pack
+ offset(%offset)
+ slices({
+ [0, 4] = %size
+ }) : index
+ // CHECK: return %[[SIZE]], %[[OFFSET]]
+ return %total_length, %offset_0 : index, index
+}
+
+// -----
+
+// A constant zero offset operand gets dropped.
+
+// CHECK-LABEL: @PropagateResourcePackZeroOffset
+func @PropagateResourcePackZeroOffset(%size : index) -> (index, index, index) {
+ // CHECK-NOT: constant 0
+ // CHECK-NEXT: = stream.resource.pack slices({
+ %base_offset = arith.constant 0 : index
+ %total_length, %offset_0, %offset_1 =
+ stream.resource.pack
+ offset(%base_offset)
+ slices({
+ [0, 4] = %size,
+ [1, 2] = %size,
+ }) : index
+ return %total_length, %offset_0, %offset_1 : index, index, index
+}
+
+// -----
+
+// A base offset operand gets propagated to returned values.
+
+// CHECK-LABEL: @PropagateResourcePackBaseOffset
+// CHECK-SAME: %[[BASE_OFFSET:.+]]: index,
+// CHECK-SAME: %[[SIZE:.+]]: index
+func @PropagateResourcePackBaseOffset(%base_offset: index, %size : index) -> (index, index, index) {
+ // CHECK-NEXT: %[[PACKED:.+]]:3 =
+ %total_length, %offset_0, %offset_1 =
+ // CHECK-SAME: stream.resource.pack slices({
+ stream.resource.pack
+ offset(%base_offset)
+ slices({
+ [0, 4] = %size,
+ [1, 2] = %size,
+ }) : index
+ // CHECK: %[[ADJUSTED_0:.+]] = arith.addi %[[BASE_OFFSET]], %[[PACKED]]#1
+ // CHECK-NEXT: %[[ADJUSTED_1:.+]] = arith.addi %[[BASE_OFFSET]], %[[PACKED]]#2
+ // CHECK-NEXT: return %[[PACKED]]#0, %[[ADJUSTED_0]], %[[ADJUSTED_1]]
+ return %total_length, %offset_0, %offset_1 : index, index, index
+}
+
+// -----
+
+// Intervals should be sorted.
+
+// CHECK-LABEL: @CanonicalizeResourcePackIntervals
+// CHECK-SAME: %[[SIZE:.+]]: index
+func @CanonicalizeResourcePackIntervals(%size : index) -> (index, index, index) {
+ // CHECK-NEXT: %[[PACKED:.+]]:3 =
+ %total_length, %offset_0, %offset_1 =
+ // CHECK-SAME: stream.resource.pack slices({
+ stream.resource.pack
+ slices({
+ // CHECK-NEXT: [0, 4] = %[[SIZE]],
+ // CHECK-NEXT: [1, 2] = %[[SIZE]]
+ [1, 2] = %size,
+ [0, 4] = %size,
+ }) : index
+ // CHECK: return %[[PACKED]]#0, %[[PACKED]]#2, %[[PACKED]]#1
+ return %total_length, %offset_0, %offset_1 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @FoldResourceSubviewOp
+func @FoldResourceSubviewOp(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: stream.resource.subview
+ %0 = stream.resource.subview %arg0[%c0] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%arg1}
+ // CHECK: return %arg0
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @FoldResourceSubviewOps
+func @FoldResourceSubviewOps(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c100 = arith.constant 100 : index
+ %c300 = arith.constant 300 : index
+ %c400 = arith.constant 400 : index
+ %c500 = arith.constant 500 : index
+ // CHECK: %[[RET:.+]] = stream.resource.subview %arg0[%c300] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c300}
+ %0 = stream.resource.subview %arg0[%c100] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c500}
+ %1 = stream.resource.subview %0[%c100] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c400}
+ %2 = stream.resource.subview %1[%c100] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c300}
+ // CHECK-NEXT: return %[[RET]]
+ return %2 : !stream.resource<*>
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir
new file mode 100644
index 0000000..3e486b6
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir
@@ -0,0 +1,122 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @resourceAlloc
+func @resourceAlloc(%arg0: index, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>) {
+ // CHECK: = stream.resource.alloc uninitialized : !stream.resource<*>{%arg0}, !stream.resource<*>{%arg1}
+ %0:2 = stream.resource.alloc uninitialized : !stream.resource<*>{%arg0}, !stream.resource<*>{%arg1}
+ return %0#0, %0#1 : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @resourceAlloca
+func @resourceAlloca(%arg0: index) -> (!stream.resource<staging>, !stream.timepoint) {
+ // CHECK: = stream.resource.alloca uninitialized : !stream.resource<staging>{%arg0} => !stream.timepoint
+ %0:2 = stream.resource.alloca uninitialized : !stream.resource<staging>{%arg0} => !stream.timepoint
+ return %0#0, %0#1 : !stream.resource<staging>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @resourceDealloca
+func @resourceDealloca(%arg0: index, %arg1: !stream.resource<staging>, %arg2: !stream.timepoint) {
+ // CHECK: stream.resource.dealloca await(%arg2) => %arg1 : !stream.resource<staging>{%arg0}
+ stream.resource.dealloca await(%arg2) => %arg1 : !stream.resource<staging>{%arg0}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @resourceSize
+func @resourceSize(%arg0: !stream.resource<*>) -> index {
+ // CHECK: = stream.resource.size %arg0 : !stream.resource<*>
+ %0 = stream.resource.size %arg0 : !stream.resource<*>
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @resourceMap
+func @resourceMap(%arg0: !util.byte_buffer) -> !stream.resource<staging> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.resource.map %arg0[%c0] : !util.byte_buffer -> !stream.resource<staging>{%c128}
+ %0 = stream.resource.map %arg0[%c0] : !util.byte_buffer -> !stream.resource<staging>{%c128}
+ return %0 : !stream.resource<staging>
+}
+
+// -----
+
+// CHECK-LABEL: @resourceTryMap
+func @resourceTryMap(%arg0: !util.byte_buffer) -> (i1, !stream.resource<constant>) {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ // CHECK: = stream.resource.try_map %arg0[%c0] : !util.byte_buffer -> i1, !stream.resource<constant>{%c128}
+ %0:2 = stream.resource.try_map %arg0[%c0] : !util.byte_buffer -> i1, !stream.resource<constant>{%c128}
+ return %0#0, %0#1 : i1, !stream.resource<constant>
+}
+
+// -----
+
+// CHECK-LABEL: @resourceLoad
+func @resourceLoad(%arg0: !stream.resource<staging>, %arg1: index) -> i32 {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.resource.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> i32
+ %0 = stream.resource.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> i32
+ return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @resourceStore
+func @resourceStore(%arg0: !stream.resource<staging>, %arg1: index) {
+ %c0 = arith.constant 0 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: stream.resource.store %c123_i32, %arg0[%c0] : i32 -> !stream.resource<staging>{%arg1}
+ stream.resource.store %c123_i32, %arg0[%c0] : i32 -> !stream.resource<staging>{%arg1}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @resourcePack
+func @resourcePack(%arg0: index, %arg1: index) -> (index, index, index) {
+ %c128 = arith.constant 128 : index
+ // CHECK: stream.resource.pack offset(%c128) slices({
+ // CHECK-NEXT: [0, 9] = %arg0,
+ // CHECK-NEXT: [3, 8] = %arg1
+ // CHECK-NEXT: })
+ %0:3 = stream.resource.pack offset(%c128) slices({
+ [0, 9] = %arg0,
+ [3, 8] = %arg1,
+ }) : index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @resourceConstants
+func @resourceConstants() -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ // CHECK: = stream.resource.constants :
+ // CHECK-NEXT: !stream.resource<constant>{%c4} = dense<100> : tensor<1xi32>,
+ // CHECK-NEXT: !stream.resource<constant>{%c8} = dense<[101, 102]> : tensor<2xi32>
+ // CHECK-NEXT: => !stream.timepoint
+ %0:3 = stream.resource.constants :
+ !stream.resource<constant>{%c4} = dense<100> : tensor<1xi32>,
+ !stream.resource<constant>{%c8} = dense<[101, 102]> : tensor<2xi32>
+ => !stream.timepoint
+ return %0#0, %0#1, %0#2 : !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @resourceSubview
+func @resourceSubview(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ // CHECK: = stream.resource.subview %arg0[%c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c256}
+ %0 = stream.resource.subview %arg0[%c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c256}
+ return %0 : !stream.resource<*>
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
new file mode 100644
index 0000000..b5692ec
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @TensorConstantToSplat
+func @TensorConstantToSplat() -> !stream.resource<constant> {
+ // CHECK-DAG: %[[CST:.+]] = arith.constant 1.000000e+00 : f32
+ // CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<2x2xf32> : index
+ // CHECK: = stream.tensor.splat %[[CST]] : f32 -> tensor<2x2xf32> in !stream.resource<*>{%[[SIZE]]}
+ %cst = stream.tensor.constant : tensor<2x2xf32> in !stream.resource<constant> = dense<1.000000e+00> : tensor<2x2xf32>
+ return %cst : !stream.resource<constant>
+}
+
+// -----
+
+// CHECK-LABEL: @FoldTensorCloneOp
+func @FoldTensorCloneOp(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
+ // CHECK-NOT: stream.tensor.clone
+ %0 = stream.tensor.clone %arg0 : tensor<2x2xf32> in !stream.resource<*>{%arg1} -> tensor<2x2xf32> in !stream.resource<*>{%arg1}
+ // CHECK: return %arg0
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @ElideUnneededTensorClones
+func @ElideUnneededTensorClones(%arg0: !stream.resource<*>, %arg1: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: stream.tensor.clone
+ %0 = stream.tensor.clone %arg0 : tensor<2x2xf32> in !stream.resource<*>{%arg1} -> tensor<2x2xf32> in !stream.resource<*>{%arg1}
+ // CHECK: %[[T0:.+]] = stream.async.transfer %arg0 : !stream.resource<*>{%arg1} -> !stream.resource<staging>{%arg1}
+ %1 = stream.async.transfer %0 : !stream.resource<*>{%arg1} -> !stream.resource<staging>{%arg1}
+ // CHECK: %[[T1:.+]] = stream.tensor.load %[[T0]][%c0, %c0] : tensor<2x2xf32> in !stream.resource<staging>{%arg1} -> f32
+ %2 = stream.tensor.load %1[%c0, %c0] : tensor<2x2xf32> in !stream.resource<staging>{%arg1} -> f32
+ // CHECK: return %[[T1]]
+ return %2 : f32
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir
new file mode 100644
index 0000000..8524d80
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir
@@ -0,0 +1,128 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @tensorImport
+func @tensorImport(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource<external> {
+ %c20 = arith.constant 20 : index
+ // CHECK: = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
+ %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x5xf32>{%arg1} in !stream.resource<external>{%c20}
+ return %0 : !stream.resource<external>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorExport
+func @tensorExport(%arg0: !stream.resource<external>, %arg1: index) -> !hal.buffer_view {
+ %c200 = arith.constant 200 : index
+ // CHECK: = stream.tensor.export %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer_view
+ %0 = stream.tensor.export %arg0 : tensor<?x1x10xf32>{%arg1} in !stream.resource<external>{%c200} -> !hal.buffer_view
+ return %0 : !hal.buffer_view
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSizeOf
+func @tensorSizeOf(%arg0: index) -> index {
+ // CHECK: = stream.tensor.sizeof tensor<?x5xf32>{%arg0} : index
+ %0 = stream.tensor.sizeof tensor<?x5xf32>{%arg0} : index
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @tensorConstant
+func @tensorConstant(%arg0: index) -> !stream.resource<constant> {
+ // CHECK: = stream.tensor.constant : tensor<?x5x64xf32>{%arg0} in !stream.resource<constant> = dense<0.000000e+00> : tensor<1x5x64xf32>
+ %0 = stream.tensor.constant : tensor<?x5x64xf32>{%arg0} in !stream.resource<constant> = dense<0.000000e+00> : tensor<1x5x64xf32>
+ return %0 : !stream.resource<constant>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSplat
+func @tensorSplat(%arg0: f32, %arg1: index, %arg2: index) -> !stream.resource<*> {
+ // CHECK: = stream.tensor.splat %arg0 : f32 -> tensor<?x1x10xf32>{%arg1} in !stream.resource<*>{%arg2}
+ %0 = stream.tensor.splat %arg0 : f32 -> tensor<?x1x10xf32>{%arg1} in !stream.resource<*>{%arg2}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorClone
+func @tensorClone(%arg0: !stream.resource<*>, %arg1: index, %arg2: index) -> !stream.resource<*> {
+ // CHECK: = stream.tensor.clone %arg0 : tensor<?x4xf32>{%arg1} in !stream.resource<*>{%arg2} -> tensor<?x4xf32>{%arg1} in !stream.resource<*>{%arg2}
+ %0 = stream.tensor.clone %arg0 : tensor<?x4xf32>{%arg1} in !stream.resource<*>{%arg2} -> tensor<?x4xf32>{%arg1} in !stream.resource<*>{%arg2}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSlice
+func @tensorSlice(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: = stream.tensor.slice %arg0[%c0, %c1 for %arg3, %c1] : tensor<?x4xf32>{%arg1} in !stream.resource<*>{%arg2} -> tensor<?x1xf32>{%arg3} in !stream.resource<*>{%arg4}
+ %0 = stream.tensor.slice %arg0[%c0, %c1 for %arg3, %c1] : tensor<?x4xf32>{%arg1} in !stream.resource<*>{%arg2} -> tensor<?x1xf32>{%arg3} in !stream.resource<*>{%arg4}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorFill
+func @tensorFill(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: = stream.tensor.fill %arg0, %arg1[%c0, %c0 for %c1, %c1] : f32 -> tensor<?x4xf32>{%arg2} in %arg1 as !stream.resource<*>{%arg3}
+ %0 = stream.tensor.fill %arg0, %arg1[%c0, %c0 for %c1, %c1] : f32 -> tensor<?x4xf32>{%arg2} in %arg1 as !stream.resource<*>{%arg3}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorUpdate
+func @tensorUpdate(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: = stream.tensor.update %arg0, %arg2[%c0, %c0] : tensor<2x2xf32> in !stream.resource<*>{%arg1} -> tensor<?x4xf32>{%arg3} in %arg2 as !stream.resource<*>{%arg4}
+ %0 = stream.tensor.update %arg0, %arg2[%c0, %c0] : tensor<2x2xf32> in !stream.resource<*>{%arg1} -> tensor<?x4xf32>{%arg3} in %arg2 as !stream.resource<*>{%arg4}
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorLoad
+func @tensorLoad(%arg0: !stream.resource<staging>, %arg1: index, %arg2: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.tensor.load %arg0[%c0] : tensor<?xf32>{%arg1} in !stream.resource<staging>{%arg2} -> f32
+ %0 = stream.tensor.load %arg0[%c0] : tensor<?xf32>{%arg1} in !stream.resource<staging>{%arg2} -> f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @tensorLoadRank0
+func @tensorLoadRank0(%arg0: !stream.resource<staging>, %arg1: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.tensor.load %arg0 : tensor<f32> in !stream.resource<staging>{%arg1} -> f32
+ %0 = stream.tensor.load %arg0 : tensor<f32> in !stream.resource<staging>{%arg1} -> f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @tensorStore
+func @tensorStore(%arg0: !stream.resource<staging>, %arg1: index, %arg2: index, %arg3: f32) -> !stream.resource<staging> {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.tensor.store %arg3, %arg0[%c0] : f32 -> tensor<?xf32>{%arg1} in %arg0 as !stream.resource<staging>{%arg2}
+ %0 = stream.tensor.store %arg3, %arg0[%c0] : f32 -> tensor<?xf32>{%arg1} in %arg0 as !stream.resource<staging>{%arg2}
+ return %0 : !stream.resource<staging>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorStoreRank0
+func @tensorStoreRank0(%arg0: !stream.resource<staging>, %arg1: index, %arg2: f32) -> !stream.resource<staging> {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.tensor.store %arg2, %arg0 : f32 -> tensor<f32> in %arg0 as !stream.resource<staging>{%arg1}
+ %0 = stream.tensor.store %arg2, %arg0 : f32 -> tensor<f32> in %arg0 as !stream.resource<staging>{%arg1}
+ return %0 : !stream.resource<staging>
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
new file mode 100644
index 0000000..0efb739
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
@@ -0,0 +1,163 @@
+// RUN: iree-opt -split-input-file -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @FoldTimepointJoinOp
+func @FoldTimepointJoinOp(%arg0: !stream.timepoint) -> !stream.timepoint {
+ // CHECK-NOT: stream.timepoint.join
+ %0 = stream.timepoint.join max(%arg0) => !stream.timepoint
+ // CHECK: return %arg0
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideImmediateTimepointJoinOperands
+func @ElideImmediateTimepointJoinOperands(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> !stream.timepoint {
+ %0 = stream.timepoint.immediate => !stream.timepoint
+ %1 = stream.timepoint.immediate => !stream.timepoint
+ // CHECK: = stream.timepoint.join max(%arg0, %arg1)
+ %2 = stream.timepoint.join max(%arg0, %0, %1, %arg1) => !stream.timepoint
+ return %2 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideImmediateTimepointJoinOperandsAll
+func @ElideImmediateTimepointJoinOperandsAll() -> !stream.timepoint {
+ %0 = stream.timepoint.immediate => !stream.timepoint
+ %1 = stream.timepoint.immediate => !stream.timepoint
+ // CHECK-NOT: stream.timepoint.join
+ %2 = stream.timepoint.join max(%0, %1) => !stream.timepoint
+ // CHECK: %[[IMM:.+]] = stream.timepoint.immediate
+ // CHECK: return %[[IMM]]
+ return %2 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @FoldDuplicateTimepointJoinOperands
+func @FoldDuplicateTimepointJoinOperands(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> !stream.timepoint {
+ // CHECK: = stream.timepoint.join max(%arg0, %arg1)
+ %0 = stream.timepoint.join max(%arg0, %arg1, %arg0, %arg1) => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @ElideImmediateAwaits
+func @ElideImmediateAwaits(%arg0: !stream.resource<staging>) -> !stream.resource<staging> {
+ %c100 = arith.constant 100 : index
+ // CHECK-NOT: stream.timepoint.immediate
+ %0 = stream.timepoint.immediate => !stream.timepoint
+ // CHECK-NOT: stream.timepoint.await
+ %1 = stream.timepoint.await %0 => %arg0 : !stream.resource<staging>{%c100}
+ // CHECK: return %arg0
+ return %1 : !stream.resource<staging>
+}
+
+// -----
+
+// Ensures that the await moves to the first common dominator of bb2/bb3 that
+// use the awaited resources.
+
+// CHECK-LABEL: @SinkAwaitToFirstConsumer
+func @SinkAwaitToFirstConsumer(
+ %arg0: i1, %arg1: i1,
+ %arg2: !stream.resource<constant>,
+ %arg3: !stream.resource<staging>,
+ %arg4: !stream.resource<external>,
+ %arg5: !stream.timepoint
+) -> !stream.resource<external> {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-NOT: stream.timepoint.await
+ %0:2 = stream.timepoint.await %arg5 => %arg2, %arg3 : !stream.resource<constant>{%c100}, !stream.resource<staging>{%c200}
+ // CHECK: cond_br %arg0, ^bb1, ^bb4
+ cond_br %arg0, ^bb1, ^bb4(%arg4 : !stream.resource<external>)
+// CHECK: ^bb1:
+^bb1:
+ // CHECK: %[[READY:.+]]:2 = stream.timepoint.await %arg5 => %arg2, %arg3 : !stream.resource<constant>{%c100}, !stream.resource<staging>{%c200}
+ // CHECK-NEXT: cond_br %arg1, ^bb2, ^bb3
+ cond_br %arg1, ^bb2, ^bb3
+// CHECK: ^bb2:
+^bb2:
+ // CHECK: = stream.async.transfer %[[READY]]#0
+ %1 = stream.async.transfer %0#0 : !stream.resource<constant>{%c100} -> !stream.resource<external>{%c100}
+ br ^bb4(%1 : !stream.resource<external>)
+// CHECK: ^bb3:
+^bb3:
+ // CHECK: = stream.async.transfer %[[READY]]#1
+ %2 = stream.async.transfer %0#1 : !stream.resource<staging>{%c200} -> !stream.resource<external>{%c200}
+ br ^bb4(%2 : !stream.resource<external>)
+// CHECK: ^bb4(
+^bb4(%arg6: !stream.resource<external>):
+ return %arg6 : !stream.resource<external>
+}
+
+// -----
+
+// CHECK-LABEL: @SinkSubviewsAcrossAwaits
+func @SinkSubviewsAcrossAwaits(
+ %arg0: !stream.resource<*>, %arg1: index,
+ %arg2: !stream.timepoint
+) -> !stream.resource<*> {
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ // CHECK: %[[READY:.+]] = stream.timepoint.await %arg2 => %arg0 : !stream.resource<*>{%arg1}
+ // CHECK: %[[RET:.+]] = stream.resource.subview %[[READY]][%c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c256}
+ %0 = stream.resource.subview %arg0[%c128] : !stream.resource<*>{%arg1} -> !stream.resource<*>{%c256}
+ %1 = stream.timepoint.await %arg2 => %0 : !stream.resource<*>{%c256}
+ // CHECK: return %[[RET]]
+ return %1 : !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @GroupAwaitsByTimepoint
+func @GroupAwaitsByTimepoint(
+ %arg0: !stream.timepoint,
+ %arg1: !stream.resource<*>,
+ %arg2: !stream.resource<*>,
+ %arg3: !stream.resource<*>,
+ %arg4: !stream.resource<*>
+) -> (!stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.resource<*>) {
+ %c100 = arith.constant 100 : index
+ %c101 = arith.constant 101 : index
+ %c102 = arith.constant 102 : index
+ %c103 = arith.constant 103 : index
+ // CHECK: %[[RET:.+]]:4 = stream.timepoint.await %arg0 => %arg1, %arg2, %arg3, %arg4 :
+ // CHECK-SAME: !stream.resource<*>{%c100}, !stream.resource<*>{%c101}, !stream.resource<*>{%c102}, !stream.resource<*>{%c103}
+ %0 = stream.timepoint.await %arg0 => %arg1 : !stream.resource<*>{%c100}
+ %1 = stream.timepoint.await %arg0 => %arg2 : !stream.resource<*>{%c101}
+ %2:2 = stream.timepoint.await %arg0 => %arg3, %arg4 : !stream.resource<*>{%c102}, !stream.resource<*>{%c103}
+ // CHECK-NEXT: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2, %[[RET]]#3
+ return %0, %1, %2#0, %2#1 : !stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// CHECK-LABEL: @FoldDuplicateAwaitResources
+func @FoldDuplicateAwaitResources(
+ %arg0: !stream.timepoint,
+ %arg1: !stream.resource<staging>, %arg2: !stream.resource<*>
+) -> (!stream.resource<staging>, !stream.resource<*>, !stream.resource<staging>, !stream.resource<staging>) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: %[[RET:.+]]:2 = stream.timepoint.await %arg0 => %arg1, %arg2 : !stream.resource<staging>{%c100}, !stream.resource<*>{%c200}
+ %0:4 = stream.timepoint.await %arg0 => %arg1, %arg2, %arg1, %arg1 : !stream.resource<staging>{%c100}, !stream.resource<*>{%c200}, !stream.resource<staging>{%c100}, !stream.resource<staging>{%c100}
+ // CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#0, %[[RET]]#0
+ return %0#0, %0#1, %0#2, %0#3 : !stream.resource<staging>, !stream.resource<*>, !stream.resource<staging>, !stream.resource<staging>
+}
+
+// -----
+
+// CHECK-LABEL: @ElideUnusedTimepointAwaitOp
+func @ElideUnusedTimepointAwaitOp(
+ %arg0: !stream.timepoint,
+ %arg1: !stream.resource<staging>, %arg2: !stream.resource<*>
+) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-NOT: stream.timepoint.await
+ %0:2 = stream.timepoint.await %arg0 => %arg1, %arg2 : !stream.resource<staging>{%c100}, !stream.resource<*>{%c200}
+ return
+}
diff --git a/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir
new file mode 100644
index 0000000..fcc96d8
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir
@@ -0,0 +1,28 @@
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @timepointImmediate
+func @timepointImmediate() -> !stream.timepoint {
+ // CHECK: = stream.timepoint.immediate => !stream.timepoint
+ %0 = stream.timepoint.immediate => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @timepointJoin
+func @timepointJoin(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> !stream.timepoint {
+ // CHECK: = stream.timepoint.join max(%arg0, %arg1) => !stream.timepoint
+ %0 = stream.timepoint.join max(%arg0, %arg1) => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @timepointAwait
+func @timepointAwait(%arg0: !stream.timepoint, %arg1: !stream.resource<staging>, %arg2: !stream.resource<*>) -> (!stream.resource<staging>, !stream.resource<*>) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK: = stream.timepoint.await %arg0 => %arg1, %arg2 : !stream.resource<staging>{%c100}, !stream.resource<*>{%c200}
+ %0:2 = stream.timepoint.await %arg0 => %arg1, %arg2 : !stream.resource<staging>{%c100}, !stream.resource<*>{%c200}
+ return %0#0, %0#1 : !stream.resource<staging>, !stream.resource<*>
+}
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 5de83d9..2723b38 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -111,6 +111,7 @@
"//iree/compiler/Dialect/Modules/VMVX/Transforms",
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
+ "//iree/compiler/Dialect/Stream/IR",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/Util/Transforms",
"//iree/compiler/Dialect/VM/Analysis",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index a884884..fc29931 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -208,6 +208,7 @@
iree::compiler::Dialect::Modules::VMVX::Transforms
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
+ iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::Analysis
diff --git a/iree/tools/init_iree_dialects.h b/iree/tools/init_iree_dialects.h
index c99dd22..b374934 100644
--- a/iree/tools/init_iree_dialects.h
+++ b/iree/tools/init_iree_dialects.h
@@ -20,6 +20,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
#include "iree/compiler/Dialect/Modules/VMVX/IR/VMVXDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
@@ -34,11 +35,12 @@
registry.insert<IREE::Flow::FlowDialect,
IREE::HAL::HALDialect,
ShapeDialect,
+ IREE::Stream::StreamDialect,
IREE::Util::UtilDialect,
- linalg_ext::LinalgExtDialect,
IREE::VM::VMDialect,
IREE::VMVX::VMVXDialect,
IREE::Vulkan::VulkanDialect,
+ linalg_ext::LinalgExtDialect,
mlir::iree::IREEDialect,
mlir::iree_pydm::IREEPyDMDialect>();
// clang-format on