Moving OutlineConstantsPass to flow and adding parameter support. (#17303)
This allows us to hide the stream dialect attributes from frontends and
use inline flow.tensor.constant ops with parameter attrs. Outlining now
also properly preserves hoistable attrs such as stream affinity. By
running IPO at the head of the flow pipeline we gain fusion
opportunities for hoisted (by user or by global opt) constants and then
we clean up the inlined constants at the end of flow so that the stream
dialect can handle all values consistently.
diff --git a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
index 589b710..574e0e5 100644
--- a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
+++ b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
@@ -113,7 +113,7 @@
):
return 1
if not inv.execute_text_pass_pipeline(
- "iree-util-outline-constants, iree-util-strip-and-splat-constants"
+ "iree-flow-outline-constants, iree-util-strip-and-splat-constants"
):
return 2
write_output(inv, output, args)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
index 6528cf5..aac974b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
@@ -241,6 +241,9 @@
def FLOW_NamedParameterAttr :
AttrDef<Flow_Dialect, "NamedParameter", [
TypedAttrInterface,
+ DeclareAttrInterfaceMethods<Util_SizedStorageAttr, [
+ "getStorageSize",
+ ]>,
]> {
let mnemonic = "parameter.named";
let summary = [{named parameter referenced an optional scope and key}];
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
index 91a970a..58c853f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
@@ -310,4 +310,21 @@
p << "\"" << keyAttr.getValue() << "\"";
}
+//===----------------------------------------------------------------------===//
+// #flow.parameter.named<...>
+//===----------------------------------------------------------------------===//
+
+int64_t NamedParameterAttr::getStorageSize() const {
+ if (auto configAttr = getConfig()) {
+ if (auto lengthAttr = configAttr.getAs<IntegerAttr>("length")) {
+ return lengthAttr.getInt();
+ }
+ }
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(getType())) {
+ return IREE::Util::getRoundedPhysicalStorageSize(shapedType);
+ } else {
+ return IREE::Util::getTypePhysicalStorageBitWidth(getType());
+ }
+}
+
} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index 1b2339e..2c40f82 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -57,6 +57,7 @@
"InjectTensorTracing.cpp",
"InsertDispatchDebugTargets.cpp",
"InterchangeTransposeGenericOps.cpp",
+ "OutlineConstants.cpp",
"OutlineDispatchExterns.cpp",
"OutlineDispatchRegions.cpp",
"Passes.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 7b14e04..007891c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -57,6 +57,7 @@
"InjectTensorTracing.cpp"
"InsertDispatchDebugTargets.cpp"
"InterchangeTransposeGenericOps.cpp"
+ "OutlineConstants.cpp"
"OutlineDispatchExterns.cpp"
"OutlineDispatchRegions.cpp"
"Passes.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
index 1980b3d..5d6c3fe 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
@@ -178,19 +178,25 @@
OpBuilder &moduleBuilder,
Explorer &explorer) {
std::string name = namePrefix + "_arg" + std::to_string(arg.getArgNumber());
- return TypeSwitch<Type, IREE::Util::GlobalOp>(arg.getType())
- .Case([&](IREE::HAL::BufferViewType type) {
- return createImportBufferViewGlobalOp(name, arg, symbolTable,
+ auto globalOp =
+ TypeSwitch<Type, IREE::Util::GlobalOp>(arg.getType())
+ .Case([&](IREE::HAL::BufferViewType type) {
+ return createImportBufferViewGlobalOp(name, arg, symbolTable,
+ moduleBuilder, explorer);
+ })
+ .Case([&](IREE::HAL::BufferType type) {
+ return createExportBufferGlobalOp(name, arg, symbolTable,
moduleBuilder, explorer);
- })
- .Case([&](IREE::HAL::BufferType type) {
- return createExportBufferGlobalOp(name, arg, symbolTable, moduleBuilder,
- explorer);
- })
- .Default([&](Type type) {
- return createPrimitiveDefaultGlobalOp(name, arg.getLoc(), type,
- symbolTable, moduleBuilder);
- });
+ })
+ .Default([&](Type type) {
+ return createPrimitiveDefaultGlobalOp(name, arg.getLoc(), type,
+ symbolTable, moduleBuilder);
+ });
+ if (globalOp) {
+ // Prevent globals from folding so that we have unique buffers for each arg.
+ globalOp->setAttr("flow.unique_id", moduleBuilder.getStringAttr(name));
+ }
+ return globalOp;
}
static LogicalResult
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp
new file mode 100644
index 0000000..0e1562b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp
@@ -0,0 +1,169 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Utils/StringUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::Flow {
+
+#define GEN_PASS_DEF_OUTLINECONSTANTSPASS
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"
+
+namespace {
+
+// Returns true if |value| is worth outlining (large, etc).
+static bool isOutlinableValue(Attribute value) {
+ if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(value)) {
+ // Don't outline splats - we want those fused.
+ return !elementsAttr.isSplat();
+ } else if (isa<IREE::Flow::NamedParameterAttr>(value)) {
+ // Always outline parameter constants.
+ return true;
+ }
+ return false;
+}
+
+struct ConstantDef {
+ Operation *op;
+ Type type;
+ TypedAttr value;
+};
+
+// Returns a list of all constant-like shaped data ops in the module.
+static SmallVector<ConstantDef> findConstantsInModule(mlir::ModuleOp moduleOp) {
+ SmallVector<ConstantDef> results;
+ for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
+ auto *region = callableOp.getCallableRegion();
+ if (!region)
+ continue;
+ region->walk([&](Operation *op) {
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+ if (isOutlinableValue(constantOp.getValue())) {
+ results.push_back(ConstantDef{
+ constantOp,
+ constantOp.getType(),
+ constantOp.getValue(),
+ });
+ }
+ } else if (auto constantOp = dyn_cast<IREE::Flow::TensorConstantOp>(op)) {
+ if (isOutlinableValue(constantOp.getValue())) {
+ results.push_back(ConstantDef{
+ constantOp,
+ constantOp.getType(),
+ constantOp.getValue(),
+ });
+ }
+ }
+ });
+ }
+ return results;
+}
+
+// Returns the operation containing |childOp| that is a direct child of
+// |ancestorOp|. May return |childOp|.
+static Operation *getParentInOp(Operation *childOp, Operation *ancestorOp) {
+ assert(childOp != ancestorOp && "child can't be its own ancestor");
+ do {
+ auto *parentOp = childOp->getParentOp();
+ if (parentOp == ancestorOp)
+ return childOp;
+ childOp = parentOp;
+ } while (childOp);
+ assert(false && "child must be nested under ancestor");
+ return nullptr;
+}
+
+static std::string getConstantName(ConstantDef &def) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ if (auto parameterAttr =
+ dyn_cast<IREE::Flow::NamedParameterAttr>(def.value)) {
+ os << "__parameter_";
+ if (parameterAttr.getScope() && !parameterAttr.getScope().empty())
+ os << parameterAttr.getScope().getValue() << "_";
+ os << parameterAttr.getKey().getValue() << "_";
+ } else {
+ os << "__constant_";
+ }
+ def.type.print(os);
+ str = sanitizeSymbolName(str);
+ if (str.substr(str.size() - 1) == "_")
+ str = str.substr(0, str.size() - 1); // strip trailing _
+ return str;
+}
+
+//===----------------------------------------------------------------------===//
+// --iree-flow-outline-constants
+//===----------------------------------------------------------------------===//
+
+struct OutlineConstantsPass
+ : public IREE::Flow::impl::OutlineConstantsPassBase<OutlineConstantsPass> {
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ if (moduleOp.getBody()->empty())
+ return;
+
+ SymbolTable moduleSymbols(moduleOp);
+
+ // Create all top-level util.globals from constants in the module.
+ std::vector<std::pair<Operation *, IREE::Util::GlobalOp>> replacements;
+ for (auto &def : findConstantsInModule(moduleOp)) {
+ // Position the global immediately preceding the top-level op that
+ // contains the constant.
+ OpBuilder moduleBuilder(&moduleOp.getBody()->front());
+ auto parentFuncOp = getParentInOp(def.op, moduleOp);
+ if (parentFuncOp)
+ moduleBuilder.setInsertionPoint(parentFuncOp);
+
+ // New immutable global takes the constant attribute in its specified
+ // encoding.
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ def.op->getLoc(), getConstantName(def), /*isMutable=*/false, def.type,
+ def.value);
+ globalOp.setPrivate();
+ IREE::Util::HoistableAttrInterface::gatherHoistableAttrs(def.op,
+ globalOp);
+ moduleSymbols.insert(globalOp); // uniques name
+ replacements.emplace_back(def.op, globalOp);
+
+ // Prevent the variable from being re-inlined if the canonicalizer runs.
+ // By the time we've outlined things here we are sure we want them
+ // outlined even if the user runs an arbitrary number of passes between
+ // now and when we may use that information (HAL constant pooling, etc).
+ globalOp.setInliningPolicyAttr(
+ moduleBuilder.getAttr<IREE::Util::InlineNeverAttr>());
+ }
+
+ // Replace all of the constants with lookups for the new variables.
+ for (auto pair : replacements) {
+ auto *originalOp = pair.first;
+ auto globalOp = pair.second;
+ OpBuilder builder(moduleOp.getContext());
+ builder.setInsertionPoint(originalOp);
+ auto loadOp = globalOp.createLoadOp(originalOp->getLoc(), builder);
+ loadOp.setGlobalImmutable(true);
+ originalOp->getResult(0).replaceAllUsesWith(
+ loadOp.getLoadedGlobalValue());
+ originalOp->erase();
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 941d66e..698629c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -121,6 +121,36 @@
using FunctionLikeNest =
MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static void addCleanupPatterns(OpPassManager &passManager) {
+ FunctionLikeNest(passManager)
+ // Standard MLIR cleanup.
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // Simplify util.global accesses; this can help with data flow tracking as
+ // redundant store-loads are removed.
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+
+ // Cleanup and canonicalization of util.global (and other util ops).
+ passManager.addPass(IREE::Util::createApplyPatternsPass());
+ passManager.addPass(IREE::Util::createFoldGlobalsPass());
+ passManager.addPass(IREE::Util::createFuseGlobalsPass());
+
+ // Large IPO pass. Note that this can introduce a significant amount of
+ // duplication/inlined constants and we'll want to ensure we're running
+ // cleanup again after (this entire set of patterns is run in a fixed-point
+ // iteration to do that).
+ passManager.addPass(IREE::Util::createIPOPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
// 1. Do some simple elementwise op fusion. This could be skipped,
// but could reduce the surface area of ops to handle later.
@@ -240,6 +270,22 @@
clEnableFusePaddingIntoLinalgConsumerOps}));
}
+ {
+ // We run these under a fixed-point iteration such that we can perform
+ // inter-procedural, intra-procedural, and canonicalization as separably
+ // verifiable/reusable passes. IPO will fold duplicate arguments/results
+ // and inline constants to allow the local optimizations to work more
+ // effectively.
+ OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());
+
+ // IPO and other cleanups.
+ addCleanupPatterns(ipoPipeline);
+
+ // Run fixed-point iteration on the IPO pipeline.
+ passManager.addPass(
+ IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline)));
+ }
+
addDispatchRegionCreationPasses(passManager, transformOptions);
FunctionLikeNest(passManager)
@@ -325,9 +371,28 @@
// passes above after we've formed dispatch regions.
.addPass(IREE::Flow::createInjectTensorTracingPass)
// Cleanup the IR after we are done.
- .addPass(IREE::Flow::createCleanupTensorShapesPass)
- .addPass(mlir::createCanonicalizerPass)
- .addPass(mlir::createCSEPass);
+ .addPass(IREE::Flow::createCleanupTensorShapesPass);
+
+ {
+ // We run these under a fixed-point iteration such that we can perform
+ // inter-procedural, intra-procedural, and canonicalization as separably
+ // verifiable/reusable passes. IPO will fold duplicate arguments/results
+ // and inline constants to allow the local optimizations to work more
+ // effectively.
+ OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());
+
+ // Turn all constant ops into global variables and fix up the IR.
+ // As many locations change and constants are deduplicated we'll end up with
+ // a lot of extraneous IR (mostly global loads) and clean those up here.
+ ipoPipeline.addPass(IREE::Flow::createOutlineConstantsPass());
+
+ // IPO and other cleanups.
+ addCleanupPatterns(ipoPipeline);
+
+ // Run fixed-point iteration on the IPO pipeline.
+ passManager.addPass(
+ IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline)));
+ }
// Cleanup executable contents.
{
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 2f3360a..592d015 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -391,6 +391,20 @@
];
}
+def OutlineConstantsPass :
+ Pass<"iree-flow-outline-constants", "mlir::ModuleOp"> {
+ let summary = "Outlines tensor constants into util.globals at the module level.";
+ let description = [{
+ Outlines tensor constants throughout the program into globals initialized
+ with stream ops.
+ }];
+ let dependentDialects = [
+ "mlir::arith::ArithDialect",
+ "IREE::Flow::FlowDialect",
+ "IREE::Util::UtilDialect",
+ ];
+}
+
def OutlineDispatchExternsPass :
Pass<"iree-flow-outline-dispatch-externs", "mlir::ModuleOp"> {
let summary = "Outlines external dispatches into executables.";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index bb23266..b9ce61b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -42,6 +42,7 @@
"inject_tensor_tracing.mlir",
"insert_dispatch_debug_targets.mlir",
"interchange_transpose_generic_ops.mlir",
+ "outline_constants.mlir",
"outline_dispatch_externs.mlir",
"outline_dispatch_regions.mlir",
"pad_fusion_with_consumer.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 4fda873..e7df9de 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -40,6 +40,7 @@
"inject_tensor_tracing.mlir"
"insert_dispatch_debug_targets.mlir"
"interchange_transpose_generic_ops.mlir"
+ "outline_constants.mlir"
"outline_dispatch_externs.mlir"
"outline_dispatch_regions.mlir"
"pad_fusion_with_consumer.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
index dfb93d5..03f5b56 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
@@ -11,8 +11,8 @@
util.return %3 : !hal.buffer_view
}
-// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view
-// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view
+// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {
+// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {
// CHECK: util.func public @simpleMul_benchmark() attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "entry"}} {
// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : !hal.buffer_view
@@ -37,12 +37,12 @@
util.return %5 : i32
}
-// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} = 0 : i32
-// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {inlining_policy = #util.inline.never} = 0 : i32
+// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {{{.+}}} = 0 : i32
+// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {{{.+}}} = 0 : i32
// CHECK: util.func public @while_benchmark()
-// CHECK-DAG: %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : i32
-// CHECK-DAG: %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : i32
+// CHECK-DAG: %[[ARG0:.+]] = util.global.load immutable @[[GLOBAL_ARG0]] : i32
+// CHECK-DAG: %[[ARG1:.+]] = util.global.load immutable @[[GLOBAL_ARG1]] : i32
// CHECK: %[[RET0:.+]] = util.call @while(%[[ARG0]], %[[ARG1]])
// CHECK: util.optimization_barrier %[[RET0]] : i32
// CHECK: util.return
@@ -59,7 +59,7 @@
util.return %2 : !hal.buffer_view
}
-// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view
+// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {
// CHECK: util.initializer {
// CHECK-DAG: %[[SPLAT:.+]] = flow.tensor.splat %c0_i32
// CHECK-DAG: %[[EXPORT:.+]] = hal.tensor.export %[[SPLAT]] : tensor<4xi32> -> !hal.buffer_view
@@ -99,14 +99,14 @@
util.return %2 : !hal.buffer_view
}
-// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {inlining_policy = #util.inline.never} : !hal.buffer_view
+// CHECK: util.global private @[[GLOBAL_ARG0:.+]] {
// CHECK: util.initializer {
// CHECK-DAG: %[[SPLAT0:.+]] = flow.tensor.splat %c0_i32
// CHECK-DAG: %[[EXPORT0:.+]] = hal.tensor.export %[[SPLAT0]] : tensor<4xi32> -> !hal.buffer_view
// CHECK-DAG: %[[DNO0:.+]] = util.optimization_barrier %[[EXPORT0]]
// CHECK-NEXT: util.global.store %[[DNO0]], @[[GLOBAL_ARG0]]
-// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {inlining_policy = #util.inline.never} : !hal.buffer
+// CHECK: util.global private @[[GLOBAL_ARG1:.+]] {
// CHECK: util.initializer {
// CHECK-DAG: %[[SPLAT1:.+]] = flow.tensor.splat %c0_i32
// CHECK-DAG: %[[EXPORT1:.+]] = hal.tensor.export %[[SPLAT1]] : tensor<4xi32> -> !hal.buffer
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir
new file mode 100644
index 0000000..e3db1b6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir
@@ -0,0 +1,79 @@
+// RUN: iree-opt --split-input-file --iree-flow-outline-constants %s | FileCheck %s
+
+// Tests that we don't outline splats (as we want them to be transients).
+
+// CHECK-LABEL: @splatConstant
+util.func @splatConstant() {
+ // CHECK-DAG: = arith.constant dense<1> : tensor<512x128xi32>
+ %arith_cst = arith.constant dense<1> : tensor<512x128xi32>
+ // CHECK-DAG: = flow.tensor.constant dense<1> : tensor<512x128xi32>
+ %flow_cst = flow.tensor.constant dense<1> : tensor<512x128xi32>
+ util.return
+}
+
+// -----
+
+// Tests that constant parameters are outlined.
+
+// CHECK: util.global private @__parameter_scope_key_tensor_4x2xi32 {inlining_policy = #util.inline.never} = #flow.parameter.named<"scope"::"key"> : tensor<4x2xi32>
+// CHECK-LABEL: @parameterConstant
+util.func @parameterConstant() {
+ // CHECK: = util.global.load immutable @__parameter_scope_key_tensor_4x2xi32 : tensor<4x2xi32>
+ %cst = flow.tensor.constant #flow.parameter.named<"scope"::"key"> : tensor<4x2xi32>
+ util.return
+}
+
+// -----
+
+// Tests that multiple constants will be hoisted and named uniquely.
+
+// CHECK: util.global private @__constant_tensor_2xf32 {inlining_policy = #util.inline.never} = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
+// CHECK-NEXT: util.global private @__constant_tensor_2xf32_0 {inlining_policy = #util.inline.never} = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
+// CHECK-NEXT: util.func private @denseConstants
+util.func private @denseConstants() {
+ // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xf32 : tensor<2xf32>
+ %cst_0 = arith.constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
+ // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xf32_0 : tensor<2xf32>
+ %cst_1 = flow.tensor.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
+ util.return
+}
+
+// -----
+
+// Tests that constants are outlined to the module scope above their use to
+// preserve ordering of existing functions/globals.
+
+// CHECK: util.func private @external_func
+util.func private @external_func()
+// CHECK-NEXT: util.global private @__constant_tensor_2xi32
+// CHECK-NEXT: util.func private @func_0()
+util.func private @func_0() {
+ // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32
+ %cst_0 = arith.constant dense<[0, 1]> : tensor<2xi32>
+ util.return
+}
+
+// CHECK: util.global private @existing_global
+util.global private @existing_global : tensor<4xf32>
+// CHECK-NEXT: util.global private @__constant_tensor_3xi32
+// CHECK-NEXT: util.func private @func_1()
+util.func private @func_1() {
+ // CHECK-NEXT: = util.global.load immutable @__constant_tensor_3xi32
+ %cst_1 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
+ util.return
+}
+
+// -----
+
+// Tests that any hoistable attrs are propagated to the outlined globals.
+
+// CHECK: util.global private @__constant_tensor_2xi32
+// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]>
+// CHECK-NEXT: util.func private @set_affinity
+util.func private @set_affinity() attributes {
+ stream.affinity = #hal.affinity.queue<[0]>
+} {
+ // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32
+ %cst = arith.constant dense<[0, 1]> : tensor<2xi32>
+ util.return
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
index 3116a86..fbc0e51 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel
@@ -21,6 +21,7 @@
"PatternUtils.h",
],
deps = [
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
index 472fdb9..05bbb79 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt
@@ -23,6 +23,7 @@
MLIRIR
MLIRTransformUtils
MLIRTransforms
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Stream::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 93f1aef..0942148 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -46,7 +46,8 @@
getContext(), IREE::Stream::Lifetime::Constant);
auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
- constantOp.getLoc(), constantType, constantOp.getValue(),
+ constantOp.getLoc(), constantType,
+ convertAttributeToStream(constantOp.getValue()),
TypeAttr::get(constantOp.getType()), ValueRange{}, affinityAttr);
// Transfer to unknown lifetime.
@@ -94,7 +95,8 @@
getContext(), IREE::Stream::Lifetime::Constant);
auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
- constantOp.getLoc(), constantType, constantOp.getValue(),
+ constantOp.getLoc(), constantType,
+ convertAttributeToStream(constantOp.getValue()),
TypeAttr::get(resultType), dynamicDims, affinityAttr);
// Transfer to unknown lifetime.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
index 1eb0dec..9a1272f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir
@@ -17,7 +17,7 @@
// CHECK-DAG: %[[CST:.+]] = stream.tensor.constant : tensor<4x2xi32> in !stream.resource<constant> = #stream.parameter.named<"scope"::"key"> : tensor<4x2xi32>
// CHECK-DAG: %[[SIZE:.+]] = stream.resource.size %[[CST]] : !stream.resource<constant>
// CHECK-DAG: %[[TRANSFER:.+]] = stream.async.transfer %[[CST]] : !stream.resource<constant>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
- %cst = flow.tensor.constant #stream.parameter.named<"scope"::"key"> : tensor<4x2xi32>
+ %cst = flow.tensor.constant #flow.parameter.named<"scope"::"key"> : tensor<4x2xi32>
// CHECK: util.return %[[TRANSFER]], %[[SIZE]]
util.return %cst : tensor<4x2xi32>
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
index 46c1c83..6bb26f1 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp
@@ -6,11 +6,23 @@
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
namespace mlir::iree_compiler {
+TypedAttr convertAttributeToStream(TypedAttr attr) {
+ if (!attr)
+ return {};
+ if (auto parameterAttr = dyn_cast<IREE::Flow::NamedParameterAttr>(attr)) {
+ return IREE::Stream::NamedParameterAttr::get(
+ attr.getContext(), parameterAttr.getType(), parameterAttr.getScope(),
+ parameterAttr.getKey(), parameterAttr.getConfig());
+ }
+ return attr;
+}
+
void expandResourceOperand(Location loc, Value operand,
SmallVectorImpl<Value> &newOperands,
OpBuilder &builder) {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
index a7a864f..fd9249e 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h
@@ -13,6 +13,10 @@
namespace mlir::iree_compiler {
+// Converts a supported attribute type to the corresponding stream dialect
+// value. Returns the provided value if it is natively supported.
+TypedAttr convertAttributeToStream(TypedAttr attr);
+
void expandResourceOperand(Location loc, Value operand,
SmallVectorImpl<Value> &newOperands,
OpBuilder &builder);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
index 127bd74..5ff99f7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
@@ -34,7 +35,7 @@
auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp);
auto newOp = rewriter.create<IREE::Stream::TensorConstantOp>(
constantOp.getLoc(), constantType,
- llvm::cast<ElementsAttr>(constantOp.getValue()),
+ convertAttributeToStream(constantOp.getValue()),
TypeAttr::get(constantOp.getType()),
/*result_encoding_dims=*/ValueRange{}, affinityAttr);
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
index a8e8d07..4d1fa5f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -249,7 +249,8 @@
affinityAttr);
} else {
initialValue = rewriter.create<IREE::Stream::TensorConstantOp>(
- globalOp.getLoc(), resourceOp.getType(), initialValueAttr,
+ globalOp.getLoc(), resourceOp.getType(),
+ convertAttributeToStream(initialValueAttr),
TypeAttr::get(globalOp.getType()),
/*result_encoding_dims=*/ValueRange{}, affinityAttr);
initialValueSize = rewriter.create<IREE::Stream::ResourceSizeOp>(
@@ -404,7 +405,7 @@
[&](IREE::Util::GlobalOp op) {
return typeConverter.isLegal(op.getType()) &&
(!op.getInitialValueAttr() ||
- !llvm::isa<TensorType>(op.getInitialValueAttr().getType()));
+ !isExpandedType(op.getInitialValueAttr().getType()));
});
conversionTarget.addDynamicallyLegalOp<IREE::Util::GlobalAddressOp>(
[&](IREE::Util::GlobalAddressOp op) {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 1ee27c0..c994e65 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -11,6 +11,7 @@
include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
@@ -1323,7 +1324,7 @@
}];
let arguments = (ins
- AnyAttr:$value,
+ TypedAttrInterface:$value,
TypeAttr:$result_encoding,
Stream_ShapeDynamicDims:$result_encoding_dims,
OptionalAttr<Stream_AffinityAttr>:$affinity
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 5800ded..a0861e7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -64,15 +64,6 @@
// propagation or fusion that needs to happen first.
addCleanupPatterns(passManager);
- // Turn all constant ops into global variables and fix up the IR.
- // As many locations change and constants are deduplicated we'll end up with
- // a lot of extraneous IR (mostly global loads) and clean those up here.
- passManager.addPass(IREE::Util::createOutlineConstantsPass());
-
- // Perform cleanup after constant simplification as more canonicalizers may be
- // able to kick in.
- addCleanupPatterns(passManager);
-
//----------------------------------------------------------------------------
// Conversion
//----------------------------------------------------------------------------
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
index 5135610..03cc0f4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
@@ -887,6 +887,46 @@
};
//===----------------------------------------------------------------------===//
+// IREE::Util::Hoistable*Interface
+//===----------------------------------------------------------------------===//
+
+// Walks |fromOp| and up to gather all dialect attributes that want to be
+// hoisted along with it. If the same named attribute is present on multiple
+// ancestors only the most narrowly scoped value will be used.
+// static
+void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp,
+ NamedAttrList &dialectAttrs) {
+ for (auto attr : fromOp->getDialectAttrs()) {
+ if (auto hoistableAttr = llvm::dyn_cast<IREE::Util::HoistableAttrInterface>(
+ attr.getValue())) {
+ if (hoistableAttr.shouldAttachToHoistedOps() &&
+ !dialectAttrs.get(attr.getName())) {
+ dialectAttrs.push_back(attr);
+ }
+ }
+ }
+ if (auto *parentOp = fromOp->getParentOp())
+ gatherHoistableAttrs(parentOp, dialectAttrs);
+}
+
+// static
+void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp,
+ Operation *toOp) {
+ // Get the attributes specified on the target op first as those take
+ // precedence over any from ancestors. We also want to preserve any
+ // non-hoistable attrs when we reassign the dialect attrs.
+ NamedAttrList dialectAttrs;
+ for (auto attr : toOp->getDialectAttrs())
+ dialectAttrs.push_back(attr);
+
+ // Gather attributes from the op and its parents, only adding ones not already
+ // set on the op.
+ HoistableAttrInterface::gatherHoistableAttrs(fromOp, dialectAttrs);
+
+ toOp->setDialectAttrs(dialectAttrs);
+}
+
+//===----------------------------------------------------------------------===//
// IREE::Util::UtilDialect
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index c2155e4..eaae0d4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -1157,6 +1157,17 @@
}]
>,
];
+
+ let extraClassDeclaration = [{
+ // Walks |fromOp| and up to gather all dialect attributes that want to be
+ // hoisted along with it. If the same named attribute is present on multiple
+ // ancestors only the most narrowly scoped value will be used.
+ static void gatherHoistableAttrs(Operation *fromOp,
+ NamedAttrList &dialectAttrs);
+
+ // Copies any hoistable attributes from the source op to the target op.
+ static void gatherHoistableAttrs(Operation *fromOp, Operation *toOp);
+ }];
}
def Util_HoistableOpInterface : OpInterface<"HoistableOpInterface"> {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
index 769e6bf..ce557ec 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
@@ -27,7 +27,6 @@
"HoistIntoGlobals.cpp",
"IPO.cpp",
"ImportResources.cpp",
- "OutlineConstants.cpp",
"PassDetail.h",
"Passes.cpp",
"Patterns.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
index 85dade3..d07b205 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
@@ -30,7 +30,6 @@
"HoistIntoGlobals.cpp"
"IPO.cpp"
"ImportResources.cpp"
- "OutlineConstants.cpp"
"PassDetail.h"
"Passes.cpp"
"Patterns.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
index 2467a5d..aa360cb 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/Utils/StringUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
@@ -31,22 +32,15 @@
// Maps an original value in the program to the symbol name of a global.
using HoistedValueMap = llvm::DenseMap<Value, GlobalOp>;
-// Walks |fromOp| and up to gather all dialect attributes that want to be
-// hoisted along with it. If the same named attribute is present on multiple
-// ancestors only the most narrowly scoped value will be used.
-static void gatherHoistableAttrs(Operation *fromOp,
- NamedAttrList &dialectAttrs) {
- for (auto attr : fromOp->getDialectAttrs()) {
- if (auto hoistableAttr =
- dyn_cast<IREE::Util::HoistableAttrInterface>(attr.getValue())) {
- if (hoistableAttr.shouldAttachToHoistedOps() &&
- !dialectAttrs.get(attr.getName())) {
- dialectAttrs.push_back(attr);
- }
- }
- }
- if (auto *parentOp = fromOp->getParentOp())
- gatherHoistableAttrs(parentOp, dialectAttrs);
+static std::string getHoistedName(Type type) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ os << "__hoisted_";
+ type.print(os);
+ str = sanitizeSymbolName(str);
+ if (str.substr(str.size() - 1) == "_")
+ str = str.substr(0, str.size() - 1); // strip trailing _
+ return str;
}
// Hoist expressions into globals. It is not expected that such a greedy
@@ -130,22 +124,27 @@
OpBuilder builder(&getContext());
for (auto [originalValue, globalOp] : hoistedMap) {
builder.setInsertionPointAfterValue(originalValue);
- Value load = globalOp.createLoadOp(globalOp->getLoc(), builder)
- .getLoadedGlobalValue();
+ auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder);
+ if (!originalValue.getDefiningOp()
+ ->getParentOfType<IREE::Util::InitializerOpInterface>()) {
+ loadOp.setGlobalImmutable(true);
+ }
+ Value loadedValue = loadOp.getLoadedGlobalValue();
// Call user hook to cast back to the original type.
if (auto hoistableType = dyn_cast<IREE::Util::HoistableTypeInterface>(
originalValue.getType())) {
- load = hoistableType.decodeStorageType(builder, load.getLoc(),
- originalValue.getType(), load);
+ loadedValue = hoistableType.decodeStorageType(
+ builder, loadedValue.getLoc(), originalValue.getType(),
+ loadedValue);
}
- if (load.getType() != originalValue.getType()) {
+ if (loadedValue.getType() != originalValue.getType()) {
getOperation().emitError()
<< "Unresolved conflict between casted global of type "
- << load.getType() << " and original type "
+ << loadedValue.getType() << " and original type "
<< originalValue.getType();
return signalPassFailure();
}
- originalValue.replaceAllUsesWith(load);
+ originalValue.replaceAllUsesWith(loadedValue);
}
cleanupDeadOps(constExprs);
}
@@ -168,7 +167,8 @@
// Gather any dialect attributes we may need to preserve.
auto *topLevelOp = getTopLevelOp(originalValue.getDefiningOp());
NamedAttrList dialectAttrs;
- gatherHoistableAttrs(topLevelOp, dialectAttrs);
+ IREE::Util::HoistableAttrInterface::gatherHoistableAttrs(topLevelOp,
+ dialectAttrs);
// No existing mapping - create a new global.
OpBuilder moduleBuilder(topLevelOp);
@@ -269,12 +269,12 @@
// functions for setting the preferred storage type.
auto hoistableType =
dyn_cast<IREE::Util::HoistableTypeInterface>(globalType);
- // Get the preferred global storage type.
if (hoistableType) {
+ // Allow the storage type of the global to differ from the local type.
globalType = hoistableType.getPreferredStorageType();
}
auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, "hoisted", false, globalType);
+ loc, getHoistedName(globalType), false, globalType);
moduleSymbols.insert(globalOp);
SymbolTable::setSymbolVisibility(globalOp,
SymbolTable::Visibility::Private);
@@ -290,17 +290,16 @@
clonedResult.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
- // Cast to the preferred global storage type.
if (hoistableType) {
+ // Allow casting to the global type if it differs from the local type.
clonedResult = hoistableType.encodeStorageType(
initializerBuilder, clonedResult.getLoc(), globalType,
clonedResult);
}
if (clonedResult.getType() != globalType) {
- globalOp.emitError()
- << "Unresolved conflict between global of type " << globalType
- << " and stored type " << clonedResult.getType();
- return failure();
+ return globalOp.emitError()
+ << "unresolved conflict between global of type " << globalType
+ << " and stored type " << clonedResult.getType();
}
globalOp.createStoreOp(loc, clonedResult, initializerBuilder);
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
deleted file mode 100644
index 023031a..0000000
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
+++ /dev/null
@@ -1,124 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <utility>
-
-#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
-#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler::IREE::Util {
-
-// Returns true if |value| is worth outlining (large, etc).
-static bool isOutlinableValue(Attribute value) {
- if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(value)) {
- // Don't outline splats - we want those fused.
- return !elementsAttr.isSplat();
- }
- return false;
-}
-
-struct ConstantDef {
- Operation *op;
- Type type;
- ElementsAttr value;
-};
-
-// Returns a list of all constant-like shaped data ops in the module.
-static SmallVector<ConstantDef> findConstantsInModule(mlir::ModuleOp moduleOp) {
- SmallVector<ConstantDef> results;
- for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
- auto *region = callableOp.getCallableRegion();
- if (!region)
- continue;
- for (auto &block : *region) {
- for (auto &op : block.getOperations()) {
- if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
- if (isOutlinableValue(constantOp.getValue())) {
- results.push_back(ConstantDef{
- constantOp,
- constantOp.getType(),
- llvm::cast<ElementsAttr>(constantOp.getValue()),
- });
- }
- }
- }
- }
- }
- return results;
-}
-
-class OutlineConstantsPass : public OutlineConstantsBase<OutlineConstantsPass> {
-public:
- OutlineConstantsPass() = default;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<mlir::arith::ArithDialect>();
- registry.insert<IREE::Util::UtilDialect>();
- }
-
- void runOnOperation() override {
- auto moduleOp = getOperation();
- if (moduleOp.getBody()->empty())
- return;
-
- SymbolTable moduleSymbols(moduleOp);
- std::string baseName = "_constant";
-
- // Create all top-level util.globals from constants in the module.
- OpBuilder moduleBuilder(&moduleOp.getBody()->front());
- std::vector<std::pair<Operation *, IREE::Util::GlobalOp>> replacements;
- for (auto &def : findConstantsInModule(moduleOp)) {
- // New immutable global takes the constant attribute in its specified
- // encoding.
- auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
- def.op->getLoc(), baseName, /*isMutable=*/false, def.type, def.value);
- globalOp.setPrivate();
- moduleSymbols.insert(globalOp); // uniques name
- replacements.emplace_back(def.op, globalOp);
-
- // Prevent the variable from being re-inlined if the canonicalizer runs.
- // By the time we've outlined things here we are sure we want them
- // outlined even if the user runs an arbitrary number of passes between
- // now and when we may use that information (HAL constant pooling, etc).
- globalOp.setInliningPolicyAttr(
- moduleBuilder.getAttr<IREE::Util::InlineNeverAttr>());
- }
-
- // Replace all of the constants with lookups for the new variables.
- for (auto pair : replacements) {
- auto *originalOp = pair.first;
- auto globalOp = pair.second;
- OpBuilder builder(moduleOp.getContext());
- builder.setInsertionPoint(originalOp);
- auto loadOp = globalOp.createLoadOp(originalOp->getLoc(), builder);
-
- Value replacement;
- if (auto constantOp = dyn_cast<arith::ConstantOp>(originalOp)) {
- // Directly replace constant with global constant value.
- replacement = loadOp.getLoadedGlobalValue();
- } else {
- assert(false && "unhandled constant op type");
- }
-
- originalOp->getResult(0).replaceAllUsesWith(replacement);
- originalOp->erase();
- }
- }
-};
-
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineConstantsPass() {
- return std::make_unique<OutlineConstantsPass>();
-}
-
-} // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index 402ee2e..a2aa226 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -30,7 +30,6 @@
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFoldGlobalsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFuseGlobalsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createIPOPass();
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createOutlineConstantsPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubrangesPass();
std::unique_ptr<OperationPass<void>> createSimplifyGlobalAccessesPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>>
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
index 191da7d..f3072d1 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
@@ -65,14 +65,6 @@
}];
}
-def OutlineConstants :
- Pass<"iree-util-outline-constants", "mlir::ModuleOp"> {
- let summary = "Outlines tensor constants into util.globals at the module level.";
- let constructor = [{
- mlir::iree_compiler::IREE::Util::createOutlineConstantsPass()
- }];
-}
-
def PropagateSubranges : Pass<"iree-util-propagate-subranges", "mlir::ModuleOp"> {
let summary = "Propagates resource subranges across the program.";
let constructor = [{
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
index df0835a..6c608cc 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
@@ -28,7 +28,6 @@
"hoist_into_globals_linalg.mlir",
"import_resources.mlir",
"ipo.mlir",
- "outline_constants.mlir",
"patterns.mlir",
"promote_bf16_to_f32.mlir",
"promote_f16_to_f32.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
index 53d73ec..2ed4d40 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
@@ -26,7 +26,6 @@
"hoist_into_globals_linalg.mlir"
"import_resources.mlir"
"ipo.mlir"
- "outline_constants.mlir"
"patterns.mlir"
"promote_bf16_to_f32.mlir"
"promote_f16_to_f32.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
index 28a7c20..37f0605 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir
@@ -16,7 +16,7 @@
%1 = arith.constant 1 : i32
// CHECK-NOT: arith.constant
// CHECK-NOT: iree_unregistered.const_expr
- // CHECK: %[[VAL:.*]] = util.global.load @[[HOISTED_SYM]] : i32
+ // CHECK: %[[VAL:.*]] = util.global.load immutable @[[HOISTED_SYM]] : i32
// CHECK: util.return %[[VAL]]
%2 = "iree_unregistered.const_expr"(%0, %1) : (i32, i32) -> i32
util.return %2 : i32
@@ -141,8 +141,8 @@
// CHECK: util.func public @main
util.func public @main() -> (i32, i32, i32) {
- // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : i32
- // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load @[[HOISTED_1]] : i32
+ // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load immutable @[[HOISTED_0]] : i32
+ // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load immutable @[[HOISTED_1]] : i32
// CHECK-DAG: %[[RESULT:.*]] = "iree_unregistered.var_expr"(%[[LOAD_HOISTED_1]])
// CHECK: util.return %[[LOAD_HOISTED_0]], %[[LOAD_HOISTED_1]], %[[RESULT]]
%0 = arith.constant 0 : i32
@@ -171,7 +171,7 @@
// CHECK: }
// CHECK: util.func public @main
util.func public @main() -> i32 {
- // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : i32
+ // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load immutable @[[HOISTED_0]] : i32
// CHECK-DAG: %[[RESULT:.*]] = "iree_unregistered.var_expr"(%[[LOAD_HOISTED_0]])
// CHECK: util.return %[[RESULT]]
%0 = arith.constant 0 : i32
@@ -201,8 +201,8 @@
// CHECK: }
// CHECK: util.func public @main
util.func public @main() -> (i32) {
- // CHECK: %[[LOAD_HOISTED:.*]] = util.global.load @[[HOISTED]] : i32
- // CHECK: %[[RESULT:.*]] = "iree_unregistered.non_leaf_const_expr"(%hoisted)
+ // CHECK: %[[LOAD_HOISTED:.*]] = util.global.load immutable @[[HOISTED]] : i32
+ // CHECK: %[[RESULT:.*]] = "iree_unregistered.non_leaf_const_expr"(%[[LOAD_HOISTED]])
// CHECK: util.return %[[RESULT]]
%0 = arith.constant 0 : i32
%1 = arith.constant 1 : i32
@@ -236,7 +236,7 @@
%1 = arith.constant 1 : i32
// CHECK-NOT: arith.constant
// CHECK-NOT: iree_unregistered.const_expr
- // CHECK: %[[VAL:.*]] = util.global.load @[[HOISTED_SYM]] : i32
+ // CHECK: %[[VAL:.*]] = util.global.load immutable @[[HOISTED_SYM]] : i32
// CHECK: util.return %[[VAL]]
%2 = "iree_unregistered.const_expr"(%0) ({
^bb0(%inner0 : i32):
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir
index cb41b5e..1760ff2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals_linalg.mlir
@@ -27,7 +27,7 @@
linalg.yield %42 : f32
} -> tensor<5x6xf32>
- // CHECK: %[[RESULT:.*]] = util.global.load @[[HOISTED]] : tensor<5x6xf32>
+ // CHECK: %[[RESULT:.*]] = util.global.load immutable @[[HOISTED]] : tensor<5x6xf32>
// CHECK: util.return %[[RESULT]]
util.return %3 : tensor<5x6xf32>
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir
deleted file mode 100644
index 76b27da..0000000
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/outline_constants.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-util-outline-constants %s | FileCheck %s
-
-// CHECK-LABEL: @scalarConstant
-util.func @scalarConstant() {
- // CHECK: = arith.constant 0 : i32
- %cst = arith.constant 0 : i32
- util.return
-}
-
-// -----
-
-// CHECK-LABEL: @splatConstant
-util.func @splatConstant() {
- // CHECK: = arith.constant dense<1.200000e+00> : tensor<512x128xf32>
- %cst = arith.constant dense<1.2> : tensor<512x128xf32>
- util.return
-}
-
-// -----
-
-// CHECK: util.global private @_constant {inlining_policy = #util.inline.never} = dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
-// CHECK-NEXT: util.global private @_constant_0 {inlining_policy = #util.inline.never} = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]> : tensor<8xf32>
-// CHECK-LABEL: @denseConstants
-util.func @denseConstants() {
- // CHECK: = util.global.load @_constant : tensor<2xf32>
- %cst_0 = arith.constant dense<[0.0287729427, 0.0297581609]> : tensor<2xf32>
- // CHECK-NEXT: = util.global.load @_constant_0 : tensor<8xf32>
- %cst_1 = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]> : tensor<8xf32>
- util.return
-}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index 61d3b6c..65ad8c9 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -100,6 +100,8 @@
// If there are stores mark the global as mutable.
globalInfo->op.setGlobalMutable(!globalInfo->getStores().empty());
}
+ for (auto loadOp : globalInfo->getLoads())
+ loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable());
});
for (auto *deadOp : deadOps)
deadOp->erase();
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
index 757fef2..c7930a5 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
@@ -35,6 +35,43 @@
// -----
+// CHECK-LABEL: @mutability_change
+vm.module @mutability_change {
+ // CHECK: vm.global.i32 private @g0
+ vm.global.i32 private mutable @g0 : i32 = 0 : i32
+ // CHECK: vm.global.i32 private mutable @g1 : i32
+ vm.global.i32 private mutable @g1 = 123 : i32
+ // CHECK: vm.global.i32 private mutable @g2 : i32
+ vm.global.i32 private @g2 : i32
+
+ vm.initializer {
+ %c456 = vm.const.i32 456
+ vm.global.store.i32 %c456, @g2 : i32
+ vm.return
+ }
+
+ // CHECK: vm.func public @func
+ vm.func public @func() {
+ // CHECK: vm.global.load.i32 immutable @g0
+ vm.global.load.i32 @g0 : i32
+ // CHECK: vm.global.load.i32 @g1
+ vm.global.load.i32 @g1 : i32
+ // CHECK: vm.global.load.i32 @g2
+ vm.global.load.i32 immutable @g2 : i32
+ vm.return
+ }
+
+ // CHECK: vm.func private @__init() {
+ // CHECK-NEXT: %c123 = vm.const.i32 123
+ // CHECK-NEXT: vm.global.store.i32 %c123, @g1
+ // CHECK-NEXT: %c456 = vm.const.i32 456
+ // CHECK-NEXT: vm.global.store.i32 %c456, @g2
+ // CHECK-NEXT: vm.return
+ // CHECK-NEXT: }
+}
+
+// -----
+
// CHECK-LABEL: @init_ref
vm.module @init_ref {
// CHECK: vm.global.ref private mutable @g0 : !vm.ref<?>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
index 46705a9..4082bbf 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
@@ -10,7 +10,7 @@
// CHECK: util.return
// CHECK: util.func public @main() -> tensor<64xi4>
- // CHECK: %[[GLOBAL_LD:.+]] = util.global.load @{{.*}} : tensor<32xi8>
+ // CHECK: %[[GLOBAL_LD:.+]] = util.global.load immutable @{{.*}} : tensor<32xi8>
// CHECK: %[[ORIG_VAL:.+]] = flow.tensor.bitcast %[[GLOBAL_LD]] : tensor<32xi8> -> tensor<64xi4>
// CHECK: util.return %[[ORIG_VAL]]
util.func public @main() -> (tensor<64xi4>) {
@@ -48,9 +48,9 @@
// CHECK: util.func public @main
util.func public @main() -> (tensor<8xi4>, tensor<8xi4>, tensor<8xi4>) {
- // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : tensor<4xi8>
+ // CHECK-DAG: %[[LOAD_HOISTED_0:.*]] = util.global.load immutable @[[HOISTED_0]] : tensor<4xi8>
// CHECK-DAG: %[[BITCAST_0:.*]] = flow.tensor.bitcast %[[LOAD_HOISTED_0]] : tensor<4xi8> -> tensor<8xi4>
- // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load @[[HOISTED_1]] : tensor<4xi8>
+ // CHECK-DAG: %[[LOAD_HOISTED_1:.*]] = util.global.load immutable @[[HOISTED_1]] : tensor<4xi8>
// CHECK-DAG: %[[BITCAST_1:.*]] = flow.tensor.bitcast %[[LOAD_HOISTED_1]] : tensor<4xi8> -> tensor<8xi4>
// CHECK-DAG: %[[RESULT:.*]] = "iree_unregistered.var_expr"(%[[BITCAST_1]])
// CHECK: util.return %[[BITCAST_0]], %[[BITCAST_1]], %[[RESULT]]
@@ -128,7 +128,7 @@
// CHECK-NEXT: flow.tensor.constant #flow.parameter.named<"compile"::"constant_hoisted_0">
// CHECK-NEXT: "iree_unregistered.const_expr"
util.func public @main() -> tensor<i32> {
- // CHECK: util.global.load @[[HOISTED]]
+ // CHECK: util.global.load immutable @[[HOISTED]]
%parameter = flow.tensor.constant #flow.parameter.named<"compile"::"constant_hoisted_0"> : tensor<i32>
%0 = "iree_unregistered.const_expr"(%parameter) : (tensor<i32>) -> tensor<i32>
util.return %0 : tensor<i32>
@@ -142,7 +142,7 @@
// CHECK-LABEL: @hoist_dialect_attrs
module @hoist_dialect_attrs {
- // CHECK: util.global private @[[HOISTED:[a-z0-9]+]]
+ // CHECK: util.global private @[[HOISTED:[a-z0-9_]+]]
// CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]>
// CHECK: util.initializer
// CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]>
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
index 4fb4155..6717abb 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/BUILD.bazel
@@ -28,7 +28,7 @@
],
deps = [
":PassesIncGen",
- "//compiler/src/iree/compiler/Dialect/Stream/IR",
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//runtime/src/iree/base",
"//runtime/src/iree/hal",
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
index 7c27305..41704dd 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/CMakeLists.txt
@@ -33,7 +33,7 @@
MLIRTransformUtils
MLIRTransforms
iree::base
- iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
iree::hal
iree::io::file_handle
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
index 4e67fb6..539a850 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -174,7 +174,7 @@
}
// Change the global to reference the parameter.
- globalOp.setGlobalInitialValue(IREE::Stream::NamedParameterAttr::get(
+ globalOp.setGlobalInitialValue(IREE::Flow::NamedParameterAttr::get(
context, globalOp.getGlobalType(), StringAttr::get(context, scope),
StringAttr::get(context, name), DictionaryAttr()));
}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp
index 4a6dfc2..a319589 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp
@@ -4,7 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -24,6 +25,34 @@
namespace {
+// Returns a set of all unique parameters and the locations using them.
+static SmallVector<std::pair<Location, IREE::Flow::NamedParameterAttr>>
+findAllParameters(ModuleOp moduleOp) {
+ llvm::MapVector<IREE::Flow::NamedParameterAttr, SmallVector<Location>>
+ parameterAttrs;
+ moduleOp.walk([&](Operation *op) {
+ if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
+ if (auto parameterAttr =
+ dyn_cast_if_present<IREE::Flow::NamedParameterAttr>(
+ globalOp.getGlobalInitialValue())) {
+ parameterAttrs[parameterAttr].push_back(globalOp.getLoc());
+ }
+ } else if (auto constantOp = dyn_cast<IREE::Flow::TensorConstantOp>(op)) {
+ if (auto parameterAttr =
+ dyn_cast_if_present<IREE::Flow::NamedParameterAttr>(
+ constantOp.getValue())) {
+ parameterAttrs[parameterAttr].push_back(constantOp.getLoc());
+ }
+ }
+ });
+ SmallVector<std::pair<Location, IREE::Flow::NamedParameterAttr>> locAttrs;
+ for (auto &entry : parameterAttrs) {
+ locAttrs.push_back(std::make_pair(
+ FusedLoc::get(moduleOp.getContext(), entry.second), entry.first));
+ }
+ return locAttrs;
+}
+
static Attribute getDefaultSplatAttr(Type elementType) {
// Today we only support basic types where 0 bits represent zeros - that lets
// us just splat out the right number of bits.
@@ -50,20 +79,16 @@
if (failed(builder))
return signalPassFailure();
- // Walk the globals in the module.
- for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
+ // Find all parameters in the module and add them to the builder.
+ // NOTE: there may be no parameters but we still will create the archive
+ // so that subsequent tooling that tries to load it succeeds.
+ auto parameterAttrs = findAllParameters(moduleOp);
+ for (auto [loc, parameterAttr] : parameterAttrs) {
// Only support types we can meaningfully generate splats for.
- auto shapedType = dyn_cast<ShapedType>(globalOp.getGlobalType());
+ auto shapedType = dyn_cast<ShapedType>(parameterAttr.getType());
if (!shapedType)
continue;
- // Look for globals backed by parameters.
- auto parameterAttr =
- dyn_cast_if_present<IREE::Stream::NamedParameterAttr>(
- globalOp.getGlobalInitialValue());
- if (!parameterAttr)
- continue;
-
// TODO: support other patterns/generators.
auto elementAttr = getDefaultSplatAttr(shapedType.getElementType());
@@ -71,7 +96,7 @@
SmallVector<char, IREE_IO_PARAMETER_MAX_SPLAT_PATTERN_LENGTH> pattern;
llvm::raw_svector_ostream os(pattern);
if (failed(IREE::Util::SerializableAttrInterface::serializeSplatValue(
- globalOp.getLoc(), elementAttr,
+ loc, elementAttr,
/*count=*/1, llvm::endianness::little, os))) {
return signalPassFailure();
}
@@ -94,10 +119,6 @@
}
}
- // Early exit if no parameter backed globals present.
- if (iree_io_parameter_archive_builder_is_empty(builder->get()))
- return;
-
// Create the parameter archive file.
auto fileStreamIndexOr =
createParameterIndex(moduleOp, std::move(builder.value()), filePath);
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp
index 8a3888e..288495e 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -255,7 +255,7 @@
// Import the given |parameterAttr| from |entry|.
static FailureOr<TypedAttr>
importParameter(StringRef fullName, ShapedType globalType,
- IREE::Stream::NamedParameterAttr parameterAttr,
+ IREE::Flow::NamedParameterAttr parameterAttr,
const iree_io_parameter_index_entry_t *entry) {
switch (entry->type) {
case IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT:
@@ -292,7 +292,7 @@
for (auto &key : keys)
importKeys.insert(key);
auto shouldImportParameter =
- [&](IREE::Stream::NamedParameterAttr parameterAttr) -> bool {
+ [&](IREE::Flow::NamedParameterAttr parameterAttr) -> bool {
// Always try to import explicitly named parameters.
if (importKeys.contains(parameterAttr.getKey().getValue()))
return true; // key match
@@ -308,9 +308,8 @@
// Find all parameters and try to import them.
for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
// Only inspect parameter globals.
- auto parameterAttr =
- dyn_cast_if_present<IREE::Stream::NamedParameterAttr>(
- globalOp.getGlobalInitialValue());
+ auto parameterAttr = dyn_cast_if_present<IREE::Flow::NamedParameterAttr>(
+ globalOp.getGlobalInitialValue());
if (!parameterAttr)
continue;
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
index 00d7dad..603d23f 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/Passes.td
@@ -14,7 +14,7 @@
let summary = "Exports all global constants to an archive file when "
"they are larger than the specified minimum size.";
let dependentDialects = [
- "IREE::Stream::StreamDialect",
+ "IREE::Flow::FlowDialect",
"IREE::Util::UtilDialect",
];
let options = [
@@ -42,7 +42,7 @@
Pass<"iree-io-import-parameters", "mlir::ModuleOp"> {
let summary = "Imports parameters from an archive file.";
let dependentDialects = [
- "IREE::Stream::StreamDialect",
+ "IREE::Flow::FlowDialect",
"IREE::Util::UtilDialect",
];
let options = [
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir
index 81ba9da..7f54149 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/export_parameters.mlir
@@ -1,42 +1,42 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-io-export-parameters{path="opt=%t.irpa" minimum-size=0})" %s | FileCheck %s
// RUN: iree-dump-parameters --parameters=%t.irpa | FileCheck %s --check-prefix=DUMP
-// CHECK: util.global private @constant_scalar_i1 = #stream.parameter.named<"opt"::"constant_scalar_i1"> : tensor<i1>
+// CHECK: util.global private @constant_scalar_i1 = #flow.parameter.named<"opt"::"constant_scalar_i1"> : tensor<i1>
// DUMP: - | - | 1 | `constant_scalar_i1`
util.global private @constant_scalar_i1 = dense<true> : tensor<i1>
-// CHECK-NEXT: util.global private @constant_dense_2xi1 = #stream.parameter.named<"opt"::"constant_dense_2xi1"> : tensor<2xi1>
+// CHECK-NEXT: util.global private @constant_dense_2xi1 = #flow.parameter.named<"opt"::"constant_dense_2xi1"> : tensor<2xi1>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `constant_dense_2xi1`
util.global private @constant_dense_2xi1 = dense<[true, false]> : tensor<2xi1>
-// CHECK-NEXT: util.global private @constant_dense_3xi4 = #stream.parameter.named<"opt"::"constant_dense_3xi4"> : tensor<3xi4>
+// CHECK-NEXT: util.global private @constant_dense_3xi4 = #flow.parameter.named<"opt"::"constant_dense_3xi4"> : tensor<3xi4>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `constant_dense_3xi4`
util.global private @constant_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4>
-// CHECK-NEXT: util.global private @constant_dense_2xi8 = #stream.parameter.named<"opt"::"constant_dense_2xi8"> : tensor<2xi8>
+// CHECK-NEXT: util.global private @constant_dense_2xi8 = #flow.parameter.named<"opt"::"constant_dense_2xi8"> : tensor<2xi8>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `constant_dense_2xi8`
util.global private @constant_dense_2xi8 = dense<[4, 5]> : tensor<2xi8>
-// CHECK-NEXT: util.global private @constant_dense_2xf32 = #stream.parameter.named<"opt"::"constant_dense_2xf32"> : tensor<2xf32>
+// CHECK-NEXT: util.global private @constant_dense_2xf32 = #flow.parameter.named<"opt"::"constant_dense_2xf32"> : tensor<2xf32>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 8 | `constant_dense_2xf32`
util.global private @constant_dense_2xf32 = dense<[11.0, 12.0]> : tensor<2xf32>
-// CHECK-NEXT: util.global private @constant_splat_2xf32 = #stream.parameter.named<"opt"::"constant_splat_2xf32"> : tensor<2xf32>
+// CHECK-NEXT: util.global private @constant_splat_2xf32 = #flow.parameter.named<"opt"::"constant_splat_2xf32"> : tensor<2xf32>
// DUMP-NEXT: - | - | 8 | `constant_splat_2xf32`
util.global private @constant_splat_2xf32 = dense<11.0> : tensor<2xf32>
-// CHECK-NEXT: util.global private mutable @mutable_scalar_i1 = #stream.parameter.named<"opt"::"mutable_scalar_i1"> : tensor<i1>
+// CHECK-NEXT: util.global private mutable @mutable_scalar_i1 = #flow.parameter.named<"opt"::"mutable_scalar_i1"> : tensor<i1>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 1 | `mutable_scalar_i1`
util.global private mutable @mutable_scalar_i1 = dense<true> : tensor<i1>
-// CHECK-NEXT: util.global private mutable @mutable_dense_3xi4 = #stream.parameter.named<"opt"::"mutable_dense_3xi4"> : tensor<3xi4>
+// CHECK-NEXT: util.global private mutable @mutable_dense_3xi4 = #flow.parameter.named<"opt"::"mutable_dense_3xi4"> : tensor<3xi4>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 2 | `mutable_dense_3xi4`
util.global private mutable @mutable_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4>
-// CHECK-NEXT: util.global private mutable @mutable_dense_2xf32 = #stream.parameter.named<"opt"::"mutable_dense_2xf32"> : tensor<2xf32>
+// CHECK-NEXT: util.global private mutable @mutable_dense_2xf32 = #flow.parameter.named<"opt"::"mutable_dense_2xf32"> : tensor<2xf32>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 8 | `mutable_dense_2xf32`
util.global private mutable @mutable_dense_2xf32 = dense<[11.0, 12.0]> : tensor<2xf32>
-// CHECK-NEXT: util.global private mutable @mutable_splat_2xf32 = #stream.parameter.named<"opt"::"mutable_splat_2xf32"> : tensor<2xf32>
+// CHECK-NEXT: util.global private mutable @mutable_splat_2xf32 = #flow.parameter.named<"opt"::"mutable_splat_2xf32"> : tensor<2xf32>
// DUMP-NEXT: {{[0-9]+}} | {{[0-9]+}} | 8 | `mutable_splat_2xf32`
util.global private mutable @mutable_splat_2xf32 = dense<11.0> : tensor<2xf32>
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir
index 4944c01..215d1cc 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/generate_splat_parameter_archive.mlir
@@ -3,20 +3,27 @@
// CHECK: util.global private @tensor_i1
// DUMP: - | - | 1 | `tensor_i1`
-util.global private @tensor_i1 = #stream.parameter.named<"opt"::"tensor_i1"> : tensor<i1>
+util.global private @tensor_i1 = #flow.parameter.named<"opt"::"tensor_i1"> : tensor<i1>
// CHECK-NEXT: util.global private @tensor_i8
// DUMP-NEXT: - | - | 1 | `tensor_i8`
-util.global private @tensor_i8 = #stream.parameter.named<"opt"::"tensor_i8"> : tensor<i8>
+util.global private @tensor_i8 = #flow.parameter.named<"opt"::"tensor_i8"> : tensor<i8>
// CHECK-NEXT: util.global private @tensor_1x2xi32
// DUMP-NEXT: - | - | 8 | `tensor_1x2xi32`
-util.global private @tensor_1x2xi32 = #stream.parameter.named<"opt"::"tensor_1x2xi32"> : tensor<1x2xi32>
+util.global private @tensor_1x2xi32 = #flow.parameter.named<"opt"::"tensor_1x2xi32"> : tensor<1x2xi32>
// CHECK-NEXT: util.global private @tensor_2x2xi4
// DUMP-NEXT: - | - | 2 | `tensor_2x2xi4`
-util.global private @tensor_2x2xi4 = #stream.parameter.named<"opt"::"tensor_2x2xi4"> : tensor<2x2xi4>
+util.global private @tensor_2x2xi4 = #flow.parameter.named<"opt"::"tensor_2x2xi4"> : tensor<2x2xi4>
// CHECK-NEXT: util.global private @tensor_3xi4
// DUMP-NEXT: - | - | 2 | `tensor_3xi4`
-util.global private @tensor_3xi4 = #stream.parameter.named<"opt"::"tensor_3xi4"> : tensor<3xi4>
+util.global private @tensor_3xi4 = #flow.parameter.named<"opt"::"tensor_3xi4"> : tensor<3xi4>
+
+util.func private @function() {
+ // CHECK: flow.tensor.constant
+ // DUMP-NEXT: - | - | 4 | `inline`
+ flow.tensor.constant #flow.parameter.named<"opt"::"inline"> : tensor<4xi8>
+ util.return
+}
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir
index 075b9d0..45eb7b7 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/import_parameters.mlir
@@ -1,15 +1,15 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-io-export-parameters{path="opt=%t.irpa" minimum-size=0},iree-io-import-parameters{paths="opt=%t.irpa"})" %s | FileCheck %s
// NOTE: packed types not supported for import yet.
-// CHECK: util.global private @constant_scalar_i1 = #stream.parameter.named
+// CHECK: util.global private @constant_scalar_i1 = #flow.parameter.named
util.global private @constant_scalar_i1 = dense<true> : tensor<i1>
// NOTE: packed types not supported for import yet.
-// CHECK: util.global private @constant_dense_2xi1 = #stream.parameter.named
+// CHECK: util.global private @constant_dense_2xi1 = #flow.parameter.named
util.global private @constant_dense_2xi1 = dense<[true, false]> : tensor<2xi1>
// NOTE: packed types not supported for import yet.
-// CHECK: util.global private @constant_dense_3xi4 = #stream.parameter.named
+// CHECK: util.global private @constant_dense_3xi4 = #flow.parameter.named
util.global private @constant_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4>
// CHECK: util.global private @constant_dense_2xi8 = dense<[4, 5]> : tensor<2xi8>
@@ -22,11 +22,11 @@
util.global private @constant_splat_2xf32 = dense<1.100000e+01> : tensor<2xf32>
// NOTE: packed types not supported for import yet.
-// CHECK: util.global private mutable @mutable_scalar_i1 = #stream.parameter.named
+// CHECK: util.global private mutable @mutable_scalar_i1 = #flow.parameter.named
util.global private mutable @mutable_scalar_i1 = dense<true> : tensor<i1>
// NOTE: packed types not supported for import yet.
-// CHECK: util.global private mutable @mutable_dense_3xi4 = #stream.parameter.named
+// CHECK: util.global private mutable @mutable_dense_3xi4 = #flow.parameter.named
util.global private mutable @mutable_dense_3xi4 = dense<[4, 5, 6]> : tensor<3xi4>
// CHECK: util.global private mutable @mutable_dense_2xf32 = dense<[1.100000e+01, 1.200000e+01]> : tensor<2xf32>
diff --git a/runtime/bindings/python/tests/io_runtime_test.py b/runtime/bindings/python/tests/io_runtime_test.py
index 417b99d..3d07e30 100644
--- a/runtime/bindings/python/tests/io_runtime_test.py
+++ b/runtime/bindings/python/tests/io_runtime_test.py
@@ -17,10 +17,10 @@
TEST_COMPILED = None
TEST_ASM = r"""
-util.global private @a0 = #stream.parameter.named<"a"::"a0"> : tensor<4xi64>
-util.global private @a1 = #stream.parameter.named<"a"::"a1"> : tensor<4xi64>
-util.global private @b0 = #stream.parameter.named<"b"::"b0"> : tensor<8xi64>
-util.global private @b1 = #stream.parameter.named<"b"::"b1"> : tensor<8xi64>
+util.global private @a0 = #flow.parameter.named<"a"::"a0"> : tensor<4xi64>
+util.global private @a1 = #flow.parameter.named<"a"::"a1"> : tensor<4xi64>
+util.global private @b0 = #flow.parameter.named<"b"::"b0"> : tensor<8xi64>
+util.global private @b1 = #flow.parameter.named<"b"::"b1"> : tensor<8xi64>
func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) {
%a0 = util.global.load @a0 : tensor<4xi64>
%a1 = util.global.load @a1 : tensor<4xi64>
diff --git a/tests/e2e/parameters/generate_splat_archive.mlir b/tests/e2e/parameters/generate_splat_archive.mlir
index a7b3607..74a6725 100644
--- a/tests/e2e/parameters/generate_splat_archive.mlir
+++ b/tests/e2e/parameters/generate_splat_archive.mlir
@@ -13,10 +13,10 @@
// CHECK-LABEL: EXEC @main
// CHECK: 1x2xi32=[0 0]
-util.global private @array_global_0 = #stream.parameter.named<"scope"::"global_0"> : tensor<1x2xi32>
-util.global private @dense_global_1 = #stream.parameter.named<"scope"::"global_1"> : tensor<2x2xi32>
-util.global private @dense_global_2 = #stream.parameter.named<"scope"::"global_2"> : tensor<1x2xi32>
-util.global private @dense_global_3 = #stream.parameter.named<"scope"::"global_3"> : tensor<2x2xi32>
+util.global private @array_global_0 = #flow.parameter.named<"scope"::"global_0"> : tensor<1x2xi32>
+util.global private @dense_global_1 = #flow.parameter.named<"scope"::"global_1"> : tensor<2x2xi32>
+util.global private @dense_global_2 = #flow.parameter.named<"scope"::"global_2"> : tensor<1x2xi32>
+util.global private @dense_global_3 = #flow.parameter.named<"scope"::"global_3"> : tensor<2x2xi32>
func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2xi32> {
%cst = arith.constant 0 : i32
%3 = util.global.load @array_global_0 : tensor<1x2xi32>
diff --git a/tools/test/parameters_scoped.mlir b/tools/test/parameters_scoped.mlir
index 2bc7c9e..9b294ab 100644
--- a/tools/test/parameters_scoped.mlir
+++ b/tools/test/parameters_scoped.mlir
@@ -14,10 +14,10 @@
// provide content for a single scope but not to have a single file provide
// content for multiple scopes. Since parameter keys only need to be unique
// within a scope this test could use the same name for both scopes if needed.
-util.global private @a0 = #stream.parameter.named<"a"::"a0"> : tensor<4xi64>
-util.global private @a1 = #stream.parameter.named<"a"::"a1"> : tensor<4xi64>
-util.global private @b0 = #stream.parameter.named<"b"::"b0"> : tensor<8xi64>
-util.global private @b1 = #stream.parameter.named<"b"::"b1"> : tensor<8xi64>
+util.global private @a0 = #flow.parameter.named<"a"::"a0"> : tensor<4xi64>
+util.global private @a1 = #flow.parameter.named<"a"::"a1"> : tensor<4xi64>
+util.global private @b0 = #flow.parameter.named<"b"::"b0"> : tensor<8xi64>
+util.global private @b1 = #flow.parameter.named<"b"::"b1"> : tensor<8xi64>
func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) {
%a0 = util.global.load @a0 : tensor<4xi64>
%a1 = util.global.load @a1 : tensor<4xi64>
diff --git a/tools/test/parameters_unscoped.mlir b/tools/test/parameters_unscoped.mlir
index 933cb77..715f686 100644
--- a/tools/test/parameters_unscoped.mlir
+++ b/tools/test/parameters_unscoped.mlir
@@ -12,10 +12,10 @@
// Simple named parameters with no scope. Parameter files are combined at
// runtime to allow for filesystem sharding while still providing a flat set of
// parameters in the compiler input.
-util.global private @a0 = #stream.parameter.named<"a0"> : tensor<4xi64>
-util.global private @a1 = #stream.parameter.named<"a1"> : tensor<4xi64>
-util.global private @b0 = #stream.parameter.named<"b0"> : tensor<8xi64>
-util.global private @b1 = #stream.parameter.named<"b1"> : tensor<8xi64>
+util.global private @a0 = #flow.parameter.named<"a0"> : tensor<4xi64>
+util.global private @a1 = #flow.parameter.named<"a1"> : tensor<4xi64>
+util.global private @b0 = #flow.parameter.named<"b0"> : tensor<8xi64>
+util.global private @b1 = #flow.parameter.named<"b1"> : tensor<8xi64>
func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) {
%a0 = util.global.load @a0 : tensor<4xi64>
%a1 = util.global.load @a1 : tensor<4xi64>