Adding the head of the stream dialect conversion pipeline. (#7526)
Fixing the breakages from upstream dialect conversion changes.
See #7439 / #7520.
This reverts commit 80a3c85474aafaacb8a2e2c6faf698e9c63154c6.
diff --git a/iree/compiler/Dialect/Stream/Conversion/BUILD b/iree/compiler/Dialect/Stream/Conversion/BUILD
new file mode 100644
index 0000000..0a187e7
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/BUILD
@@ -0,0 +1,27 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Conversion",
+ srcs = [
+ "PatternUtils.cpp",
+ ],
+ hdrs = [
+ "PatternUtils.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Stream/IR",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
new file mode 100644
index 0000000..ed6cbea
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
@@ -0,0 +1,28 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Conversion
+ HDRS
+ "PatternUtils.h"
+ SRCS
+ "PatternUtils.cpp"
+ DEPS
+ MLIRIR
+ MLIRStandard
+ MLIRTransforms
+ iree::compiler::Dialect::Stream::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD
new file mode 100644
index 0000000..d317206
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD
@@ -0,0 +1,31 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "FlowToStream",
+ srcs = [
+ "ConvertFlowToStream.cpp",
+ ],
+ hdrs = [
+ "ConvertFlowToStream.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/Stream/Conversion",
+ "//iree/compiler/Dialect/Stream/IR",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:TensorDialect",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt
new file mode 100644
index 0000000..d3b2d39
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt
@@ -0,0 +1,32 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ FlowToStream
+ HDRS
+ "ConvertFlowToStream.h"
+ SRCS
+ "ConvertFlowToStream.cpp"
+ DEPS
+ MLIRIR
+ MLIRStandard
+ MLIRTensor
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::Stream::Conversion
+ iree::compiler::Dialect::Stream::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
new file mode 100644
index 0000000..bcb1fb8
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
@@ -0,0 +1,467 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.h"
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Inserts a sizeof calculation for the given tensor value type and dims.
+// This should only be used to produce sizes for values produced by an op; the
+// size of operands must be queried from the input resource.
+static Value buildResultSizeOf(Location loc, Value tensorValue,
+ ValueRange dynamicDims,
+ ConversionPatternRewriter &rewriter) {
+ // TODO(benvanik): see if we can stash this on the side to avoid expensive
+ // materialization of a bunch of redundant IR.
+ return rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ loc, rewriter.getIndexType(), TypeAttr::get(tensorValue.getType()),
+ dynamicDims, /*affinity=*/nullptr);
+}
+
+// hal.tensor.cast is inserted by frontends to ensure that ABI types are HAL
+// buffer views. We need to map those to the stream import/export equivalents as
+// the cast has special meaning when we are dealing with asynchronous values.
+//
+// %1 = hal.tensor.cast %0 : !hal.buffer_view -> tensor<4xf32>
+// ->
+// %1 = stream.tensor.import %0 : !hal.buffer_view ->
+// tensor<4xf32> in !stream.resource<*>
+struct ConvertHALTensorCastOp
+ : public OpConversionPattern<IREE::HAL::TensorCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::TensorCastOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::HAL::TensorCastOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+ if (op.source().getType().isa<IREE::HAL::BufferViewType>()) {
+ // Import (buffer view to stream resource).
+ auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ auto resultSize = buildResultSizeOf(op.getLoc(), op.target(),
+ operands.target_dims(), rewriter);
+ auto newOp = rewriter.create<IREE::Stream::TensorImportOp>(
+ op.getLoc(), resultType, operands.source(),
+ TypeAttr::get(op.target().getType()), operands.target_dims(),
+ resultSize,
+ /*affinity=*/nullptr);
+
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ op, unknownType, newOp.result(), resultSize, resultSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+
+ } else if (op.target().getType().isa<IREE::HAL::BufferViewType>()) {
+ auto source =
+ consumeTensorOperand(op.getLoc(), operands.source(), rewriter);
+ auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ auto exportSource = operands.source();
+ if (source.resource.getType() != externalType) {
+ exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), externalType, source.resource, source.resourceSize,
+ source.resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ }
+
+ // Export (stream resource to buffer view).
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>(
+ op, op.target().getType(), exportSource,
+ TypeAttr::get(op.source().getType()), operands.source_dims(),
+ source.resourceSize,
+ /*affinity=*/nullptr);
+ } else {
+ return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
+ }
+ return success();
+ }
+};
+
+// Reshapes become clones here to preserve shape information (which may become
+// actual transfers depending on source/target shape) - they'll be elided if not
+// needed.
+struct ConvertTensorReshapeOp
+ : public OpConversionPattern<IREE::Flow::TensorReshapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorReshapeOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorReshapeOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto source =
+ consumeTensorOperand(op.getLoc(), operands.source(), rewriter);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.result(), op.result_dims(), rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
+ op, unknownType, source.resource, op.source().getType(),
+ op.source_dims(), source.resourceSize, op.result().getType(),
+ operands.result_dims(), resultSize,
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+struct ConvertTensorSplatOp
+ : public OpConversionPattern<IREE::Flow::TensorSplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorSplatOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorSplatOpAdaptor operands(newOperands);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.result(), op.result_dims(), rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorSplatOp>(
+ op, unknownType, operands.value(), op.result().getType(),
+ operands.result_dims(), resultSize,
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+struct ConvertTensorCloneOp
+ : public OpConversionPattern<IREE::Flow::TensorCloneOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorCloneOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorCloneOpAdaptor operands(newOperands);
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto operand =
+ consumeTensorOperand(op.getLoc(), operands.operand(), rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorCloneOp>(
+ op, unknownType, operand.resource, op.operand().getType(),
+ op.operand_dims(), operand.resourceSize, op.result().getType(),
+ operands.operand_dims(), operand.resourceSize,
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+struct ConvertTensorSliceOp
+ : public OpConversionPattern<IREE::Flow::TensorSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorSliceOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorSliceOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto source =
+ consumeTensorOperand(op.getLoc(), operands.source(), rewriter);
+ auto resultSize =
+ buildResultSizeOf(op.getLoc(), op.result(), op.result_dims(), rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorSliceOp>(
+ op, unknownType, source.resource, op.source().getType(),
+ op.source_dims(), source.resourceSize, operands.start_indices(),
+ operands.lengths(), op.result().getType(), operands.result_dims(),
+ resultSize,
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+struct ConvertTensorUpdateOp
+ : public OpConversionPattern<IREE::Flow::TensorUpdateOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorUpdateOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorUpdateOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+ auto update =
+ consumeTensorOperand(op.getLoc(), operands.update(), rewriter);
+ auto target =
+ consumeTensorOperand(op.getLoc(), operands.target(), rewriter);
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorUpdateOp>(
+ op, target.resource.getType(), target.resource, op.target().getType(),
+ operands.target_dims(), target.resourceSize, operands.start_indices(),
+ update.resource, op.update().getType(), op.update_dims(),
+ update.resourceSize,
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+struct ConvertTensorLoadOp
+ : public OpConversionPattern<IREE::Flow::TensorLoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorLoadOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorLoadOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+ auto resultType = getTypeConverter()->convertType(op.result().getType());
+ auto source =
+ consumeTensorOperand(op.getLoc(), operands.source(), rewriter);
+
+ auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Staging);
+ auto loadSource = source.resource;
+ if (source.resource.getType() != stagingType) {
+ loadSource = rewriter.createOrFold<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), stagingType, source.resource, source.resourceSize,
+ source.resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorLoadOp>(
+ op, resultType, loadSource, op.source().getType(), op.source_dims(),
+ source.resourceSize, operands.indices());
+ return success();
+ }
+};
+
+struct ConvertTensorStoreOp
+ : public OpConversionPattern<IREE::Flow::TensorStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::TensorStoreOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::TensorStoreOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+ auto target =
+ consumeTensorOperand(op.getLoc(), operands.target(), rewriter);
+
+ auto stagingType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::Staging);
+ auto storeTarget = target.resource;
+ if (target.resource.getType() != stagingType) {
+ storeTarget = rewriter.createOrFold<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), stagingType, storeTarget, target.resourceSize,
+ target.resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ }
+
+ auto newOp = rewriter.create<IREE::Stream::TensorStoreOp>(
+ op.getLoc(), storeTarget.getType(), storeTarget, op.target().getType(),
+ operands.target_dims(), target.resourceSize, operands.indices(),
+ operands.value());
+
+ Value newResult = newOp.result();
+ if (target.resource.getType() != stagingType) {
+ newResult = rewriter.createOrFold<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), target.resource.getType(), newResult,
+ target.resourceSize, target.resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ }
+ rewriter.replaceOp(op, {newResult});
+
+ return success();
+ }
+};
+
+struct ConvertDispatchOp : public OpConversionPattern<IREE::Flow::DispatchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::DispatchOp op, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Flow::DispatchOpAdaptor operands(newOperands,
+ op->getAttrDictionary());
+
+ // Query and resolve all operands and their sizes.
+ SmallVector<Value> dispatchOperands;
+ SmallVector<Value> dispatchOperandSizes;
+ for (auto oldNewOperand : llvm::zip(op.operands(), operands.operands())) {
+ auto oldOperand = std::get<0>(oldNewOperand);
+ auto newOperand = std::get<1>(oldNewOperand);
+ if (oldOperand.getType().isa<ShapedType>()) {
+ auto newOperandCast =
+ consumeTensorOperand(op.getLoc(), newOperand, rewriter);
+ newOperand = newOperandCast.resource;
+ dispatchOperandSizes.push_back(newOperandCast.resourceSize);
+ }
+ dispatchOperands.push_back(newOperand);
+ }
+
+ // Construct result sizes or reuse tied operand sizes from above.
+ SmallVector<Value> resultSizes;
+ SmallVector<Type> resultTypes;
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ auto tiedOperandBase = op.getTiedOperandsIndexAndLength().first;
+ for (auto result : llvm::enumerate(op.results())) {
+ auto oldResultType = result.value().getType();
+ if (!oldResultType.isa<ShapedType>()) {
+ resultTypes.push_back(getTypeConverter()->convertType(oldResultType));
+ continue;
+ }
+ auto tiedOperand = op.getTiedResultOperandIndex(result.index());
+ if (tiedOperand.hasValue()) {
+ auto operandIndex = tiedOperand.getValue() - tiedOperandBase;
+ resultSizes.push_back(dispatchOperandSizes[operandIndex]);
+ resultTypes.push_back(dispatchOperands[operandIndex].getType());
+ } else {
+ auto resultDynamicDims = Shape::buildOrFindDynamicDimsForValue(
+ op.getLoc(), result.value(), rewriter);
+ resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
+ resultDynamicDims, rewriter));
+ resultTypes.push_back(unknownType);
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncDispatchOp>(
+ op, resultTypes, operands.workgroup_count(), operands.entry_point(),
+ dispatchOperands, dispatchOperandSizes, resultSizes,
+ operands.tied_operands(),
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+static SmallVector<Value> makeBindingDynamicDims(
+ Location loc, IREE::Flow::DispatchTensorType tensorType, BlockArgument arg,
+ OpBuilder &builder) {
+ if (tensorType.hasStaticShape()) return {};
+
+ // We can expect its first user to be a tie_shape op to associate
+ // concrete dimension values. Originally we have such information
+ // maintained in the flow ops handling dynamic tensors. But during
+ // flow executable outlining, such information is transfered to
+ // tie_shape ops.
+ //
+ // HACK: this is disgusting - we should carry this from the flow level in the
+ // right way such that we don't need to make this assumption.
+ auto tieShapeOp = dyn_cast<IREE::Flow::DispatchTieShapeOp>(*arg.user_begin());
+ assert(tieShapeOp && "missing flow tie shape for dynamic value");
+ builder.setInsertionPointAfter(tieShapeOp.shape().getDefiningOp());
+
+ // Get the SSA values for all dynamic dimensions.
+ SmallVector<Value> dynamicDims;
+ dynamicDims.reserve(tensorType.getNumDynamicDims());
+ for (int i = 0; i < tensorType.getRank(); ++i) {
+ if (!tensorType.isDynamicDim(i)) continue;
+ dynamicDims.push_back(builder.create<Shape::RankedDimOp>(
+ tieShapeOp.getLoc(), tieShapeOp.shape(), i));
+ }
+ assert(dynamicDims.size() == tensorType.getNumDynamicDims());
+
+ return dynamicDims;
+}
+
+struct ConvertExecutableOp
+ : public OpConversionPattern<IREE::Flow::ExecutableOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Flow::ExecutableOp flowOp, ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ // flow.executable -> stream.executable
+ auto streamOp = rewriter.create<IREE::Stream::ExecutableOp>(
+ flowOp.getLoc(), flowOp.sym_name());
+ streamOp.setVisibility(flowOp.getVisibility());
+ streamOp->setDialectAttrs(flowOp->getDialectAttrs());
+ rewriter.setInsertionPointToStart(&streamOp.body().front());
+
+ // flow.dispatch.entry -> stream.executable.entry_point
+ for (auto entryOp : flowOp.getOps<IREE::Flow::DispatchEntryOp>()) {
+ auto newOp = rewriter.create<IREE::Stream::ExecutableExportOp>(
+ entryOp.getLoc(), entryOp.sym_name(), entryOp.function_refAttr());
+ newOp->setDialectAttrs(entryOp->getDialectAttrs());
+ }
+
+ // Move the original nested module body into the new executable directly.
+ auto moduleOp = rewriter.cloneWithoutRegions(flowOp.getInnerModule());
+ streamOp.getInnerModule().body().takeBody(flowOp.getInnerModule().body());
+
+ // Update the entry point signatures in the module.
+ // Dispatch tensor arguments become bindings and all others are preserved as
+ // operands. Note that we only touch public (exported) functions.
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ if (!funcOp.isPublic()) continue;
+
+ SmallVector<Type> newTypes;
+ newTypes.reserve(funcOp.getNumArguments());
+ assert(funcOp.getNumResults() == 0 && "flow dispatches have no results");
+
+ rewriter.setInsertionPointToStart(&funcOp.front());
+ auto zero = rewriter.create<arith::ConstantIndexOp>(funcOp.getLoc(), 0);
+ for (auto arg : funcOp.front().getArguments()) {
+ auto oldType = arg.getType();
+ if (auto tensorType =
+ oldType.dyn_cast<IREE::Flow::DispatchTensorType>()) {
+ // Now a binding.
+ auto newType = rewriter.getType<IREE::Stream::BindingType>();
+ newTypes.push_back(newType);
+ auto dynamicDims =
+ makeBindingDynamicDims(arg.getLoc(), tensorType, arg, rewriter);
+ auto subspanOp = rewriter.create<IREE::Stream::BindingSubspanOp>(
+ arg.getLoc(), tensorType, arg, zero, dynamicDims);
+ arg.replaceAllUsesExcept(subspanOp.result(), subspanOp);
+ arg.setType(newType);
+ } else {
+ // Preserved - will eventually be a push constants.
+ newTypes.push_back(oldType);
+ }
+ }
+
+ funcOp.setType(rewriter.getFunctionType(newTypes, {}));
+ }
+
+ rewriter.replaceOp(flowOp, {});
+ return success();
+ }
+};
+
+} // namespace
+
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ typeConverter.addConversion(
+ [](IREE::HAL::BufferViewType type) { return type; });
+ patterns.insert<ConvertHALTensorCastOp>(typeConverter, context);
+
+ patterns
+ .insert<ConvertTensorReshapeOp, ConvertTensorSplatOp,
+ ConvertTensorCloneOp, ConvertTensorSliceOp, ConvertTensorUpdateOp,
+ ConvertTensorLoadOp, ConvertTensorStoreOp>(typeConverter,
+ context);
+ patterns.insert<ConvertDispatchOp>(typeConverter, context);
+ patterns.insert<ConvertExecutableOp>(typeConverter, context);
+}
+
+void populateFlowToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ // Disallow all flow ops besides the ones we pass through (today).
+ // We don't have a stream-equivalent of several of the dispatch-level flow
+ // ops as the codegen backends directly touch them and so long as we have both
+ // paths we can't cut over. Once we convert the flow.executable to a
+ // stream.executable we ignore the contents and cross our fingers.
+ conversionTarget.addIllegalDialect<IREE::Flow::FlowDialect>();
+ conversionTarget.addLegalOp<IREE::Stream::ExecutableOp>();
+ conversionTarget.markOpRecursivelyLegal<IREE::Stream::ExecutableOp>();
+
+ conversionTarget.addDynamicallyLegalOp<IREE::HAL::TensorCastOp>(
+ [&](IREE::HAL::TensorCastOp op) {
+ return typeConverter.isLegal(op.source().getType()) &&
+ typeConverter.isLegal(op.target().getType());
+ });
+
+ populateFlowToStreamConversionPatterns(context, typeConverter, patterns);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.h b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.h
new file mode 100644
index 0000000..0ea90e6
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.h
@@ -0,0 +1,31 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_FLOWTOSTREAM_CONVERTFLOWTOSTREAM_H_
+#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_FLOWTOSTREAM_CONVERTFLOWTOSTREAM_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Populates conversion patterns that perform flow->stream conversion.
+// These patterns ensure that nested types are run through the provided
+// |typeConverter|.
+void populateFlowToStreamConversionPatterns(MLIRContext *context,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+void populateFlowToStreamConversionPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_CONVERSION_FLOWTOSTREAM_CONVERTFLOWTOSTREAM_H_
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
new file mode 100644
index 0000000..1f3d342
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
@@ -0,0 +1,30 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "cast_ops.mlir",
+ "dispatch_ops.mlir",
+ "tensor_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
new file mode 100644
index 0000000..1974566
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
@@ -0,0 +1,25 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "cast_ops.mlir"
+ "dispatch_ops.mlir"
+ "tensor_ops.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/cast_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/cast_ops.mlir
new file mode 100644
index 0000000..17f9a6a
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/cast_ops.mlir
@@ -0,0 +1,34 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @importBufferView
+// CHECK-SAME: (%[[VIEW:.+]]: !hal.buffer_view)
+// CHECK-SAME: -> (!stream.resource<*>, index)
+func @importBufferView(%view: !hal.buffer_view) -> tensor<?x?x4xf32> {
+ // CHECK-DAG: %[[DIM0:.+]] = hal.buffer_view.dim{{.+}}[0]
+ %dim0 = hal.buffer_view.dim<%view : !hal.buffer_view>[0] : index
+ // CHECK-DAG: %[[DIM1:.+]] = hal.buffer_view.dim{{.+}}[1]
+ %dim1 = hal.buffer_view.dim<%view : !hal.buffer_view>[1] : index
+ // CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} : index
+ // CHECK: %[[RESOURCE:.+]] = stream.tensor.import %[[VIEW]] : !hal.buffer_view ->
+ // CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
+ // CHECK-NEXT: %[[RESULT:.+]] = stream.async.transfer %[[RESOURCE]] :
+ // CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+ %0 = hal.tensor.cast %view : !hal.buffer_view -> tensor<?x?x4xf32>{%dim0, %dim1}
+ // CHECK: return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
+ return %0 : tensor<?x?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @exportBufferView
+// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func @exportBufferView(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: index) -> !hal.buffer_view {
+ // CHECK: %[[VIEW:.+]] = stream.async.transfer %[[TENSOR]] :
+ // CHECK-SAME: !stream.resource<*>{%[[SIZE]]} -> !stream.resource<external>{%[[SIZE]]}
+ // CHECK-NEXT: %[[RESULT:.+]] = stream.tensor.export %[[VIEW]] :
+ // CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
+ // CHECK-SAME: -> !hal.buffer_view
+ %0 = hal.tensor.cast %tensor : tensor<?x?x4xf32>{%dim0, %dim1} -> !hal.buffer_view
+ // CHECK: return %[[RESULT]]
+ return %0 : !hal.buffer_view
+}
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
new file mode 100644
index 0000000..af4204b
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir
@@ -0,0 +1,32 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @dispatch
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
+func @dispatch(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> tensor<?x?x1024xf32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ // CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
+ // CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry[%c1, %c2, %c3](%[[INPUT]]) :
+ // CHECK-SAME: (!stream.resource<*>{%[[INPUT_SIZE]]}) -> !stream.resource<*>{%[[RESULT_SIZE]]}
+ %0 = flow.dispatch @ex::@entry[%c1, %c2, %c3](%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
+ // return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
+ return %0 : tensor<?x?x1024xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tiedDispatch
+// CHECK-SAME: (%[[INPUT0:.+]]: !stream.resource<*>, %[[INPUT0_SIZE:.+]]: index, %[[INPUT1:.+]]: !stream.resource<*>, %[[INPUT1_SIZE:.+]]: index)
+func @tiedDispatch(%input0: tensor<i32>, %input1: tensor<2x3xi32>) -> tensor<3x9xi32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ // CHECK: %[[T_SIZE:.+]] = stream.tensor.sizeof tensor<3x9xi32> : index
+ // CHECK: %[[T:.+]] = stream.async.dispatch @ex::@entry0[%c1, %c2, %c3](%[[INPUT0]]) : (!stream.resource<*>{%[[INPUT0_SIZE]]}) -> !stream.resource<*>{%[[T_SIZE]]}
+ %0 = flow.dispatch @ex::@entry0[%c1, %c2, %c3](%input0) : (tensor<i32>) -> tensor<3x9xi32>
+ // CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry1[%c1, %c2, %c3](%[[INPUT1]], %[[T]]) : (!stream.resource<*>{%[[INPUT1_SIZE]]}, !stream.resource<*>{%[[T_SIZE]]}) -> %[[T]]{%[[T_SIZE]]}
+ %1 = flow.dispatch @ex::@entry1[%c1, %c2, %c3](%input1, %0) : (tensor<2x3xi32>, tensor<3x9xi32>) -> %0
+ // CHECK: return %[[RESULT]], %[[T_SIZE]] : !stream.resource<*>, index
+ return %1 : tensor<3x9xi32>
+}
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
new file mode 100644
index 0000000..629232f
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -0,0 +1,117 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion %s | IreeFileCheck %s
+
+// CHECK-LABEL: @tensorReshapePassThrough
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
+func @tensorReshapePassThrough(%input: tensor<5x24x48xf32>) -> tensor<30x2x96xf32> {
+ // CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<30x2x96xf32> : index
+ // CHECK: %[[RESULT:.+]] = stream.tensor.clone %[[INPUT]] : tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]} -> tensor<30x2x96xf32> in !stream.resource<*>{%[[RESULT_SIZE]]}
+ %0 = flow.tensor.reshape %input : tensor<5x24x48xf32> -> tensor<30x2x96xf32>
+ // CHECK: return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
+ return %0 : tensor<30x2x96xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorReshapeWithSingleUse
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
+func @tensorReshapeWithSingleUse(%input: tensor<5x24x48xf32>) -> tensor<30x2x96xf32> {
+ // CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<30x2x96xf32> : index
+ // CHECK: %[[RESHAPE:.+]] = stream.tensor.clone %[[INPUT]] : tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]} -> tensor<30x2x96xf32> in !stream.resource<*>{%[[RESULT_SIZE]]}
+ %0 = flow.tensor.reshape %input : tensor<5x24x48xf32> -> tensor<30x2x96xf32>
+ // CHECK: %[[RESULT:.+]] = stream.tensor.clone %[[RESHAPE]] : tensor<30x2x96xf32> in !stream.resource<*>{%[[RESULT_SIZE]]} -> tensor<30x2x96xf32> in !stream.resource<*>{%[[RESULT_SIZE]]}
+ %1 = flow.tensor.clone %0 : tensor<30x2x96xf32>
+ // CHECK: return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
+ return %1 : tensor<30x2x96xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorReshapeWithMultipleUses
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
+func @tensorReshapeWithMultipleUses(%input: tensor<5x24x48xf32>)
+ -> (tensor<60x2x48xf32>, tensor<30x2x96xf32>) {
+ // CHECK: %[[T0:.+]] = stream.tensor.clone %[[INPUT]] : tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]} -> tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]}
+ %1 = flow.tensor.clone %input : tensor<5x24x48xf32>
+ // CHECK: %[[T1_SIZE:.+]] = stream.tensor.sizeof tensor<60x2x48xf32> : index
+ // CHECK: %[[T1:.+]] = stream.tensor.clone %[[INPUT]] : tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]} -> tensor<60x2x48xf32> in !stream.resource<*>{%[[T1_SIZE]]}
+ %2 = flow.tensor.reshape %input : tensor<5x24x48xf32> -> tensor<60x2x48xf32>
+ // CHECK: %[[T2:.+]] = stream.tensor.clone %[[T1]] : tensor<60x2x48xf32> in !stream.resource<*>{%[[T1_SIZE]]} -> tensor<60x2x48xf32> in !stream.resource<*>{%[[T1_SIZE]]}
+ %3 = flow.tensor.clone %2 : tensor<60x2x48xf32>
+ // CHECK: %[[T3_SIZE:.+]] = stream.tensor.sizeof tensor<30x2x96xf32> : index
+ // CHECK: %[[T3:.+]] = stream.tensor.clone %[[T0]] : tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]} -> tensor<30x2x96xf32> in !stream.resource<*>{%[[T3_SIZE]]}
+ %4 = flow.tensor.reshape %1 : tensor<5x24x48xf32> -> tensor<30x2x96xf32>
+ // CHECK: return %[[T2]], %[[T1_SIZE]], %[[T3]], %[[T3_SIZE]] : !stream.resource<*>, index, !stream.resource<*>, index
+ return %3, %4 : tensor<60x2x48xf32>, tensor<30x2x96xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSplat
+// CHECK-SAME: (%[[VALUE:.+]]: i8, %[[DIM0:.+]]: index)
+func @tensorSplat(%value: i8, %dim0: index) -> tensor<?x128xi8> {
+ // CHECK: %[[T_SIZE:.+]] = stream.tensor.sizeof tensor<?x128xi8>{%[[DIM0]]} : index
+ // CHECK: %[[T:.+]] = stream.tensor.splat %[[VALUE]] : i8 -> tensor<?x128xi8>{%[[DIM0]]} in !stream.resource<*>{%[[T_SIZE]]}
+ %0 = flow.tensor.splat %value : tensor<?x128xi8>{%dim0}
+ // CHECK: return %[[T]], %[[T_SIZE]]
+ return %0 : tensor<?x128xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSlice
+// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index)
+func @tensorSlice(%input : tensor<5x24x48xf32>) -> tensor<3x24x48xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c24 = arith.constant 24 : index
+ %c48 = arith.constant 48 : index
+ // CHECK: %[[T_SIZE:.+]] = stream.tensor.sizeof tensor<3x24x48xf32> : index
+ // CHECK: %[[T:.+]] = stream.tensor.slice %[[INPUT]][%c2, %c0, %c0 for %c3, %c24, %c48] : tensor<5x24x48xf32> in !stream.resource<*>{%[[INPUT_SIZE]]} -> tensor<3x24x48xf32> in !stream.resource<*>{%[[T_SIZE]]}
+ %0 = flow.tensor.slice %input[%c2, %c0, %c0 for %c3, %c24, %c48] : tensor<5x24x48xf32> -> tensor<3x24x48xf32>
+ // CHECK: return %[[T]], %[[T_SIZE]] : !stream.resource<*>, index
+ return %0 : tensor<3x24x48xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorUpdate
+// CHECK-SAME: (%[[UPDATE:.+]]: !stream.resource<*>, %[[UPDATE_SIZE:.+]]: index, %[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index)
+func @tensorUpdate(%update : tensor<1x1x10xf32>, %target : tensor<5x1x10xf32>) -> tensor<5x1x10xf32> {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ // CHECK: %[[T:.+]] = stream.tensor.update %[[UPDATE]], %[[TARGET]][%c4, %c1, %c1] : tensor<1x1x10xf32> in !stream.resource<*>{%[[UPDATE_SIZE]]} -> tensor<5x1x10xf32> in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]}
+ %0 = flow.tensor.update %update, %target[%c4, %c1, %c1] : tensor<1x1x10xf32> -> %target as tensor<5x1x10xf32>
+ // CHECK: return %[[T]], %[[TARGET_SIZE]] : !stream.resource<*>, index
+ return %0 : tensor<5x1x10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorLoad
+// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<*>, %[[SOURCE_SIZE:.+]]: index)
+func @tensorLoad(%source : tensor<2x3xi32>) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[T0:.+]] = stream.async.transfer %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]} -> !stream.resource<staging>{%[[SOURCE_SIZE]]}
+ // CHECK: %[[T1:.+]] = stream.tensor.load %[[T0]][%c0, %c1] : tensor<2x3xi32> in !stream.resource<staging>{%[[SOURCE_SIZE]]} -> i32
+ %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32>
+ // CHECK: return %[[T1]]
+ return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @tensorStore
+// CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index)
+func @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c9 = arith.constant 9 : i32
+ // CHECK: %[[T0:.+]] = stream.async.transfer %[[TARGET]] : !stream.resource<*>{%[[TARGET_SIZE]]} -> !stream.resource<staging>{%[[TARGET_SIZE]]}
+ // CHECK: %[[T1:.+]] = stream.tensor.store %c9_i32, %[[T0]][%c0, %c1] : i32 -> tensor<2x3xi32> in %[[T0]] as !stream.resource<staging>{%[[TARGET_SIZE]]}
+ // CHECK: %[[T2:.+]] = stream.async.transfer %[[T1]] : !stream.resource<staging>{%[[TARGET_SIZE]]} -> !stream.resource<*>{%[[TARGET_SIZE]]}
+ %0 = flow.tensor.store %c9, %target[%c0, %c1] : tensor<2x3xi32>
+ // CHECK: return %[[T2]]
+ return %0 : tensor<2x3xi32>
+}
diff --git a/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
new file mode 100644
index 0000000..202040f
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
@@ -0,0 +1,46 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+ConvertedTensor consumeTensorOperand(Location loc, Value operand,
+ OpBuilder &builder) {
+ auto operandType = operand.getType();
+ if (operandType.isa<IREE::Stream::ResourceType>()) {
+ // Prior to https://reviews.llvm.org/D111620 this is the path we'd take;
+ // the tensor operands would be remapped into their new resource types.
+ // This is still possible during rewriting if we ourselves produce a new
+ // resource type, but the automatic materialization will go down the
+ // unrealized_conversion_cast path below.
+ return {
+ operand,
+ builder.createOrFold<IREE::Stream::ResourceSizeOp>(
+ loc, builder.getIndexType(), operand),
+ };
+ } else if (auto castOp =
+ operand.getDefiningOp<mlir::UnrealizedConversionCastOp>()) {
+ // We only have a single tensor type conversion and it expands to (resource,
+ // size) so that's all we look for here.
+ assert(castOp.getNumOperands() == 2 && "expected (resource, size)");
+ return {
+ castOp.getOperand(0),
+ castOp.getOperand(1),
+ };
+ }
+ llvm_unreachable(
+ "unexpected operand; should have either been converted from tensor or "
+ "not");
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
new file mode 100644
index 0000000..996efd2
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
@@ -0,0 +1,35 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
+#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// https://reviews.llvm.org/D111620 broke 1->N type expansion during dialect
+// conversion. It inserts unrealized_conversion_casts but then passes the
+// illegal source dialect types for pattern operands, meaning that even though
+// we say tensors are illegal the patterns get the new remapped values as
+// tensors. This, naturally, breaks everything. To work around this we have this
+// helper that tries to peek through the unrealized_conversion_casts and get out
+// the actual values we expected to see from the conversion (and did before that
+// change).
+struct ConvertedTensor {
+ Value resource;
+ Value resourceSize;
+};
+ConvertedTensor consumeTensorOperand(Location loc, Value operand,
+ OpBuilder &builder);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD
new file mode 100644
index 0000000..cf4978f
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD
@@ -0,0 +1,36 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "StandardToStream",
+ srcs = [
+ "ConvertConstantOps.cpp",
+ "ConvertStandardToStream.cpp",
+ "ConvertStructuralOps.cpp",
+ ],
+ hdrs = [
+ "ConvertStandardToStream.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Stream/Conversion",
+ "//iree/compiler/Dialect/Stream/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
new file mode 100644
index 0000000..3cea8dd
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt
@@ -0,0 +1,37 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ StandardToStream
+ HDRS
+ "ConvertStandardToStream.h"
+ SRCS
+ "ConvertConstantOps.cpp"
+ "ConvertStandardToStream.cpp"
+ "ConvertStructuralOps.cpp"
+ DEPS
+ LLVMSupport
+ MLIRArithmetic
+ MLIRIR
+ MLIRMemRef
+ MLIRPass
+ MLIRShape
+ MLIRStandard
+ MLIRTensor
+ MLIRTransforms
+ iree::compiler::Dialect::Stream::Conversion
+ iree::compiler::Dialect::Stream::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
new file mode 100644
index 0000000..2a69d27
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
@@ -0,0 +1,63 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+struct ConvertTensorConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ arith::ConstantOp constantOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle tensor types - other arith.constant types (like i32) are
+ // ignored.
+ if (!constantOp.getType().isa<TensorType>()) return failure();
+
+ Type constantType = IREE::Stream::ResourceType::get(
+ getContext(), IREE::Stream::Lifetime::Constant);
+ auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
+ constantOp.getLoc(), constantType,
+ constantOp.value().cast<ElementsAttr>(),
+ TypeAttr::get(constantOp.getType()),
+ /*result_encoding_dims=*/ValueRange{},
+ /*affinity=*/nullptr);
+
+ Type unknownType = IREE::Stream::ResourceType::get(getContext());
+ auto constantSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+ constantOp.getLoc(), rewriter.getIndexType(), newOp.result());
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ constantOp, unknownType, newOp.result(), constantSize, constantSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateStandardConstantToStreamPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ conversionTarget.addDynamicallyLegalOp<arith::ConstantOp>(
+ [](arith::ConstantOp op) { return !op.getType().isa<TensorType>(); });
+
+ patterns.insert<ConvertTensorConstantOp>(typeConverter, context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.cpp
new file mode 100644
index 0000000..9534d53
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.cpp
@@ -0,0 +1,48 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h"
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+void populateStandardConstantToStreamPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns);
+
+void populateStandardStructuralToStreamPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns);
+
+void populateStandardToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ typeConverter.addConversion([](IndexType type) { return type; });
+ typeConverter.addConversion([](IntegerType type) { return type; });
+ typeConverter.addConversion([](FloatType type) { return type; });
+
+ // Ensure all shape related ops are fully converted as we should no longer
+ // have any types they are valid to be used on after this conversion.
+ conversionTarget.addIllegalOp<memref::DimOp>();
+ conversionTarget.addIllegalOp<mlir::RankOp>();
+ conversionTarget.addIllegalOp<tensor::DimOp>();
+
+ populateStandardConstantToStreamPatterns(context, conversionTarget,
+ typeConverter, patterns);
+ populateStandardStructuralToStreamPatterns(context, conversionTarget,
+ typeConverter, patterns);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h
new file mode 100644
index 0000000..12997c6
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h
@@ -0,0 +1,26 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_STANDARDTOSTREAM_CONVERTSTANDARDTOSTREAM_H_
+#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_STANDARDTOSTREAM_CONVERTSTANDARDTOSTREAM_H_
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Populates conversion patterns that perform standard/builtin->stream
+// conversion. These patterns ensure that nested types are run through the
+// provided |typeConverter|.
+void populateStandardToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_CONVERSION_STANDARDTOSTREAM_CONVERTSTANDARDTOSTREAM_H_
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp
new file mode 100644
index 0000000..6a050c5
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp
@@ -0,0 +1,263 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
+#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+struct FuncOpSignatureConversion : public OpConversionPattern<mlir::FuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::FuncOp funcOp, llvm::ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto &typeConverter = *getTypeConverter();
+
+ // Convert the input signature types.
+ // TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
+ auto originalType = funcOp.getType();
+ TypeConverter::SignatureConversion newSignature(
+ originalType.getNumInputs());
+ for (auto argType : llvm::enumerate(originalType.getInputs())) {
+ if (failed(typeConverter.convertSignatureArg(
+ argType.index(), argType.value(), newSignature))) {
+ return failure();
+ }
+ }
+ SmallVector<Type, 4> newResultTypes;
+ if (failed(typeConverter.convertTypes(originalType.getResults(),
+ newResultTypes))) {
+ return failure();
+ }
+
+ // Replace function.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ newFuncOp.getBlocks().clear();
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
+ newResultTypes));
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ &newSignature))) {
+ return failure();
+ }
+
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+
+static SmallVector<Value> expandResourceOperands(
+ Location loc, ValueRange operands, ConversionPatternRewriter &rewriter) {
+ SmallVector<Value> expandedOperands;
+ expandedOperands.reserve(operands.size());
+ for (auto operand : operands) {
+ if (operand.getType().isa<TensorType>()) {
+ auto value = consumeTensorOperand(loc, operand, rewriter);
+ expandedOperands.push_back(value.resource);
+ expandedOperands.push_back(value.resourceSize);
+ } else if (operand.getType().isa<IREE::Stream::ResourceType>()) {
+ expandedOperands.push_back(operand);
+ expandedOperands.push_back(
+ rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(loc, operand));
+ } else {
+ expandedOperands.push_back(operand);
+ }
+ }
+ return expandedOperands;
+}
+
+struct CallOpConversion : public OpConversionPattern<mlir::CallOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::CallOp op, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), newOperands, rewriter);
+
+ // Expand any resource results to resource + size.
+ SmallVector<Type> expandedTypes;
+ struct Result {
+ size_t originalIndex;
+ size_t newIndex;
+ Type newType;
+ };
+ SmallVector<Result> resultMap;
+ for (auto originalType : llvm::enumerate(op.getResultTypes())) {
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(originalType.value(),
+ newTypes))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to convert result types");
+ }
+ resultMap.push_back(
+ Result{originalType.index(), expandedTypes.size(), newTypes.front()});
+ expandedTypes.append(newTypes);
+ }
+
+ // Create a new call that takes the expanded input operands and returns the
+ // expanded output results. We can't directly replace the original call as
+ // the result counts differ.
+ auto callOp = rewriter.create<mlir::CallOp>(op.getLoc(), expandedTypes,
+ op.callee(), expandedOperands);
+
+ // Tie all resource results together so we end up with 1:1 results with the
+ // original op.
+ SmallVector<Value> results;
+ for (auto result : resultMap) {
+ if (result.newType.isa<IREE::Stream::ResourceType>()) {
+ auto oldType = op.getResult(result.originalIndex).getType();
+ auto resource = callOp.getResult(result.newIndex + 0);
+ auto resourceSize = callOp.getResult(result.newIndex + 1);
+ results.push_back(rewriter
+ .create<mlir::UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{oldType},
+ ValueRange{resource, resourceSize})
+ .getResult(0));
+ } else {
+ results.push_back(callOp.getResult(result.newIndex));
+ }
+ }
+ rewriter.replaceOp(op, results);
+
+ return success();
+ }
+};
+
+struct ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::ReturnOp op, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), newOperands, rewriter);
+ rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op, expandedOperands);
+ return success();
+ }
+};
+
+struct BranchOpConversion : public OpConversionPattern<mlir::BranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::BranchOp op, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::BranchOpAdaptor operands(newOperands);
+ // Expand any resource operands to resource + size.
+ auto expandedOperands =
+ expandResourceOperands(op.getLoc(), operands.destOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.dest(),
+ expandedOperands);
+ return success();
+ }
+};
+
+struct CondBranchOpConversion : public OpConversionPattern<mlir::CondBranchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::CondBranchOp op, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::CondBranchOpAdaptor operands(newOperands,
+ op.getOperation()->getAttrDictionary());
+ // Expand any resource operands to resource + size.
+ auto trueDestOperands = expandResourceOperands(
+ op.getLoc(), operands.trueDestOperands(), rewriter);
+ auto falseDestOperands = expandResourceOperands(
+ op.getLoc(), operands.falseDestOperands(), rewriter);
+ rewriter.replaceOpWithNewOp<mlir::CondBranchOp>(
+ op, operands.condition(), op.trueDest(), trueDestOperands,
+ op.falseDest(), falseDestOperands);
+ return success();
+ }
+};
+
+struct SelectOpConversion : public OpConversionPattern<mlir::SelectOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ mlir::SelectOp op, mlir::SelectOp::Adaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle selects where the operands are tensors (resources).
+ if (!op.true_value().getType().isa<TensorType>()) return failure();
+ auto trueOperand =
+ consumeTensorOperand(op.getLoc(), operands.true_value(), rewriter);
+ auto falseOperand =
+ consumeTensorOperand(op.getLoc(), operands.false_value(), rewriter);
+ auto resourceSelectOp = rewriter.create<mlir::SelectOp>(
+ op.getLoc(), operands.condition(), trueOperand.resource,
+ falseOperand.resource);
+ auto sizeSelectOp = rewriter.create<mlir::SelectOp>(
+ op.getLoc(), operands.condition(), trueOperand.resourceSize,
+ falseOperand.resourceSize);
+ rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
+ op, operands.true_value().getType(),
+ ValueRange{resourceSelectOp.result(), sizeSelectOp.result()});
+ return success();
+ }
+};
+
+} // namespace
+
+void populateStandardStructuralToStreamPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ conversionTarget.addLegalOp<mlir::ModuleOp>();
+
+ // We need to rewrite certain types on operands/results so use the default
+ // dynamic legality checker to force any ops using such types to run through
+ // our patterns.
+
+ conversionTarget.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ conversionTarget.addDynamicallyLegalOp<mlir::CallOp>([&](mlir::CallOp op) {
+ return llvm::all_of(
+ op.getOperandTypes(),
+ [&](Type type) { return typeConverter.isLegal(type); }) &&
+ llvm::all_of(op.getResultTypes(),
+ [&](Type type) { return typeConverter.isLegal(type); });
+ });
+ conversionTarget.addDynamicallyLegalOp<mlir::ReturnOp>(
+ [&](mlir::ReturnOp op) {
+ return llvm::all_of(op.getOperandTypes(), [&](Type type) {
+ return typeConverter.isLegal(type);
+ });
+ });
+
+ conversionTarget.addDynamicallyLegalOp<mlir::BranchOp>(
+ [&](mlir::BranchOp op) {
+ return llvm::all_of(op.getOperandTypes(), [&](Type type) {
+ return typeConverter.isLegal(type);
+ });
+ });
+ conversionTarget.addDynamicallyLegalOp<mlir::CondBranchOp>(
+ [&](mlir::CondBranchOp op) {
+ return llvm::all_of(op.getOperandTypes(), [&](Type type) {
+ return typeConverter.isLegal(type);
+ });
+ });
+
+ patterns
+ .insert<FuncOpSignatureConversion, CallOpConversion, ReturnOpConversion,
+ BranchOpConversion, CondBranchOpConversion, SelectOpConversion>(
+ typeConverter, context);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD
new file mode 100644
index 0000000..ca9fdbc
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "constant_ops.mlir",
+ "structural_ops.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/CMakeLists.txt
new file mode 100644
index 0000000..32bd14d
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "constant_ops.mlir"
+ "structural_ops.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/constant_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/constant_ops.mlir
new file mode 100644
index 0000000..ef2e038
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/constant_ops.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion %s | IreeFileCheck %s
+
+// CHECK-LABEL: @constantTensor
+func @constantTensor() {
+ // CHECK: %[[CST:.+]] = stream.tensor.constant : tensor<2xi32> in !stream.resource<constant> = dense<[1, 2]> : tensor<2xi32>
+ // CHECK: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource<constant>
+ // CHECK: %[[T:.+]] = stream.async.transfer %[[CST]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+ %0 = arith.constant dense<[1, 2]> : tensor<2xi32>
+ return
+}
diff --git a/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/structural_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/structural_ops.mlir
new file mode 100644
index 0000000..4f94561
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/structural_ops.mlir
@@ -0,0 +1,64 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion %s | IreeFileCheck %s
+
+// CHECK-LABEL: @functionExpansion
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<*>, %[[ARG0_SIZE:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: i1,
+// CHECK-SAME: %[[ARG2:.+]]: !stream.resource<*>, %[[ARG2_SIZE:.+]]: index)
+// CHECK-SAME: -> (!stream.resource<*>, index, i1, !stream.resource<*>, index)
+func @functionExpansion(%arg0: tensor<4x?xf32>, %arg1: i1, %arg2: tensor<i32>)
+ -> (tensor<4x?xf32>, i1, tensor<i32>) {
+ // CHECK-NEXT: %[[RET:.+]]:5 = call @callee(%[[ARG0]], %[[ARG0_SIZE]], %[[ARG1]], %[[ARG2]], %[[ARG2_SIZE]])
+ // CHECK-SAME: : (!stream.resource<*>, index, i1, !stream.resource<*>, index) -> (!stream.resource<*>, index, i1, !stream.resource<*>, index)
+ %0:3 = call @callee(%arg0, %arg1, %arg2) : (tensor<4x?xf32>, i1, tensor<i32>) -> (tensor<4x?xf32>, i1, tensor<i32>)
+ // CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2, %[[RET]]#3, %[[RET]]#4 : !stream.resource<*>, index, i1, !stream.resource<*>, index
+ return %0#0, %0#1, %0#2 : tensor<4x?xf32>, i1, tensor<i32>
+}
+
+// CHECK: func private @callee
+func private @callee(%arg0: tensor<4x?xf32>, %arg1: i1, %arg2: tensor<i32>)
+ -> (tensor<4x?xf32>, i1, tensor<i32>)
+
+// -----
+
+// CHECK-LABEL: @brExpansion
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<*>, %[[ARG0_SIZE:.+]]: index, %arg2: i1)
+// CHECK-SAME: -> (!stream.resource<*>, index, i1)
+func @brExpansion(%arg0: tensor<1xf32>, %arg1: i1) -> (tensor<1xf32>, i1) {
+ // CHECK: br ^bb1(%[[ARG0]], %[[ARG0_SIZE]], %arg2 : !stream.resource<*>, index, i1)
+ br ^bb1(%arg0, %arg1 : tensor<1xf32>, i1)
+// CHECK: ^bb1(%[[BB_ARG0:.+]]: !stream.resource<*>, %[[BB_ARG1:.+]]: index, %[[BB_ARG2:.+]]: i1):
+^bb1(%0: tensor<1xf32>, %1: i1):
+ // CHECK: return %[[BB_ARG0]], %[[BB_ARG1]], %[[BB_ARG2]] : !stream.resource<*>, index, i1
+ return %0, %1 : tensor<1xf32>, i1
+}
+
+// -----
+
+// CHECK-LABEL: @condBrExpansion
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<*>, %[[ARG0_SIZE:.+]]: index,
+// CHECK-SAME: %[[ARG1:.+]]: !stream.resource<*>, %[[ARG1_SIZE:.+]]: index)
+// CHECK-SAME: -> (!stream.resource<*>, index)
+func @condBrExpansion(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+ %true = arith.constant 1 : i1
+ // CHECK: cond_br %true,
+ // CHECK-SAME: ^bb1(%[[ARG0]], %[[ARG0_SIZE]] : !stream.resource<*>, index),
+ // CHECK-SAME: ^bb1(%[[ARG1]], %[[ARG1_SIZE]] : !stream.resource<*>, index)
+ cond_br %true, ^bb1(%arg0 : tensor<1xf32>), ^bb1(%arg1 : tensor<1xf32>)
+^bb1(%0: tensor<1xf32>):
+ return %0 : tensor<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @selectExpansion
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<*>, %[[ARG0_SIZE:.+]]: index,
+// CHECK-SAME: %[[COND:.+]]: i1,
+// CHECK-SAME: %[[ARG1:.+]]: !stream.resource<*>, %[[ARG1_SIZE:.+]]: index)
+// CHECK-SAME: -> (!stream.resource<*>, index)
+func @selectExpansion(%arg0: tensor<1xf32>, %cond: i1, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-DAG: %[[RET:.+]] = select %[[COND]], %[[ARG0]], %[[ARG1]] : !stream.resource<*>
+ // CHECK-DAG: %[[RET_SIZE:.+]] = select %[[COND]], %[[ARG0_SIZE]], %[[ARG1_SIZE]] : index
+ %0 = select %cond, %arg0, %arg1 : tensor<1xf32>
+ // CHECK: return %[[RET]], %[[RET_SIZE]] : !stream.resource<*>, index
+ return %0 : tensor<1xf32>
+}
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD
new file mode 100644
index 0000000..e880f5b
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "UtilToStream",
+ srcs = [
+ "ConvertUtilToStream.cpp",
+ ],
+ hdrs = [
+ "ConvertUtilToStream.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Stream/Conversion",
+ "//iree/compiler/Dialect/Stream/IR",
+ "//iree/compiler/Dialect/Util/IR",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
new file mode 100644
index 0000000..274da6a
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/CMakeLists.txt
@@ -0,0 +1,30 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/UtilToStream/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ UtilToStream
+ HDRS
+ "ConvertUtilToStream.h"
+ SRCS
+ "ConvertUtilToStream.cpp"
+ DEPS
+ MLIRIR
+ MLIRStandard
+ MLIRTransforms
+ iree::compiler::Dialect::Stream::Conversion
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.cpp
new file mode 100644
index 0000000..d1a1c13
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.cpp
@@ -0,0 +1,279 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.h"
+
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Globals
+//===----------------------------------------------------------------------===//
+
+struct ExpandedGlobalResource {
+ IREE::Util::GlobalOp resourceOp;
+ IREE::Util::GlobalOp resourceSizeOp;
+};
+
+struct GlobalExpansionState {
+ // A map of original symbol name to one new global for each expanded type.
+ DenseMap<StringRef, ExpandedGlobalResource> globalMap;
+};
+
+static bool isExpandedType(Type type) {
+ if (type.isa<TensorType>()) return true;
+ if (auto ptrType = type.dyn_cast<IREE::Util::PtrType>()) {
+ return isExpandedType(ptrType);
+ }
+ return false;
+}
+
+template <typename T>
+class BaseGlobalConversionPattern : public OpConversionPattern<T> {
+ public:
+ BaseGlobalConversionPattern(
+ std::shared_ptr<GlobalExpansionState> expansionState,
+ TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<T>(typeConverter, context, benefit),
+ expansionState(std::move(expansionState)) {}
+
+ protected:
+ mutable std::shared_ptr<GlobalExpansionState> expansionState;
+};
+
+struct GlobalOpExpansion
+ : public BaseGlobalConversionPattern<IREE::Util::GlobalOp> {
+ using BaseGlobalConversionPattern::BaseGlobalConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::GlobalOp globalOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only apply to expanded types (tensors/etc).
+ if (!isExpandedType(globalOp.type())) return failure();
+
+ SmallVector<Type> newTypes;
+ if (failed(getTypeConverter()->convertType(globalOp.type(), newTypes))) {
+ return rewriter.notifyMatchFailure(globalOp,
+ "failed to convert ptr type");
+ }
+ if (newTypes.size() == 1 && newTypes.front() == globalOp.type()) {
+ return rewriter.notifyMatchFailure(globalOp, "no conversion needed");
+ }
+
+ // Start with the appropriate type. Lifetime refinement will use this as a
+ // seed. Note that what was a constant in earlier dialects becomes a mutable
+ // global holding a resource that may have constant contents.
+ bool hasConstantUsage = !globalOp.isMutable();
+ auto resourceType = IREE::Stream::ResourceType::get(
+ rewriter.getContext(), hasConstantUsage
+ ? IREE::Stream::Lifetime::Constant
+ : IREE::Stream::Lifetime::Variable);
+
+ // Special handling of the initial value: if it's a tensor then we need to
+ // materialize an initializer and initialization ops. This allows the
+ // current conversion to pick up the expanded initialization ops.
+ auto initialValue = globalOp.initial_valueAttr();
+ bool tensorInitializerRequired =
+ initialValue ? initialValue.getType().isa<TensorType>() : false;
+
+ // New global holding the initial value only if it is not a tensor type.
+ auto resourceOp = rewriter.replaceOpWithNewOp<IREE::Util::GlobalOp>(
+ globalOp, globalOp.getName(), globalOp.is_mutable(), resourceType,
+ initialValue && !tensorInitializerRequired
+ ? llvm::Optional<Attribute>{initialValue}
+ : llvm::None);
+ resourceOp.setVisibility(globalOp.getVisibility());
+
+ // NOTE: we ignore noinline here, possibly to our peril. In earlier dialects
+ // noinline indicates that the constant value should not be inlined, while
+ // here it would be indicating the reference to the constant value should
+ // not be (and that's weird).
+
+ // Also create a global for tracking the resource size. In many cases this
+ // is constant and will fold throughout the program. Global optimizations
+ // such as same-value deduplication will also take effect.
+ auto indexType = rewriter.getIndexType();
+ auto resourceSizeOp = rewriter.create<IREE::Util::GlobalOp>(
+ globalOp.getLoc(), (globalOp.getName() + "__size").str(),
+ globalOp.is_mutable(), indexType, Optional<Attribute>{});
+ resourceSizeOp.setVisibility(globalOp.getVisibility());
+
+ // Materialize the initializer if we need to setup a tensor-like constant.
+ if (tensorInitializerRequired) {
+ auto initializerOp =
+ rewriter.create<IREE::Util::InitializerOp>(globalOp.getLoc());
+ auto *entryBlock = rewriter.createBlock(&initializerOp.getBody());
+ rewriter.setInsertionPointToStart(entryBlock);
+ auto constantOp = rewriter.create<IREE::Stream::TensorConstantOp>(
+ globalOp.getLoc(), resourceOp.type(),
+ initialValue.cast<ElementsAttr>(), TypeAttr::get(globalOp.type()),
+ /*result_encoding_dims=*/ValueRange{}, /*affinity=*/nullptr);
+ auto constantSizeOp = rewriter.create<IREE::Stream::ResourceSizeOp>(
+ globalOp.getLoc(), indexType, constantOp.result());
+ rewriter.create<IREE::Util::GlobalStoreOp>(
+ globalOp.getLoc(), constantOp.result(), resourceOp.getSymbolName());
+ rewriter.create<IREE::Util::GlobalStoreOp>(
+ globalOp.getLoc(), constantSizeOp.result(),
+ resourceSizeOp.getSymbolName());
+ rewriter.create<IREE::Util::InitializerReturnOp>(globalOp.getLoc());
+ }
+
+ expansionState->globalMap[globalOp.getSymbolName()] =
+ ExpandedGlobalResource{
+ resourceOp,
+ resourceSizeOp,
+ };
+
+ return success();
+ }
+};
+
+struct GlobalLoadOpExpansion
+ : public BaseGlobalConversionPattern<IREE::Util::GlobalLoadOp> {
+ using BaseGlobalConversionPattern::BaseGlobalConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::GlobalLoadOp loadOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Util::GlobalLoadOpAdaptor operands(newOperands,
+ loadOp->getAttrDictionary());
+
+ // Only apply to expanded types (tensors/etc).
+ if (!isExpandedType(loadOp.getType())) return failure();
+ auto &expandedGlobal =
+ expansionState->globalMap[operands.global().getValue()];
+
+ // Insert a load/transfer to the unknown resource lifetime.
+ auto unknownType = IREE::Stream::ResourceType::get(rewriter.getContext());
+ auto resource = rewriter
+ .create<IREE::Util::GlobalLoadOp>(
+ loadOp.getLoc(), expandedGlobal.resourceOp.type(),
+ expandedGlobal.resourceOp.getSymbolName())
+ .result();
+ auto resourceSize = rewriter
+ .create<IREE::Util::GlobalLoadOp>(
+ loadOp.getLoc(), rewriter.getIndexType(),
+ expandedGlobal.resourceSizeOp.getSymbolName())
+ .result();
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ loadOp, unknownType, resource, resourceSize, resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+
+ return success();
+ }
+};
+
+struct GlobalStoreOpExpansion
+ : public BaseGlobalConversionPattern<IREE::Util::GlobalStoreOp> {
+ using BaseGlobalConversionPattern::BaseGlobalConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Util::GlobalStoreOp storeOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ IREE::Util::GlobalStoreOpAdaptor operands(newOperands,
+ storeOp->getAttrDictionary());
+
+ // Only apply to expanded types (tensors/etc).
+ if (!isExpandedType(storeOp.value().getType())) return failure();
+ auto &expandedGlobal =
+ expansionState->globalMap[operands.global().getValue()];
+
+ // Insert a transfer/store to the global with unknown lifetime. Lifetime
+ // refinement will make this go away if possible.
+ auto value =
+ consumeTensorOperand(storeOp.getLoc(), operands.value(), rewriter);
+ auto transferOp = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ storeOp.getLoc(), expandedGlobal.resourceOp.type(), value.resource,
+ value.resourceSize, value.resourceSize, /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ rewriter.replaceOpWithNewOp<IREE::Util::GlobalStoreOp>(
+ storeOp, transferOp.result(),
+ expandedGlobal.resourceOp.getSymbolName());
+ rewriter.create<IREE::Util::GlobalStoreOp>(
+ storeOp.getLoc(), value.resourceSize,
+ expandedGlobal.resourceSizeOp.getSymbolName());
+
+ return success();
+ }
+};
+
+} // namespace
+
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ auto expansionState = std::make_shared<GlobalExpansionState>();
+ // TODO(#7432): add indirect global expansion support to streams.
+ patterns
+ .insert<GlobalOpExpansion, GlobalLoadOpExpansion, GlobalStoreOpExpansion>(
+ expansionState, typeConverter, context);
+}
+
+void populateUtilToStreamConversionPatterns(
+ MLIRContext *context, ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ typeConverter.addConversion([=](IREE::Util::PtrType type,
+ SmallVectorImpl<Type> &resultTypes) {
+ // Expand pointers to tensors to [resource, sizeof resource] pointers.
+ if (!isExpandedType(type)) return failure();
+ resultTypes.push_back(
+ IREE::Util::PtrType::get(IREE::Stream::ResourceType::get(context)));
+ resultTypes.push_back(IREE::Util::PtrType::get(IndexType::get(context)));
+ return success();
+ });
+
+ typeConverter.addConversion(
+ [=](IREE::Util::PtrType type, SmallVectorImpl<Type> &resultTypes) {
+ // Expand pointers to tensors to [ptr<resource>, ptr<sizeof resource>].
+ if (!isExpandedType(type.getTargetType())) return failure();
+ resultTypes.push_back(IREE::Stream::ResourceType::get(context));
+ resultTypes.push_back(IndexType::get(context));
+ return success();
+ });
+
+ conversionTarget
+ .addLegalOp<IREE::Util::InitializerOp, IREE::Util::InitializerReturnOp>();
+ conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalOp>(
+ [&](IREE::Util::GlobalOp op) {
+ return typeConverter.isLegal(op.type()) &&
+ (!op.initial_valueAttr() ||
+ !op.initial_valueAttr().getType().isa<TensorType>());
+ });
+ conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalAddressOp>(
+ [&](IREE::Util::GlobalAddressOp op) {
+ return typeConverter.isLegal(op.result().getType());
+ });
+ conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalLoadOp>(
+ [&](IREE::Util::GlobalLoadOp op) {
+ return typeConverter.isLegal(op.result().getType());
+ });
+ conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalLoadIndirectOp>(
+ [&](IREE::Util::GlobalLoadIndirectOp op) {
+ return typeConverter.isLegal(op.result().getType());
+ });
+ conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalStoreOp>(
+ [&](IREE::Util::GlobalStoreOp op) {
+ return typeConverter.isLegal(op.value().getType());
+ });
+ conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalStoreIndirectOp>(
+ [&](IREE::Util::GlobalStoreIndirectOp op) {
+ return typeConverter.isLegal(op.value().getType());
+ });
+
+ populateUtilToStreamConversionPatterns(context, typeConverter, patterns);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.h b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.h
new file mode 100644
index 0000000..c43912d
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.h
@@ -0,0 +1,31 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_CONVERTUTILTOSTREAM_H_
+#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_CONVERTUTILTOSTREAM_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Populates conversion patterns that perform util->stream conversion.
+// These patterns ensure that nested types are run through the provided
+// |typeConverter|.
+void populateUtilToStreamConversionPatterns(MLIRContext *context,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+void populateUtilToStreamConversionPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_CONVERSION_CONVERTUTILTOSTREAM_H_
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD
new file mode 100644
index 0000000..4478764
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD
@@ -0,0 +1,26 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ ["global_ops.mlir"],
+ include = ["*.mlir"],
+ ),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/CMakeLists.txt
new file mode 100644
index 0000000..c588368
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "global_ops.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir
new file mode 100644
index 0000000..cc8ee40
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir
@@ -0,0 +1,86 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion %s | IreeFileCheck %s
+
+// CHECK: util.global public mutable @var_i32 : !stream.resource<variable>
+// CHECK: util.global public mutable @var_i32__size : index
+util.global public mutable @var_i32 : tensor<i32>
+// CHECK-LABEL: @mutableGlobal
+func @mutableGlobal() {
+ // CHECK-DAG: %[[VAR:.+]] = util.global.load @var_i32 : !stream.resource<variable>
+ // CHECK-DAG: %[[SIZE:.+]] = util.global.load @var_i32__size : index
+ // CHECK: %[[LOAD_T:.+]] = stream.async.transfer %[[VAR]] : !stream.resource<variable>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+ %0 = util.global.load @var_i32 : tensor<i32>
+ // CHECK: %[[STORE_T:.+]] = stream.async.transfer %[[LOAD_T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<variable>{%[[SIZE]]}
+ // CHECK-DAG: util.global.store %[[STORE_T]], @var_i32 : !stream.resource<variable>
+ // CHECK-DAG: util.global.store %[[SIZE]], @var_i32__size : index
+ util.global.store %0, @var_i32 : tensor<i32>
+ return
+}
+
+// -----
+
+// TODO(#7432): add indirect global expansion support to streams.
+// util.global public mutable @var_indirect : tensor<i32>
+// func @mutableGlobalIndirect() {
+// %0 = util.global.address @var_indirect : !util.ptr<tensor<i32>>
+// %1 = util.global.load.indirect %0 : !util.ptr<tensor<i32>> -> tensor<i32>
+// util.global.store.indirect %1, %0 : tensor<i32> -> !util.ptr<tensor<i32>>
+// return
+// }
+
+// -----
+
+// CHECK-DAG: util.global public mutable @var_with_tensor_initializer : !stream.resource<variable>
+// CHECK-DAG: util.global public mutable @var_with_tensor_initializer__size : index
+// CHECK-NEXT: util.initializer {
+// CHECK-NEXT: %[[CST:.+]] = stream.tensor.constant : tensor<f32> in !stream.resource<variable> = dense<0.000000e+00> : tensor<f32>
+// CHECK-NEXT: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource<variable>
+// CHECK-DAG: util.global.store %[[CST]], @var_with_tensor_initializer : !stream.resource<variable>
+// CHECK-DAG: util.global.store %[[SIZE]], @var_with_tensor_initializer__size : index
+util.global public mutable @var_with_tensor_initializer = dense<0.000000e+00> : tensor<f32>
+// CHECK-LABEL: @initializedGlobal
+func @initializedGlobal() {
+ // CHECK-DAG: = util.global.load @var_with_tensor_initializer : !stream.resource<variable>
+ // CHECK-DAG: = util.global.load @var_with_tensor_initializer__size : index
+ %0 = util.global.load @var_with_tensor_initializer : tensor<f32>
+ // CHECK-DAG: util.global.store %{{.+}}, @var_with_tensor_initializer : !stream.resource<variable>
+ // CHECK-DAG: util.global.store %{{.+}}, @var_with_tensor_initializer__size : index
+ util.global.store %0, @var_with_tensor_initializer : tensor<f32>
+ return
+}
+
+// -----
+
+// Checks that the implicit cast allowing a buffer_view to store into a variable
+// that maps to a buffer is permitted.
+
+// CHECK-DAG: util.global public mutable @var_with_buffer_view_store : !stream.resource<variable>
+// CHECK-DAG: util.global public mutable @var_with_buffer_view_store__size : index
+util.global public mutable @var_with_buffer_view_store : tensor<?x4xf32>
+// CHECK-LABEL: @globalStoreFromExternal
+func @globalStoreFromExternal(%arg0: !hal.buffer_view) {
+ // CHECK: %[[DIM0:.+]] = hal.buffer_view.dim
+ %dim0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
+ // CHECK: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
+ // CHECK: %[[IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[SIZE]]}
+ // CHECK: %[[T:.+]] = stream.async.transfer %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+ %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
+ // CHECK: %[[VAR:.+]] = stream.async.transfer %[[T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<variable>{%[[SIZE]]}
+ // CHECK: util.global.store %[[VAR]], @var_with_buffer_view_store : !stream.resource<variable>
+ // CHECK: util.global.store %[[SIZE]], @var_with_buffer_view_store__size : index
+ util.global.store %0, @var_with_buffer_view_store : tensor<?x4xf32>
+ return
+}
+
+// -----
+
+// Checks that the implicit cast allowing a buffer_view to indirect store into
+// a variable that maps to a buffer is permitted.
+
+// TODO(#7432): add indirect global expansion support to streams.
+// util.global public mutable @var_indirect_with_buffer_view_store : tensor<i32>
+// func @globalStoreFromExternalIndirect(%arg0: !hal.buffer_view) {
+// %0 = util.global.address @var_indirect_with_buffer_view_store : !util.ptr<tensor<i32>>
+// %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i32>
+// util.global.store.indirect %1, %0 : tensor<i32> -> !util.ptr<tensor<i32>>
+// return
+// }
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 84a1931..5a3906f 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -311,6 +311,42 @@
Block::iterator(op));
}
+namespace {
+
+// Propagates resource sizes through select ops by selecting on the sizes of the
+// select operands.
+//
+// Example:
+// %a = stream... : !stream.resource<*>{%a_sz}
+// %b = stream... : !stream.resource<*>{%b_sz}
+// %c = select %cond, %a, %b : !stream.resource<*>
+// %c_sz = stream.resource.size %c : !stream.resource<*>
+// ->
+// %c = select %cond, %a, %b : !stream.resource<*>
+// %c_sz = select %cond, %a_sz, %b_sz : index
+struct SelectResourceSizeOp : public OpRewritePattern<ResourceSizeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ResourceSizeOp op,
+ PatternRewriter &rewriter) const override {
+ auto selectOp = op.operand().getDefiningOp<mlir::SelectOp>();
+ if (!selectOp) return failure();
+ auto trueSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+ op.getLoc(), selectOp.true_value(), op.affinityAttr());
+ auto falseSize = rewriter.createOrFold<IREE::Stream::ResourceSizeOp>(
+ op.getLoc(), selectOp.false_value(), op.affinityAttr());
+ rewriter.replaceOpWithNewOp<mlir::SelectOp>(op, selectOp.condition(),
+ trueSize, falseSize);
+ return success();
+ }
+};
+
+} // namespace
+
+void ResourceSizeOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<SelectResourceSizeOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// stream.resource.map
//===----------------------------------------------------------------------===//
@@ -468,8 +504,7 @@
// Zero offsets don't do anything and can just be removed so we can avoid
// inserting a bunch of additional IR.
- if (auto constantOp = dyn_cast_or_null<arith::ConstantIndexOp>(
- baseOffset.getDefiningOp())) {
+ if (auto constantOp = baseOffset.getDefiningOp<arith::ConstantIndexOp>()) {
if (constantOp.value() == 0) {
return success();
}
@@ -916,8 +951,8 @@
LogicalResult matchAndRewrite(AsyncCloneOp cloneOp,
PatternRewriter &rewriter) const override {
if (cloneOp.use_empty()) return failure();
- auto sourceOp = dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(
- cloneOp.source().getDefiningOp());
+ auto sourceOp =
+ cloneOp.source().getDefiningOp<IREE::Stream::StreamableOpInterface>();
if (!sourceOp || !sourceOp.preferCloneToConsumers()) return failure();
for (auto &use : llvm::make_early_inc_range(cloneOp.result().getUses())) {
rewriter.setInsertionPoint(use.getOwner());
@@ -968,8 +1003,7 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncSliceOp sliceOp,
PatternRewriter &rewriter) const override {
- auto splatOp = dyn_cast_or_null<IREE::Stream::AsyncSplatOp>(
- sliceOp.source().getDefiningOp());
+ auto splatOp = sliceOp.source().getDefiningOp<IREE::Stream::AsyncSplatOp>();
if (!splatOp) return failure();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncSplatOp>(
sliceOp, sliceOp.result().getType(), splatOp.value(),
@@ -1053,14 +1087,13 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncUpdateOp updateOp,
PatternRewriter &rewriter) const override {
- auto splatOp = dyn_cast_or_null<IREE::Stream::AsyncSplatOp>(
- updateOp.update().getDefiningOp());
+ auto splatOp =
+ updateOp.update().getDefiningOp<IREE::Stream::AsyncSplatOp>();
if (!splatOp) return failure();
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncFillOp>(
updateOp, updateOp.result().getType(), updateOp.target(),
updateOp.target_size(), updateOp.target_offset(), updateOp.target_end(),
- updateOp.update_size(), splatOp.value(), updateOp.tied_operandsAttr(),
- updateOp.affinityAttr());
+ updateOp.update_size(), splatOp.value(), updateOp.affinityAttr());
return success();
}
};
@@ -1086,8 +1119,8 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncUpdateOp updateOp,
PatternRewriter &rewriter) const override {
- auto sliceOp = dyn_cast_or_null<IREE::Stream::AsyncSliceOp>(
- updateOp.update().getDefiningOp());
+ auto sliceOp =
+ updateOp.update().getDefiningOp<IREE::Stream::AsyncSliceOp>();
if (!sliceOp || sliceOp->getBlock() != updateOp->getBlock()) {
// Source is not a slice or a slice from out-of-block. We don't want to
// grow memory usage by sinking the slice here (we may slice into the
@@ -1098,8 +1131,7 @@
updateOp, updateOp.result().getType(), updateOp.target(),
updateOp.target_size(), updateOp.target_offset(), updateOp.target_end(),
sliceOp.source(), sliceOp.source_size(), sliceOp.source_offset(),
- sliceOp.source_end(), sliceOp.result_size(),
- updateOp.tied_operandsAttr(), updateOp.affinityAttr());
+ sliceOp.source_end(), sliceOp.result_size(), updateOp.affinityAttr());
return success();
}
};
@@ -1140,8 +1172,7 @@
rewriter.replaceOpWithNewOp<IREE::Stream::AsyncUpdateOp>(
copyOp, copyOp.result().getType(), copyOp.target(),
copyOp.target_size(), copyOp.target_offset(), copyOp.target_end(),
- copyOp.source(), copyOp.source_size(), copyOp.tied_operandsAttr(),
- copyOp.affinityAttr());
+ copyOp.source(), copyOp.source_size(), copyOp.affinityAttr());
return success();
}
return failure();
@@ -1162,8 +1193,7 @@
//===----------------------------------------------------------------------===//
OpFoldResult AsyncTransferOp::fold(ArrayRef<Attribute> operands) {
- if (auto sourceTransferOp =
- dyn_cast_or_null<AsyncTransferOp>(source().getDefiningOp())) {
+ if (auto sourceTransferOp = source().getDefiningOp<AsyncTransferOp>()) {
if (sourceTransferOp.source().getType() == result().getType() &&
sourceTransferOp.source_affinity() == result_affinity()) {
return sourceTransferOp.source();
@@ -1201,6 +1231,29 @@
}
//===----------------------------------------------------------------------===//
+// stream.async.load
+//===----------------------------------------------------------------------===//
+
+void AsyncLoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ // TODO(benvanik): splat + load -> splat value.
+ // TODO(benvanik): clone + ex load -> slice (ranged) + load.
+ // TODO(benvanik): slice + ex load -> slice (ranged) + load.
+ // TODO(benvanik): value->transfer->load -> value->slice->transfer->load?
+ // TODO(benvanik): combine multiple loads from the same target if contiguous.
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.store
+//===----------------------------------------------------------------------===//
+
+void AsyncStoreOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ // TODO(benvanik): if value is a constant splat then turn into fill.
+ // TODO(benvanik): combine multiple stores to the same target if contiguous.
+}
+
+//===----------------------------------------------------------------------===//
// stream.async.dispatch
//===----------------------------------------------------------------------===//
@@ -1282,8 +1335,7 @@
SmallVector<Value> newTimepoints;
SmallVector<std::pair<unsigned, Value>> replacements;
for (auto operand : llvm::enumerate(op.operands())) {
- if (auto awaitOp = dyn_cast_or_null<TimepointAwaitOp>(
- operand.value().getDefiningOp())) {
+ if (auto awaitOp = operand.value().getDefiningOp<TimepointAwaitOp>()) {
newTimepoints.push_back(awaitOp.timepoint());
replacements.push_back(std::make_pair(
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
@@ -1738,8 +1790,7 @@
SmallVector<Value> newTimepoints;
SmallVector<std::pair<unsigned, Value>> replacements;
for (auto operand : llvm::enumerate(op.operands())) {
- if (auto awaitOp = dyn_cast_or_null<TimepointAwaitOp>(
- operand.value().getDefiningOp())) {
+ if (auto awaitOp = operand.value().getDefiningOp<TimepointAwaitOp>()) {
newTimepoints.push_back(awaitOp.timepoint());
replacements.push_back(std::make_pair(
operand.index(), awaitOp.getTiedResultOperand(operand.value())));
@@ -2051,8 +2102,8 @@
rewriter.startRootUpdate(op);
bool didChange = false;
for (auto operand : llvm::enumerate(op.operands())) {
- auto subviewOp = dyn_cast_or_null<IREE::Stream::ResourceSubviewOp>(
- operand.value().getDefiningOp());
+ auto subviewOp =
+ operand.value().getDefiningOp<IREE::Stream::ResourceSubviewOp>();
if (!subviewOp) continue;
didChange = true;
unsigned operandIdx = static_cast<unsigned>(operand.index());
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index cd39e77..50389ee 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -1091,6 +1091,41 @@
}
//===----------------------------------------------------------------------===//
+// stream.async.load
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncLoadOp op) {
+ if (failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// stream.async.store
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyOp(AsyncStoreOp op) {
+ if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
+ return failure();
+ }
+ return success();
+}
+
+Value AsyncStoreOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
+}
+
+::llvm::Optional<unsigned> AsyncStoreOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // target
+}
+
+SmallVector<int64_t, 4> AsyncStoreOp::getTiedResultOperandIndices() {
+ return {0}; // target
+}
+
+//===----------------------------------------------------------------------===//
// stream.async.dispatch
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.td b/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 9d77a6e..74554ce 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -200,6 +200,7 @@
Value getResultSize(unsigned idx) { return {}; }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -949,7 +950,6 @@
Variadic<Stream_Dim>:$start_indices,
Variadic<Stream_Dim>:$lengths,
Stream_PrimitiveType:$value,
- OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
@@ -963,7 +963,7 @@
`->`
$target_encoding (`` `{` $target_encoding_dims^ `}`)?
`in`
- custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ custom<ShapedTiedResult>(type($target), $target_size)
attr-dict-with-keyword
}];
@@ -1009,7 +1009,6 @@
TypeAttr:$update_encoding,
Stream_ShapeDynamicDims:$update_encoding_dims,
Stream_Size:$update_size,
- OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
@@ -1025,7 +1024,7 @@
`->`
$target_encoding (`` `{` $target_encoding_dims^ `}`)?
`in`
- custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ custom<ShapedTiedResult>(type($target), $target_size)
attr-dict-with-keyword
}];
@@ -1090,7 +1089,7 @@
let hasCanonicalizer = 1;
}
-def Stream_TensorStoreOp : Stream_Op<"tensor.store", [
+def Stream_TensorStoreOp : Stream_PureOp<"tensor.store", [
AttrSizedOperandSegments,
AllTypesMatch<["target", "result"]>,
Stream_TensorPhaseOp,
@@ -1113,8 +1112,7 @@
Stream_ShapeDynamicDims:$target_encoding_dims,
Stream_Size:$target_size,
Variadic<Stream_Dim>:$indices,
- AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$value,
- OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$value
);
let results = (outs
Stream_StagingResource:$result
@@ -1127,7 +1125,7 @@
`->`
$target_encoding (`` `{` $target_encoding_dims^ `}`)?
`in`
- custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ custom<ShapedTiedResult>(type($target), $target_size)
attr-dict-with-keyword
}];
@@ -1385,7 +1383,6 @@
Stream_Offset:$target_end,
Stream_Size:$target_length,
Stream_PrimitiveType:$value,
- OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
@@ -1397,7 +1394,7 @@
$value `,`
$target `[` $target_offset `to` $target_end `for` $target_length `]` `:`
type($value) `->`
- custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ custom<ShapedTiedResult>(type($target), $target_size)
attr-dict-with-keyword
}];
@@ -1439,7 +1436,6 @@
Stream_Offset:$target_end,
Stream_AnyStreamResource:$update,
Stream_Size:$update_size,
- OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
@@ -1451,7 +1447,7 @@
$update `,`
$target `[` $target_offset `to` $target_end `]` `:`
type($update) `` `{` $update_size `}` `->`
- custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ custom<ShapedTiedResult>(type($target), $target_size)
attr-dict-with-keyword
}];
@@ -1500,7 +1496,6 @@
Stream_Offset:$source_offset,
Stream_Offset:$source_end,
Stream_Size:$length,
- OptionalAttr<Util_TiedOpStorageAttr>:$tied_operands,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
@@ -1513,7 +1508,7 @@
$target `[` $target_offset `to` $target_end `]` `,`
$length `:`
type($source) `` `{` $source_size `}` `->`
- custom<ShapedTiedResult>(type($target), $target_size, $tied_operands)
+ custom<ShapedTiedResult>(type($target), $target_size)
attr-dict-with-keyword
}];
@@ -1580,6 +1575,87 @@
let hasFolder = 1;
}
+def Stream_AsyncLoadOp : Stream_PureOp<"async.load", [
+ Stream_AsyncPhaseOp,
+ Util_SizeAwareOp,
+]> {
+ let summary = [{loads a value from a resource}];
+ let description = [{
+ Returns the element at the given location from within the resource.
+ }];
+
+ let arguments = (ins
+ Stream_StagingResource:$source,
+ Stream_Size:$source_size,
+ Stream_Offset:$source_offset
+ );
+ let results = (outs
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $source_offset `]` `:`
+ type($source) `` `{` $source_size `}`
+ `->`
+ type($result)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return source_size(); }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def Stream_AsyncStoreOp : Stream_PureOp<"async.store", [
+ AllTypesMatch<["target", "result"]>,
+ Stream_AsyncPhaseOp,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+]> {
+ let summary = [{stores a value into a resource}];
+ let description = [{
+ Returns a resource with the element at the given offset set to the given
+ value.
+ }];
+
+ let arguments = (ins
+ Stream_StagingResource:$target,
+ Stream_Size:$target_size,
+ Stream_Offset:$target_offset,
+ AnyTypeOf<[Stream_PrimitiveType, AnyVector]>:$value
+ );
+ let results = (outs
+ Stream_StagingResource:$result
+ );
+
+ let assemblyFormat = [{
+ $value `,`
+ $target `[` $target_offset `]` `:`
+ type($value)
+ `->`
+ custom<ShapedTiedResult>(type($target), $target_size)
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return target_size(); }
+ Value getResultSize(unsigned idx) { return target_size(); }
+ }];
+
+ let verifier = [{ return verifyOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
def Stream_AsyncDispatchOp : Stream_Op<"async.dispatch", [
AttrSizedOperandSegments,
Stream_AffinityOp,
diff --git a/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
index 352a521..b02fe78 100644
--- a/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
+++ b/iree/compiler/Dialect/Stream/IR/test/async_ops.mlir
@@ -89,6 +89,26 @@
// -----
+// CHECK-LABEL: @asyncLoad
+func @asyncLoad(%arg0: !stream.resource<staging>, %arg1: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.async.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> f32
+ %0 = stream.async.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @asyncStore
+func @asyncStore(%arg0: !stream.resource<staging>, %arg1: index, %arg2: f32) -> !stream.resource<staging> {
+ %c0 = arith.constant 0 : index
+ // CHECK: = stream.async.store %arg2, %arg0[%c0] : f32 -> %arg0 as !stream.resource<staging>{%arg1}
+ %0 = stream.async.store %arg2, %arg0[%c0] : f32 -> %arg0 as !stream.resource<staging>{%arg1}
+ return %0 : !stream.resource<staging>
+}
+
+// -----
+
// CHECK-LABEL: @asyncDispatch
func @asyncDispatch(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> {
%c1 = arith.constant 1 : index
diff --git a/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
index 0c7b333..5bd3609 100644
--- a/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
+++ b/iree/compiler/Dialect/Stream/IR/test/resource_folding.mlir
@@ -13,6 +13,23 @@
// -----
+// CHECK-LABEL: @SelectResourceSizeOp
+func @SelectResourceSizeOp(%arg0: !stream.resource<staging>, %arg1: index, %arg2: !stream.resource<staging>, %arg3: index, %arg4: i1) -> (!stream.resource<staging>, index) {
+ // CHECK: %[[ARG0_T:.+]] = stream.async.transfer %arg0 {{.+}} -> !stream.resource<*>{%[[ARG0_SZ:.+]]}
+ %0 = stream.async.transfer %arg0 : !stream.resource<staging>{%arg1} -> !stream.resource<*>{%arg1}
+ // CHECK: %[[ARG2_T:.+]] = stream.async.transfer %arg2 {{.+}} -> !stream.resource<*>{%[[ARG2_SZ:.+]]}
+ %1 = stream.async.transfer %arg2 : !stream.resource<staging>{%arg3} -> !stream.resource<*>{%arg3}
+ // CHECK: %[[RET_T:.+]] = select %arg4, %[[ARG0_T]], %[[ARG2_T]] : !stream.resource<*>
+ %2 = select %arg4, %0, %1 : !stream.resource<*>
+ // CHECK: %[[RET_SIZE:.+]] = select %arg4, %[[ARG0_SZ]], %[[ARG2_SZ]] : index
+ %3 = stream.resource.size %2 : !stream.resource<*>
+ // CHECK: = stream.async.transfer %[[RET_T]] : !stream.resource<*>{%[[RET_SIZE]]}
+ %4 = stream.async.transfer %2 : !stream.resource<*>{%3} -> !stream.resource<staging>{%3}
+ return %4, %3 : !stream.resource<staging>, index
+}
+
+// -----
+
// CHECK-LABEL: @FoldSubviewIntoLoadOp
func @FoldSubviewIntoLoadOp(%arg0: !stream.resource<staging>, %arg1: index) -> i32 {
%c64 = arith.constant 64 : index
diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD
new file mode 100644
index 0000000..1c4848c
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/BUILD
@@ -0,0 +1,71 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "Transforms",
+ srcs = [
+ "ConvertToStream.cpp",
+ "OutlineConstants.cpp",
+ "PassDetail.h",
+ "Passes.cpp",
+ "VerifyLowerings.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ "Passes.h.inc",
+ ],
+ deps = [
+ ":PassesIncGen",
+ "//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/Shape/Transforms",
+ "//iree/compiler/Dialect/Shape/Utils:TypeConversion",
+ "//iree/compiler/Dialect/Stream/Conversion/FlowToStream",
+ "//iree/compiler/Dialect/Stream/Conversion/StandardToStream",
+ "//iree/compiler/Dialect/Stream/Conversion/UtilToStream",
+ "//iree/compiler/Dialect/Stream/IR",
+ "//iree/compiler/Dialect/Util/Conversion",
+ "//iree/compiler/Dialect/Util/IR",
+ "//iree/compiler/Dialect/Util/Transforms",
+ "//iree/compiler/Utils",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:Analysis",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["-gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..267b6f6
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -0,0 +1,63 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Transforms/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Transforms
+ HDRS
+ "Passes.h"
+ "Passes.h.inc"
+ SRCS
+ "ConvertToStream.cpp"
+ "OutlineConstants.cpp"
+ "PassDetail.h"
+ "Passes.cpp"
+ "VerifyLowerings.cpp"
+ DEPS
+ ::PassesIncGen
+ LLVMSupport
+ MLIRAffine
+ MLIRAnalysis
+ MLIRArithmetic
+ MLIRIR
+ MLIRPass
+ MLIRSCF
+ MLIRStandard
+ MLIRSupport
+ MLIRTensor
+ MLIRTransformUtils
+ MLIRTransforms
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::Shape::Transforms
+ iree::compiler::Dialect::Shape::Utils::TypeConversion
+ iree::compiler::Dialect::Stream::Conversion::FlowToStream
+ iree::compiler::Dialect::Stream::Conversion::StandardToStream
+ iree::compiler::Dialect::Stream::Conversion::UtilToStream
+ iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::Conversion
+ iree::compiler::Dialect::Util::IR
+ iree::compiler::Dialect::Util::Transforms
+ iree::compiler::Utils
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ PassesIncGen
+ TD_FILE
+ "Passes.td"
+ OUTS
+ -gen-pass-decls Passes.h.inc
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
new file mode 100644
index 0000000..01ec65f
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -0,0 +1,257 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.h"
+#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h"
+#include "iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+namespace {
+
+// Builds a stream.tensor.import op that imports an external tensor value into
+// a stream resource.
+static Value buildTensorImportOp(Location loc, Value sourceTensor,
+ Type targetType, OpBuilder &builder) {
+ // Gather dynamic dimensions from the input value.
+ auto dynamicDims =
+ Shape::buildOrFindDynamicDimsForValue(loc, sourceTensor, builder);
+
+ // Compute the size of the tensor once in the stream resource.
+ // This may differ from the external encoding of the tensor as imports are
+ // a transfer operation that may need to reformat the tensor.
+ auto encodingAttr = TypeAttr::get(sourceTensor.getType());
+ auto resultSize = builder.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ loc, builder.getIndexType(), encodingAttr, dynamicDims,
+ /*affinity=*/nullptr);
+
+ // Associate the external SSA value, encoding, and shape information with the
+ // stream resource. When lowering we'll then have all the metadata required
+ // even after erasing it all on the resource.
+ auto externalType = IREE::Stream::ResourceType::get(
+ builder.getContext(), IREE::Stream::Lifetime::External);
+ auto importOp = builder.create<IREE::Stream::TensorImportOp>(
+ loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize,
+ /*affinity=*/nullptr);
+
+ // If needed insert a transfer to the target lifetime.
+ Value result = importOp.result();
+ if (targetType != externalType) {
+ result = builder
+ .create<IREE::Stream::AsyncTransferOp>(
+ loc, externalType, result, resultSize, resultSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr)
+ .result();
+ }
+
+ return result;
+}
+
+// Builds a stream.tensor.export op that exports a stream resource into an
+// external tensor value.
+static Value buildTensorExportOp(Location loc, Value sourceResource,
+ TensorType targetType, ValueRange dynamicDims,
+ OpBuilder &builder) {
+ // Query the size of the resource - which may differ from the target external
+ // value if we changed the encoding.
+ auto sourceSize = builder.createOrFold<IREE::Stream::ResourceSizeOp>(
+ loc, builder.getIndexType(), sourceResource);
+
+ // If needed insert a transfer to external resource lifetime.
+ auto externalType = IREE::Stream::ResourceType::get(
+ builder.getContext(), IREE::Stream::Lifetime::External);
+ if (sourceResource.getType() != externalType) {
+ sourceResource = builder.create<IREE::Stream::AsyncTransferOp>(
+ loc, externalType, sourceResource, sourceSize, sourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ }
+
+ // Associate the stream resource and external encoding and shape information.
+ auto newOp = builder.create<IREE::Stream::TensorExportOp>(
+ loc, targetType, sourceResource, TypeAttr::get(targetType), dynamicDims,
+ sourceSize,
+ /*affinity=*/nullptr);
+ return newOp.result();
+}
+
+// Returns true if |op| has tensor I/O that is not yet imported/exported using
+// the stream ops that capture encodings and shapes.
+static bool doesOperationNeedWrapping(Operation *op) {
+ return llvm::any_of(
+ op->getOperands(),
+ [&](Value operand) {
+ if (!operand.getType().isa<TensorType>()) return false;
+ return !isa_and_nonnull<TensorExportOp>(operand.getDefiningOp());
+ }) ||
+ llvm::any_of(op->getResults(), [&](Value result) {
+ if (!result.getType().isa<TensorType>()) return false;
+ return !llvm::all_of(result.getUsers(), [&](Operation *user) {
+ return isa<TensorImportOp>(user);
+ });
+ });
+}
+
+// Fallback handler for unknown ops taking/returning tensors that need to be
+// marshaled into/outof stream resource types.
+struct GenericResourcePattern : public ConversionPattern {
+ GenericResourcePattern(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!doesOperationNeedWrapping(op)) return failure();
+
+ // Export resources into tensor operands for the op to consume.
+ SmallVector<Value> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ rewriter.setInsertionPoint(op);
+ for (auto it : llvm::zip(op->getOperands(), operands)) {
+ auto oldOperand = std::get<0>(it);
+ auto newOperand = std::get<1>(it);
+ if (!newOperand.getType().isa<IREE::Stream::ResourceType>()) {
+ newOperands.push_back(newOperand);
+ continue;
+ }
+ auto tensorType = oldOperand.getType().dyn_cast<TensorType>();
+ assert(tensorType && "must have a tensor type to map to a resource");
+
+ auto dynamicDims = Shape::buildOrFindDynamicDimsForValue(
+ op->getLoc(), oldOperand, rewriter);
+ newOperands.push_back(buildTensorExportOp(
+ op->getLoc(), newOperand, tensorType, dynamicDims, rewriter));
+ }
+ rewriter.updateRootInPlace(op, [&]() { op->setOperands(newOperands); });
+
+ // Import into resources from tensor results produced by the op.
+ rewriter.setInsertionPointAfter(op);
+ for (auto result : op->getResults()) {
+ auto tensorType = result.getType().dyn_cast<TensorType>();
+ if (!tensorType) continue;
+
+ auto dynamicDims =
+ Shape::buildOrFindDynamicDimsForValue(op->getLoc(), result, rewriter);
+ auto importedValue = buildTensorImportOp(
+ op->getLoc(), result, IREE::Stream::ResourceType::get(getContext()),
+ rewriter);
+ result.replaceAllUsesExcept(importedValue, importedValue.getDefiningOp());
+ }
+
+ return success();
+ }
+};
+
+class ConvertToStreamPass : public ConvertToStreamBase<ConvertToStreamPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<ShapeDialect>();
+ registry.insert<mlir::StandardOpsDialect>();
+ registry.insert<mlir::arith::ArithmeticDialect>();
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ TypeConverter typeConverter;
+ ConversionTarget conversionTarget(getContext());
+ OwningRewritePatternList patterns(&getContext());
+
+ // Always allow lowerering target dialects and reasonable types.
+ conversionTarget.addLegalDialect<IREE::Stream::StreamDialect>();
+ typeConverter.addConversion(
+ [](IREE::Stream::ResourceType type) { return type; });
+
+ // Disallow tensor dialects; the goal here is to remove all tensors and
+ // turn them into stream resource ops.
+ auto indexType = IndexType::get(context);
+ conversionTarget.addIllegalDialect<tensor::TensorDialect>();
+ typeConverter.addConversion(
+ [=](TensorType type, SmallVectorImpl<Type> &resultTypes) {
+ // Expand tensors to [resource, sizeof resource].
+ resultTypes.push_back(IREE::Stream::ResourceType::get(context));
+ resultTypes.push_back(indexType);
+ return success();
+ });
+ typeConverter.addArgumentMaterialization(
+ [](OpBuilder &builder, TensorType resultType, ValueRange inputs,
+ Location loc) -> Optional<Value> {
+ assert(inputs.size() >= 2);
+ auto resourceValue = inputs[0];
+ auto resourceSize = inputs[1];
+ assert(inputs.size() == 2 &&
+ "expecting 2 operands (resource + size)");
+ return builder
+ .create<IREE::Stream::AsyncTransferOp>(
+ loc, resourceValue.getType(), resourceValue, resourceSize,
+ resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr)
+ .result();
+ });
+
+ populateUtilConversionPatterns(context, conversionTarget, typeConverter,
+ patterns);
+ populateUtilToStreamConversionPatterns(context, conversionTarget,
+ typeConverter, patterns);
+
+ populateStandardToStreamConversionPatterns(context, conversionTarget,
+ typeConverter, patterns);
+ populateFlowToStreamConversionPatterns(context, conversionTarget,
+ typeConverter, patterns);
+
+ conversionTarget.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); });
+ patterns.insert<GenericResourcePattern>(context, typeConverter);
+
+ // NOTE: we allow ops that we don't know about to allow custom dialects
+ // that don't need anything Stream-specific to pass through.
+ conversionTarget.addLegalOp<UnrealizedConversionCastOp>();
+ if (failed(applyPartialConversion(getOperation(), conversionTarget,
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertToStreamPass() {
+ return std::make_unique<ConvertToStreamPass>();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/OutlineConstants.cpp b/iree/compiler/Dialect/Stream/Transforms/OutlineConstants.cpp
new file mode 100644
index 0000000..60fa6ef
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/OutlineConstants.cpp
@@ -0,0 +1,129 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+// Returns true if |value| is worth outlining (large, etc).
+static bool isOutlinableValue(Attribute value) {
+ if (auto elementsAttr = value.dyn_cast<DenseElementsAttr>()) {
+ // Don't outline splats - we want those fused.
+ return !elementsAttr.isSplat();
+ }
+ return false;
+}
+
+struct ConstantDef {
+ Operation *op;
+ Type type;
+ ElementsAttr value;
+};
+
+// Returns a list of all constant-like shaped data ops in the module.
+static SmallVector<ConstantDef> findConstantsInModule(mlir::ModuleOp moduleOp) {
+ SmallVector<ConstantDef> results;
+ for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
+ for (auto &block : *callableOp.getCallableRegion()) {
+ for (auto &op : block.getOperations()) {
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+ if (isOutlinableValue(constantOp.value())) {
+ results.push_back(ConstantDef{
+ constantOp,
+ constantOp.getType(),
+ constantOp.value().cast<ElementsAttr>(),
+ });
+ }
+ }
+ }
+ }
+ }
+ return results;
+}
+
+class OutlineConstantsPass : public OutlineConstantsBase<OutlineConstantsPass> {
+ public:
+ OutlineConstantsPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::StandardOpsDialect>();
+ registry.insert<mlir::arith::ArithmeticDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ if (moduleOp.getBody()->empty()) return;
+
+ SymbolTable moduleSymbols(moduleOp);
+ std::string baseName = "_constant";
+
+ // Create all top-level util.globals from constants in the module.
+ OpBuilder moduleBuilder(&moduleOp.getBody()->front());
+ std::vector<std::pair<Operation *, IREE::Util::GlobalOp>> replacements;
+ for (auto &def : findConstantsInModule(moduleOp)) {
+ // New immutable global takes the constant attribute in its specified
+ // encoding.
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ def.op->getLoc(), baseName, /*isMutable=*/false, def.type, def.value);
+ globalOp.setPrivate();
+ moduleSymbols.insert(globalOp); // uniques name
+ replacements.emplace_back(def.op, globalOp);
+
+ // Prevent the variable from being re-inlined if the canonicalizer runs.
+ // By the time we've outlined things here we are sure we want them
+ // outlined even if the user runs an arbitrary number of passes between
+ // now and when we may use that information (HAL constant pooling, etc).
+ globalOp->setAttr("noinline", moduleBuilder.getUnitAttr());
+ }
+
+ // Replace all of the constants with lookups for the new variables.
+ for (auto pair : replacements) {
+ auto *originalOp = pair.first;
+ auto globalOp = pair.second;
+ OpBuilder builder(moduleOp.getContext());
+ builder.setInsertionPoint(originalOp);
+ auto loadOp = builder.create<IREE::Util::GlobalLoadOp>(
+ originalOp->getLoc(), globalOp.type(), SymbolRefAttr::get(globalOp));
+
+ Value replacement;
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(originalOp)) {
+ // Directly replace constant with global constant value.
+ replacement = loadOp.result();
+ } else {
+ llvm_unreachable("unhandled constant op type");
+ }
+
+ originalOp->getResult(0).replaceAllUsesWith(replacement);
+ originalOp->erase();
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineConstantsPass() {
+ return std::make_unique<OutlineConstantsPass>();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/PassDetail.h b/iree/compiler/Dialect/Stream/Transforms/PassDetail.h
new file mode 100644
index 0000000..6fab181
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/PassDetail.h
@@ -0,0 +1,25 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASS_DETAIL_H_
+#define IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASS_DETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+#define GEN_PASS_CLASSES
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" // IWYU pragma: keep
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASS_DETAIL_H_
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
new file mode 100644
index 0000000..dd108bd
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -0,0 +1,207 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+
+#include <memory>
+
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static void addCleanupPatterns(OpPassManager &passManager) {
+ // Standard MLIR cleanup.
+ passManager.addPass(mlir::createCanonicalizerPass());
+ passManager.addPass(mlir::createCSEPass());
+
+ // Cleanup and canonicalization of util.global (and other util ops).
+ passManager.addPass(IREE::Util::createApplyPatternsPass());
+ passManager.addPass(IREE::Util::createFoldGlobalsPass());
+ passManager.addPass(IREE::Util::createFuseGlobalsPass());
+
+ // Simplify util.global accesses; this can help with data flow tracking as
+ // redundant store-loads are removed.
+ passManager.addNestedPass<IREE::Util::InitializerOp>(
+ IREE::Util::createSimplifyGlobalAccessesPass());
+ passManager.addNestedPass<mlir::FuncOp>(
+ IREE::Util::createSimplifyGlobalAccessesPass());
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-tensor-transformation-pipeline
+//===----------------------------------------------------------------------===//
+
+void buildStreamTensorPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions) {
+ //----------------------------------------------------------------------------
+ // Input cleanup and simplification
+ //----------------------------------------------------------------------------
+
+ // Verify we support the program.
+ passManager.addPass(IREE::Stream::createVerifyInputPass());
+
+ // Turn all constant ops into global variables and fix up the IR.
+ // As many locations change and constants are deduplicated we'll end up with
+ // a lot of extraneous IR (mostly global loads) and clean those up here.
+ passManager.addPass(IREE::Stream::createOutlineConstantsPass());
+
+ // Perform cleanup after constnat simplification as more canonicalizers may be
+ // able to kick in.
+ addCleanupPatterns(passManager);
+
+ //----------------------------------------------------------------------------
+ // Conversion
+ //----------------------------------------------------------------------------
+
+ // Converts from all input dialects into various levels of the stream dialect.
+ // Tensor-like things go to stream.tensor.* ops while lower level buffer-like
+ // things will go to stream.async.* ops.
+ passManager.addPass(IREE::Stream::createConvertToStreamPass());
+
+ // No more tensor.*/etc ops are allowed. This is conservative - there may be
+ // a lot of ops we convert but this will catch the majority of stragglers.
+ passManager.addPass(IREE::Stream::createVerifyLoweringToTensorsPass());
+
+ //----------------------------------------------------------------------------
+ // Constant/variable optimization
+ //----------------------------------------------------------------------------
+
+ // Cleanup globals that were created during conversion.
+ addCleanupPatterns(passManager);
+
+ // Bring all initializers together so that we can schedule them.
+ passManager.addPass(IREE::Util::createCombineInitializersPass());
+
+ //----------------------------------------------------------------------------
+ // Stream affinity/assignment
+ //----------------------------------------------------------------------------
+
+ // TODO(benvanik): pin based on target backends here.
+ // TODO(benvanik): compute affinities for executables.
+ // TODO(benvanik): annotate all dispatches with preferred executable affinity.
+ // TODO(benvanik): DFA to specify all value affinities and pin dispatches.
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-async-transformation-pipeline
+//===----------------------------------------------------------------------===//
+
+void buildStreamAsyncPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions) {}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-cmd-transformation-pipeline
+//===----------------------------------------------------------------------===//
+
+void buildStreamCmdPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions) {}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-optimization-pipeline
+//===----------------------------------------------------------------------===//
+
+void buildStreamOptimizationPassPipeline(
+ OpPassManager &passManager, const TransformOptions &transformOptions) {}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-transformation-pipeline
+//===----------------------------------------------------------------------===//
+
+void buildStreamTransformPassPipeline(
+ OpPassManager &passManager, const TransformOptions &transformOptions) {
+ //----------------------------------------------------------------------------
+ // Primary pipeline stages (required)
+ //----------------------------------------------------------------------------
+
+ buildStreamTensorPassPipeline(passManager, transformOptions);
+ buildStreamAsyncPassPipeline(passManager, transformOptions);
+ buildStreamCmdPassPipeline(passManager, transformOptions);
+
+ //----------------------------------------------------------------------------
+ // Optimizations (may be required by some targets)
+ //----------------------------------------------------------------------------
+
+ buildStreamOptimizationPassPipeline(passManager, transformOptions);
+
+ //----------------------------------------------------------------------------
+ // Post-pipeline cleanup
+ //----------------------------------------------------------------------------
+
+ // Forming streams involves a fair amount of subgraph stitching, which can
+ // cause duplication. Run CSE to collapse.
+ addCleanupPatterns(passManager);
+
+ // Symbol DCE any remaining variables/functions that are now no longer
+ // required.
+ passManager.addPass(mlir::createSymbolDCEPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+void registerStreamTransformPassPipelines() {
+ PassPipelineRegistration<TransformOptions> tensorPassPipeline(
+ "iree-stream-tensor-transformation-pipeline",
+ "Lowers source dialects into stream.tensor.* IR.",
+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
+ buildStreamTensorPassPipeline(passManager, transformOptions);
+ });
+ PassPipelineRegistration<TransformOptions> asyncPassPipeline(
+ "iree-stream-async-transformation-pipeline",
+ "Lowers stream.tensor.* to stream.async.* IR.",
+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
+ buildStreamAsyncPassPipeline(passManager, transformOptions);
+ });
+ PassPipelineRegistration<TransformOptions> cmdPassPipeline(
+ "iree-stream-cmd-transformation-pipeline",
+ "Lowers stream.async.* to stream.cmd.* IR.",
+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
+ buildStreamCmdPassPipeline(passManager, transformOptions);
+ });
+ PassPipelineRegistration<TransformOptions> optimizationPassPipeline(
+ "iree-stream-optimization-pipeline",
+ "Optimizes stream commands and resources (may be required for some "
+ "targets).",
+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
+ buildStreamOptimizationPassPipeline(passManager, transformOptions);
+ });
+ PassPipelineRegistration<TransformOptions> transformPassPipeline(
+ "iree-stream-transformation-pipeline",
+ "Runs the full IREE stream dialect transformation pipeline.",
+ [](OpPassManager &passManager, const TransformOptions &transformOptions) {
+ buildStreamTransformPassPipeline(passManager, transformOptions);
+ });
+}
+
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" // IWYU pragma: export
+} // namespace
+
+void registerStreamPasses() {
+ // Generated.
+ registerPasses();
+
+ // Pipelines.
+ registerStreamTransformPassPipelines();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.h b/iree/compiler/Dialect/Stream/Transforms/Passes.h
new file mode 100644
index 0000000..513fdd0
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -0,0 +1,104 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASSES_H_
+#define IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASSES_H_
+
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
+struct TransformOptions : public PassPipelineOptions<TransformOptions> {
+ // TODO(benvanik): options for async/sync overrides.
+
+ Option<bool> optimizeBindings{
+ *this, "optimize-bindings",
+ llvm::cl::desc(
+ "Enables binding fusion and dispatch site specialization."),
+ llvm::cl::init(true)};
+};
+
+// Adds a set of passes to the given pass manager that run the required flow
+// transforms in the canonical order.
+//
+// Most translation code should prefer to use this instead of manually adding
+// the passes themselves to ensure that expected pass ordering is observed.
+//
+// The expected usage is:
+// Input legalization by one of:
+// - Directly passing supported flow plus core ops
+// buildStreamTransformPassPipeline
+// <run conversion from flow to sequencer/hal/vm/etc>
+//
+// This is equivalent to:
+// buildStreamTensorPassPipeline
+// buildStreamAsyncPassPipeline
+// buildStreamCmdPassPipeline
+void buildStreamTransformPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions);
+
+// Lowers from source dialects into stream.tensor.* IR.
+void buildStreamTensorPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions);
+// Lowers stream.tensor.* IR into stream.async.* IR.
+void buildStreamAsyncPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions);
+// Lowers stream.async.* IR into stream.cmd.* IR.
+void buildStreamCmdPassPipeline(OpPassManager &passManager,
+ const TransformOptions &transformOptions);
+// Optimizes stream commands (mostly optional).
+void buildStreamOptimizationPassPipeline(
+ OpPassManager &passManager, const TransformOptions &transformOptions);
+
+void registerStreamTransformPassPipelines();
+
+//===----------------------------------------------------------------------===//
+// Optimizations
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineConstantsPass();
+
+//===----------------------------------------------------------------------===//
+// Conversion
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertToStreamPass();
+
+//===----------------------------------------------------------------------===//
+// Diagnostics
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyInputPass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyLoweringToTensorsPass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyLoweringToAsyncPass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyLoweringToCmdPass();
+
+//===----------------------------------------------------------------------===//
+// Register all Passes
+//===----------------------------------------------------------------------===//
+
+void registerStreamPasses();
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASSES_H_
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td
new file mode 100644
index 0000000..259183b
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -0,0 +1,72 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_STREAM_PASSES
+#define IREE_DIALECT_STREAM_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// Optimizations
+//===----------------------------------------------------------------------===//
+
+def OutlineConstants :
+ Pass<"iree-stream-outline-constants", "mlir::ModuleOp"> {
+ let summary = "Outlines tensor constants into util.globals at the module level.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createOutlineConstantsPass()
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion
+//===----------------------------------------------------------------------===//
+
+def ConvertToStream :
+ Pass<"iree-stream-conversion", "mlir::ModuleOp"> {
+ let summary = "Converts from flow/std/etc dialects into the stream dialect.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createConvertToStreamPass()
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Diagnostics
+//===----------------------------------------------------------------------===//
+
+def VerifyInput :
+ Pass<"iree-stream-verify-input", "mlir::ModuleOp"> {
+ let summary = "Verifies that input dialects are supported by the streams dialect.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createVerifyInputPass()
+ }];
+}
+
+def VerifyLoweringToTensors :
+ Pass<"iree-stream-verify-lowering-to-tensors", "mlir::ModuleOp"> {
+ let summary = "Verifies that input dialects are converted to stream.tensor.* ops.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createVerifyLoweringToTensorsPass()
+ }];
+}
+
+def VerifyLoweringToAsync :
+ Pass<"iree-stream-verify-lowering-to-async", "mlir::ModuleOp"> {
+ let summary = "Verifies that all stream.tensor.* ops and types are fully lowered to stream.async.* ops.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createVerifyLoweringToAsyncPass()
+ }];
+}
+
+def VerifyLoweringToCmd :
+ Pass<"iree-stream-verify-lowering-to-cmd", "mlir::ModuleOp"> {
+ let summary = "Verifies that all stream.async.* ops and types are fully lowered to stream.cmd.* ops.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createVerifyLoweringToCmdPass()
+ }];
+}
+
+#endif // IREE_DIALECT_STREAM_PASSES
diff --git a/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp b/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp
new file mode 100644
index 0000000..8ecdf2c
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp
@@ -0,0 +1,420 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+//===----------------------------------------------------------------------===//
+// Base pass utility
+//===----------------------------------------------------------------------===//
+
+class Verifier {
+ public:
+ enum class Legality {
+ LEGAL,
+ RECURSIVELY_LEGAL,
+ ILLEGAL,
+ };
+
+ using OpVerifierFn = std::function<Optional<Legality>(Operation *op)>;
+ using TypeVerifierFn = std::function<Legality(Type type)>;
+
+ void addIllegalDialect(StringRef dialectName) {
+ dialectLegality.insert({dialectName, Legality::ILLEGAL});
+ }
+ template <typename DialectT>
+ void addIllegalDialect() {
+ addIllegalDialect(DialectT::getDialectNamespace());
+ }
+
+ template <typename OpT>
+ void addLegalOp() {
+ opLegality.insert({OpT::getOperationName(), Legality::LEGAL});
+ }
+
+ template <typename OpT>
+ void addRecursivelyLegalOp() {
+ opLegality.insert({OpT::getOperationName(), Legality::RECURSIVELY_LEGAL});
+ }
+
+ template <typename OpT>
+ void addIllegalOp() {
+ opLegality.insert({OpT::getOperationName(), Legality::ILLEGAL});
+ }
+
+ void addOpVerifier(std::function<Optional<Legality>(Operation *)> fn) {
+ opVerifiers.push_back(fn);
+ }
+
+ template <typename OpT>
+ void addOpVerifier(std::function<Optional<Legality>(OpT)> fn) {
+ auto wrapperFn = [=](Operation *baseOp) -> Optional<Legality> {
+ if (auto op = dyn_cast<OpT>(baseOp)) {
+ return fn(op);
+ }
+ return llvm::None;
+ };
+ opVerifiers.push_back(wrapperFn);
+ }
+
+ template <typename TypeT>
+ void addIllegalType() {
+ typeLegality.insert({TypeID::get<TypeT>(), Legality::ILLEGAL});
+ }
+
+ template <typename TypeT>
+ void addTypeVerifier(std::function<Legality(TypeT)> fn) {
+ auto wrapperFn = [=](Type baseType) { return fn(baseType.cast<TypeT>()); };
+ if (typeVerifiers.insert({TypeID::get<TypeT>(), wrapperFn}).second ==
+ false) {
+ llvm_unreachable("already registered for this type");
+ }
+ }
+
+ LogicalResult run(Operation *rootOp) {
+ bool foundAnyIllegal = false;
+ rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ auto walkResult = WalkResult::advance();
+
+ // Check for op legality - can skip the expensive work if known-illegal.
+ auto legality = getOpLegality(op);
+ switch (legality) {
+ case Legality::LEGAL:
+ // Op itself is legal but may not have valid operands/results.
+ break;
+ case Legality::RECURSIVELY_LEGAL:
+ // If the entire op w/ nested ops is legal then skip.
+ return WalkResult::skip();
+ default:
+ case Legality::ILLEGAL:
+ // Early-exit on illegal ops without recursing.
+ emitIllegalOpError(op);
+ foundAnyIllegal = true;
+ return WalkResult::skip();
+ }
+
+ // Check types for operands/results.
+ for (auto operandType : llvm::enumerate(op->getOperandTypes())) {
+ if (isTypeLegal(operandType.value())) continue;
+ emitIllegalTypeError(op, "operand", operandType.index(),
+ operandType.value());
+ foundAnyIllegal = true;
+ }
+ for (auto resultType : llvm::enumerate(op->getResultTypes())) {
+ if (isTypeLegal(resultType.value())) continue;
+ emitIllegalTypeError(op, "result", resultType.index(),
+ resultType.value());
+ foundAnyIllegal = true;
+ }
+
+ return walkResult;
+ });
+ return success(!foundAnyIllegal);
+ }
+
+ private:
+ Legality getOpLegality(Operation *op) {
+ auto opName = op->getName();
+
+ // Check specific ops first (we may override dialect settings).
+ {
+ auto legalityIt = opLegality.find(opName.getStringRef());
+ if (legalityIt != opLegality.end()) {
+ return legalityIt->second;
+ }
+ }
+
+ // Check all op verifiers (usually used for interface checks).
+ for (auto &opVerifier : opVerifiers) {
+ auto legalOr = opVerifier(op);
+ if (legalOr.hasValue()) {
+ return legalOr.getValue();
+ }
+ }
+
+ // If no op carveout is applied then check to see if the dialect is
+ // allowed at all.
+ {
+ auto legalityIt = dialectLegality.find(opName.getDialectNamespace());
+ if (legalityIt != dialectLegality.end()) {
+ return legalityIt->second;
+ }
+ }
+
+ // Assume legal by default.
+ return Legality::LEGAL;
+ }
+
+ bool isTypeLegal(Type type) {
+ // TODO(benvanik): subelements interface checks using recursive legality.
+
+ // Defer to verifiers first.
+ auto it = typeVerifiers.find(type.getTypeID());
+ if (it != typeVerifiers.end()) {
+ return it->second(type) != Legality::ILLEGAL;
+ }
+
+ // Check legality of the base type.
+ {
+ auto legalityIt = typeLegality.find(type.getTypeID());
+ if (legalityIt != typeLegality.end()) {
+ return legalityIt->second != Legality::ILLEGAL;
+ }
+ }
+
+ // Assume legal by default.
+ return true;
+ }
+
+ void emitIllegalOpError(Operation *op) {
+ op->emitOpError()
+ << "illegal for this phase of lowering in the stream dialect; "
+ "expected to have been converted or removed";
+ }
+
+ void emitIllegalTypeError(Operation *op, StringRef location, unsigned idx,
+ Type type) {
+ op->emitOpError()
+ << location << " " << idx << " type " << type
+ << " illegal for this phase of lowering in the stream dialect";
+ }
+
+ DenseMap<StringRef, Legality> dialectLegality;
+ DenseMap<StringRef, Legality> opLegality;
+ SmallVector<OpVerifierFn> opVerifiers;
+ DenseMap<TypeID, Legality> typeLegality;
+ DenseMap<TypeID, TypeVerifierFn> typeVerifiers;
+};
+
+static void markStreamTensorOpsIllegal(Verifier &verifier) {
+ verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
+ if (op->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
+ return Verifier::Legality::ILLEGAL;
+ }
+ return llvm::None;
+ });
+}
+
+static void markStreamAsyncOpsIllegal(Verifier &verifier) {
+ verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
+ if (op->hasTrait<OpTrait::IREE::Stream::AsyncPhaseOp>()) {
+ return Verifier::Legality::ILLEGAL;
+ }
+ return llvm::None;
+ });
+}
+
+static void markStreamCmdOpsIllegal(Verifier &verifier) {
+ verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
+ if (op->hasTrait<OpTrait::IREE::Stream::CmdPhaseOp>()) {
+ return Verifier::Legality::ILLEGAL;
+ }
+ return llvm::None;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-verify-input
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class VerifyInputPass : public VerifyInputBase<VerifyInputPass> {
+ public:
+ VerifyInputPass() = default;
+
+ void runOnOperation() override {
+ Verifier verifier;
+
+ // TODO(#7432): add indirect global expansion support to streams.
+ verifier.addIllegalOp<IREE::Util::GlobalAddressOp>();
+ verifier.addIllegalOp<IREE::Util::GlobalLoadIndirectOp>();
+ verifier.addIllegalOp<IREE::Util::GlobalStoreIndirectOp>();
+
+ if (failed(verifier.run(getOperation()))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyInputPass() {
+ return std::make_unique<VerifyInputPass>();
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-verify-lowering-to-tensors
+//===----------------------------------------------------------------------===//
+
+static void markTensorInputsIllegal(Verifier &verifier) {
+ // Tensorish dialects should all be either converted or outlined into
+ // executables. Everything should be in resources now.
+ verifier.addIllegalDialect("tensor");
+ verifier.addIllegalDialect("linalg");
+
+ // We don't allow the flow dialect except for inside of executables for which
+ // we don't yet have a full mapping to in the stream dialect.
+ // TODO(#7277): remove this carveout once we switch over to streams fully.
+ verifier.addIllegalDialect("flow");
+ verifier.addRecursivelyLegalOp<IREE::Stream::ExecutableOp>();
+}
+
+namespace {
+
+class VerifyLoweringToTensorsPass
+ : public VerifyLoweringToTensorsBase<VerifyLoweringToTensorsPass> {
+ public:
+ VerifyLoweringToTensorsPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ // We cannot have stream.cmd.* ops mixed with stream.tensor/async.* ops
+ // as they use different memory models.
+ Verifier verifier;
+ markTensorInputsIllegal(verifier);
+ markStreamCmdOpsIllegal(verifier);
+ if (failed(verifier.run(getOperation()))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyLoweringToTensorsPass() {
+ return std::make_unique<VerifyLoweringToTensorsPass>();
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-verify-lowering-to-tensors
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class VerifyLoweringToAsyncPass
+ : public VerifyLoweringToAsyncBase<VerifyLoweringToAsyncPass> {
+ public:
+ VerifyLoweringToAsyncPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ // We cannot have stream.cmd.* ops mixed with stream.tensor/async.* ops
+ // as they use different memory models.
+ Verifier verifier;
+ markTensorInputsIllegal(verifier);
+ markStreamTensorOpsIllegal(verifier);
+ markStreamCmdOpsIllegal(verifier);
+
+ // All resources should have had their usage assigned.
+ verifier.addTypeVerifier<IREE::Stream::ResourceType>([](auto type) {
+ if (type.getLifetime() == IREE::Stream::Lifetime::Unknown) {
+ return Verifier::Legality::ILLEGAL;
+ }
+ return Verifier::Legality::LEGAL;
+ });
+
+ // All streamable ops should be inside of execution regions.
+ verifier.addOpVerifier<IREE::Stream::StreamableOpInterface>(
+ [](auto op) -> Optional<Verifier::Legality> {
+ // Allow metadata ops outside of execution regions.
+ if (op.isMetadata()) return Verifier::Legality::LEGAL;
+
+ // TODO(benvanik): execution region interface to make this generic.
+ if (!op->template getParentOfType<IREE::Stream::AsyncExecuteOp>()) {
+ op->emitOpError()
+ << ": streamable op expected to be in an execution region";
+ return Verifier::Legality::ILLEGAL;
+ }
+ return llvm::None;
+ });
+
+ if (failed(verifier.run(getOperation()))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createVerifyLoweringToAsyncPass() {
+ return std::make_unique<VerifyLoweringToAsyncPass>();
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-verify-lowering-to-cmd
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class VerifyLoweringToCmdPass
+ : public VerifyLoweringToCmdBase<VerifyLoweringToCmdPass> {
+ public:
+ VerifyLoweringToCmdPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ Verifier verifier;
+ markTensorInputsIllegal(verifier);
+ markStreamTensorOpsIllegal(verifier);
+ markStreamAsyncOpsIllegal(verifier);
+
+ // All resources should have had their usage assigned.
+ verifier.addTypeVerifier<IREE::Stream::ResourceType>([](auto type) {
+ if (type.getLifetime() == IREE::Stream::Lifetime::Unknown) {
+ return Verifier::Legality::ILLEGAL;
+ }
+ return Verifier::Legality::LEGAL;
+ });
+
+ if (failed(verifier.run(getOperation()))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyLoweringToCmdPass() {
+ return std::make_unique<VerifyLoweringToCmdPass>();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
new file mode 100644
index 0000000..9b9e29f
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "convert_to_stream.mlir",
+ "outline_constants.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
new file mode 100644
index 0000000..7401676
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Stream/Transforms/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "convert_to_stream.mlir"
+ "outline_constants.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir b/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir
new file mode 100644
index 0000000..eb00f10
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir
@@ -0,0 +1,50 @@
+// RUN: iree-opt -split-input-file -iree-stream-conversion %s | IreeFileCheck %s
+
+// CHECK: stream.executable private @executable
+flow.executable private @executable {
+ // CHECK: stream.executable.export public @dispatch
+ flow.dispatch.entry public @dispatch attributes {workgroup_rank = 3 : index}
+ builtin.module {
+ // CHECK: func @dispatch(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index)
+ func @dispatch(%arg0: !flow.dispatch.tensor<readonly:?x4xf32>, %arg1: !flow.dispatch.tensor<writeonly:?xf32>,
+ %arg0_dim0: index, %arg1_dim0: index) {
+ // CHECK: %[[ARG0_SHAPE:.+]] = shapex.make_ranked_shape %arg2 : (index) -> !shapex.ranked_shape<[?,4]>
+ %arg0_shape = shapex.make_ranked_shape %arg0_dim0 : (index) -> !shapex.ranked_shape<[?,4]>
+ // CHECK: %[[ARG0_DIM0:.+]] = shapex.ranked_dim %[[ARG0_SHAPE]][0] : !shapex.ranked_shape<[?,4]> -> index
+ // CHECK: %[[ARG0_SPAN:.+]] = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:?x4xf32>{%[[ARG0_DIM0]]}
+ // CHECK: = flow.dispatch.tie_shape %[[ARG0_SPAN]], %[[ARG0_SHAPE]] : (!flow.dispatch.tensor<readonly:?x4xf32>, !shapex.ranked_shape<[?,4]>) -> !flow.dispatch.tensor<readonly:?x4xf32>
+ %arg0_shaped = flow.dispatch.tie_shape %arg0, %arg0_shape : (!flow.dispatch.tensor<readonly:?x4xf32>, !shapex.ranked_shape<[?,4]>) -> !flow.dispatch.tensor<readonly:?x4xf32>
+
+ // CHECK: %[[ARG1_SHAPE:.+]] = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?]>
+ %arg1_shape = shapex.make_ranked_shape %arg1_dim0 : (index) -> !shapex.ranked_shape<[?]>
+ // CHECK: %[[ARG1_DIM0:.+]] = shapex.ranked_dim %[[ARG1_SHAPE]][0] : !shapex.ranked_shape<[?]> -> index
+ // CHECK: %[[ARG1_SPAN:.+]] = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:?xf32>{%[[ARG1_DIM0]]}
+ // CHECK: = flow.dispatch.tie_shape %[[ARG1_SPAN]], %[[ARG1_SHAPE]] : (!flow.dispatch.tensor<writeonly:?xf32>, !shapex.ranked_shape<[?]>) -> !flow.dispatch.tensor<writeonly:?xf32>
+ %arg1_shaped = flow.dispatch.tie_shape %arg1, %arg1_shape : (!flow.dispatch.tensor<writeonly:?xf32>, !shapex.ranked_shape<[?]>) -> !flow.dispatch.tensor<writeonly:?xf32>
+
+ return
+ }
+ }
+}
+
+// CHECK-LABEL: @simple_mul
+func @simple_mul(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+ // CHECK: %[[DIM0:.+]] = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
+ %dim0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
+ // CHECK: %[[ARG0_SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
+ // CHECK: %[[ARG0_IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[ARG0_SIZE]]}
+ // CHECK: %[[ARG0_T:.+]] = stream.async.transfer %[[ARG0_IMPORT]] : !stream.resource<external>{%[[ARG0_SIZE]]} -> !stream.resource<*>{%[[ARG0_SIZE]]}
+ %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
+
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ // CHECK: %[[RET0_SIZE:.+]] = stream.tensor.sizeof tensor<?xf32>{%[[DIM0]]} : index
+ // CHECK: %[[RET0:.+]] = stream.async.dispatch @executable::@dispatch[%c4, %c1, %c1](%[[ARG0_T]]) : (!stream.resource<*>{%[[ARG0_SIZE]]}) -> !stream.resource<*>{%[[RET0_SIZE]]}
+ %1 = flow.dispatch @executable::@dispatch[%c4, %c1, %c1](%0) : (tensor<?x4xf32>{%dim0}) -> tensor<?xf32>{%dim0}
+
+ // CHECK: %[[RET0_T:.+]] = stream.async.transfer %[[RET0]] : !stream.resource<*>{%[[RET0_SIZE]]} -> !stream.resource<external>{%[[RET0_SIZE]]}
+ // CHECK: %[[RET0_EXPORT:.+]] = stream.tensor.export %[[RET0_T]] : tensor<?xf32>{%[[DIM0]]} in !stream.resource<external>{%[[RET0_SIZE]]} -> !hal.buffer_view
+ %2 = hal.tensor.cast %1 : tensor<?xf32>{%dim0} -> !hal.buffer_view
+ // CHECK: return %[[RET0_EXPORT]] : !hal.buffer_view
+ return %2 : !hal.buffer_view
+}
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/outline_constants.mlir b/iree/compiler/Dialect/Stream/Transforms/test/outline_constants.mlir
new file mode 100644
index 0000000..9bebe73
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/outline_constants.mlir
@@ -0,0 +1,30 @@
+// RUN: iree-opt -split-input-file -iree-stream-outline-constants %s | IreeFileCheck %s
+
+// CHECK-LABEL: @scalarConstant
+func @scalarConstant() {
+ // CHECK: = arith.constant 0 : i32
+ %cst = arith.constant 0 : i32
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @sparseConstant
+func @sparseConstant() {
+ // CHECK: = arith.constant dense<1.200000e+00> : tensor<512x128xf32>
+ %cst = arith.constant dense<1.2> : tensor<512x128xf32>
+ return
+}
+
+// -----
+
+// CHECK: util.global private @_constant {noinline} = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
+// CHECK-NEXT: util.global private @_constant_0 {noinline} = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]> : tensor<8xf32>
+// CHECK-LABEL: @denseConstants
+func @denseConstants() {
+ // CHECK: = util.global.load @_constant : tensor<2xf32>
+ %cst_0 = arith.constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
+ // CHECK-NEXT: = util.global.load @_constant_0 : tensor<8xf32>
+ %cst_1 = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]> : tensor<8xf32>
+ return
+}
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 792b34b..c4eb495 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -234,6 +234,13 @@
ParseResult parseShapedTiedResult(
OpAsmParser &parser, Type &resultType,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims) {
+ ArrayAttr tiedOperands;
+ return parseShapedTiedResult(parser, resultType, resultDims, tiedOperands);
+}
+
+ParseResult parseShapedTiedResult(
+ OpAsmParser &parser, Type &resultType,
SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
ArrayAttr &tiedOperands) {
OpAsmParser::OperandType tiedResult;
@@ -270,7 +277,7 @@
}
void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
- ValueRange resultDims, ArrayAttr tiedOperands) {
+ ValueRange resultDims) {
auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(0);
if (tiedOperandIndex.hasValue()) {
@@ -301,6 +308,11 @@
}
}
+void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
+ ValueRange resultDims, ArrayAttr tiedOperands) {
+ printShapedTiedResult(p, op, resultType, resultDims);
+}
+
//===----------------------------------------------------------------------===//
// custom<ShapedFunctionType>
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.h b/iree/compiler/Dialect/Util/IR/UtilOps.h
index a00447a..02cc7ad 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.h
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.h
@@ -92,6 +92,22 @@
ParseResult parseShapedTiedResult(
OpAsmParser &parser, Type &resultType,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims);
+inline ParseResult parseShapedTiedResult(OpAsmParser &parser, Type &resultType,
+ OpAsmParser::OperandType &resultDim) {
+ SmallVector<OpAsmParser::OperandType, 1> resultDims;
+ if (failed(parseShapedTiedResult(parser, resultType, resultDims))) {
+ return failure();
+ }
+ assert(resultDims.size() == 1 && "requires one dim");
+ resultDim = std::move(resultDims.front());
+ return success();
+}
+void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
+ ValueRange resultDims);
+
+ParseResult parseShapedTiedResult(
+ OpAsmParser &parser, Type &resultType,
SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
ArrayAttr &tiedOperands);
void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
@@ -105,6 +121,7 @@
tiedOperands))) {
return failure();
}
+ assert(resultDims.size() == 1 && "requires one dim");
resultDim = std::move(resultDims.front());
return success();
}
diff --git a/iree/compiler/Translation/BUILD b/iree/compiler/Translation/BUILD
index 886a5a2..4b8de51 100644
--- a/iree/compiler/Translation/BUILD
+++ b/iree/compiler/Translation/BUILD
@@ -24,6 +24,7 @@
"//iree/compiler/Dialect/HAL/Conversion/HALToVM",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Transforms",
+ "//iree/compiler/Dialect/Stream/Transforms",
"//iree/compiler/Dialect/Util/Transforms",
"//iree/compiler/Dialect/VM/Conversion",
"//iree/compiler/Dialect/VM/Conversion/StandardToVM",
diff --git a/iree/compiler/Translation/CMakeLists.txt b/iree/compiler/Translation/CMakeLists.txt
index 823575e..bdaaa39 100644
--- a/iree/compiler/Translation/CMakeLists.txt
+++ b/iree/compiler/Translation/CMakeLists.txt
@@ -32,6 +32,7 @@
iree::compiler::Dialect::HAL::Conversion::HALToVM
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Transforms
+ iree::compiler::Dialect::Stream::Transforms
iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::Conversion::StandardToVM
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 360a490..ed6b6c4 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Bindings/TFLite/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/TranslationFlags.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
@@ -29,6 +30,16 @@
namespace mlir {
namespace iree_compiler {
+// TODO(#7277): remove flag when on by default.
+static bool getExperimentalStreamsModeFromFlags() {
+ static llvm::cl::opt<bool> *enableFlag = new llvm::cl::opt<bool>{
+ "iree-experimental-streams",
+ llvm::cl::desc("Enables experimental stream dialect pipelines."),
+ llvm::cl::init(false),
+ };
+ return *enableFlag;
+}
+
static BindingOptions getBindingOptionsFromFlags() {
static llvm::cl::OptionCategory bindingOptionsCategory(
"IREE translation binding support options");
@@ -72,59 +83,6 @@
return options;
}
-// Performs initial dialect conversion to get the canonical input lowered into
-// the IREE execution/dataflow dialect.
-//
-// This will fail if we cannot support the input yet. The hope is that any
-// error that happens after this point is either backend-specific (like
-// unsupported SPIR-V lowering) or a bug.
-static LogicalResult convertToFlowModule(ModuleOp moduleOp) {
- PassManager passManager(moduleOp.getContext());
- mlir::applyPassManagerCLOptions(passManager);
- mlir::applyDefaultTimingPassManagerCLOptions(passManager);
- passManager.addInstrumentation(std::make_unique<PassTracing>());
- IREE::Flow::TransformOptions flowOptions;
- IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions);
- if (failed(passManager.run(moduleOp))) {
- return moduleOp.emitError()
- << "failed to run flow transformation pass pipeline";
- }
- return success();
-}
-
-// Runs the flow->HAL transform pipeline to lower a flow module and compile
-// executables for the specified target backends.
-static LogicalResult convertToHALModule(
- ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions) {
- PassManager passManager(moduleOp.getContext());
- mlir::applyPassManagerCLOptions(passManager);
- mlir::applyDefaultTimingPassManagerCLOptions(passManager);
- passManager.addInstrumentation(std::make_unique<PassTracing>());
- IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
- if (failed(passManager.run(moduleOp))) {
- return moduleOp.emitError()
- << "failed to run HAL transformation pass pipeline";
- }
- return success();
-}
-
-// Converts the lowered module to a canonical vm.module containing only vm ops.
-// This uses patterns to convert from standard ops and other dialects to their
-// vm ABI form.
-static LogicalResult convertToVMModule(ModuleOp moduleOp,
- IREE::VM::TargetOptions targetOptions) {
- PassManager passManager(moduleOp.getContext());
- mlir::applyPassManagerCLOptions(passManager);
- mlir::applyDefaultTimingPassManagerCLOptions(passManager);
- passManager.addInstrumentation(std::make_unique<PassTracing>());
- IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
- if (failed(passManager.run(moduleOp))) {
- return moduleOp.emitError()
- << "failed to run VM transformation pass pipeline";
- }
- return success();
-}
-
void buildIREEVMTransformPassPipeline(
BindingOptions bindingOptions, InputDialectOptions inputOptions,
IREE::HAL::TargetOptions executableOptions,
@@ -147,11 +105,18 @@
break;
}
- IREE::Flow::TransformOptions flowOptions;
+ bool enableNewStreamsDialect = getExperimentalStreamsModeFromFlags();
buildCommonInputConversionPassPipeline(passManager);
+ IREE::Flow::TransformOptions flowOptions;
+ flowOptions.streamFormation = !enableNewStreamsDialect;
IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions);
- IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
+ if (enableNewStreamsDialect) {
+ IREE::Stream::TransformOptions streamOptions;
+ IREE::Stream::buildStreamTransformPassPipeline(passManager, streamOptions);
+ } else {
+ IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
+ }
IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
passManager.addPass(IREE::Util::createDropCompilerHintsPass());
}
@@ -241,6 +206,7 @@
#endif // IREE_HAVE_EMITC_DIALECT
void registerIREEVMTranslationFlags() {
+ getExperimentalStreamsModeFromFlags();
getBindingOptionsFromFlags();
getInputDialectOptionsFromFlags();
}
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 152abc4..3f6ee17 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -113,6 +113,7 @@
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Stream/IR",
+ "//iree/compiler/Dialect/Stream/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/Util/Transforms",
"//iree/compiler/Dialect/VM/Analysis",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 6050277..bf2c5d1 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -228,6 +228,7 @@
iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Stream::Transforms
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::Analysis
diff --git a/iree/tools/init_iree_passes.h b/iree/tools/init_iree_passes.h
index 7c8daa3..bc5909a 100644
--- a/iree/tools/init_iree_passes.h
+++ b/iree/tools/init_iree_passes.h
@@ -21,6 +21,7 @@
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Analysis/TestPasses.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
@@ -48,6 +49,7 @@
Shape::registerShapePasses();
IREE::Flow::registerFlowPasses();
IREE::HAL::registerHALPasses();
+ IREE::Stream::registerStreamPasses();
IREE::Util::registerTransformPasses();
IREE::VM::registerVMPasses();
IREE::VM::registerVMAnalysisTestPasses();